This rule raises an issue when super() calls are used in methods that are decorated with TorchScript decorators like @torch.jit.script_method or within classes likely to be compiled with TorchScript.

Why is this an issue?

TorchScript has limited support for Python’s super() mechanism, which can lead to compilation errors when converting PyTorch models for deployment.

TorchScript is PyTorch’s way to create serializable and optimizable models from PyTorch code. It allows you to run models independently from Python, which is essential for production deployment, mobile applications, and performance optimization.

However, TorchScript operates with a subset of Python’s features. The super() function relies on Python’s method resolution order (MRO) and dynamic attribute lookup, which are not fully supported in TorchScript’s static compilation environment.

When TorchScript encounters super() calls, it may fail to properly resolve the method calls during compilation, resulting in runtime errors or unexpected behavior. This is particularly problematic in forward() methods of neural network modules, where inheritance is commonly used.

What is the potential impact?

Using super() calls in TorchScript methods can cause:

How to fix in PyTorch?

Replace super() calls with direct method calls or refactor the inheritance structure to avoid super() usage in TorchScript methods.

Non-compliant code example

import torch
import torch.nn as nn

class MyModel(nn.Module):
    @torch.jit.script_method
    def forward(self, x):
        return super().forward(x)  # Noncompliant

Compliant code example

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def forward(self, x):
        # Avoid super() in TorchScript methods
        return self.process(x)

    def process(self, x):
        return x

For complex inheritance scenarios, explicitly call parent class methods by name instead of using super().

Non-compliant code example

import torch
import torch.nn as nn

class BaseModel(nn.Module):
    def forward(self, x):
        return x * 2

class DerivedModel(BaseModel):
    @torch.jit.script_method
    def forward(self, x):
        result = super().forward(x)  # Noncompliant
        return result + 1

Compliant code example

import torch
import torch.nn as nn

class BaseModel(nn.Module):
    def forward(self, x):
        return x * 2

    def base_forward(self, x):
        return x * 2

class DerivedModel(BaseModel):
    @torch.jit.script_method
    def forward(self, x):
        result = self.base_forward(x)
        return result + 1

Documentation