This rule raises an issue when super() calls are used in methods that are decorated with TorchScript decorators like
@torch.jit.script_method or within classes likely to be compiled with TorchScript.
TorchScript has limited support for Python’s super() mechanism, which can lead to compilation errors when converting PyTorch models
for deployment.
TorchScript is PyTorch’s way to create serializable and optimizable models from PyTorch code. It allows you to run models independently from Python, which is essential for production deployment, mobile applications, and performance optimization.
However, TorchScript operates with a subset of Python’s features. The super() function relies on Python’s method resolution order
(MRO) and dynamic attribute lookup, which are not fully supported in TorchScript’s static compilation environment.
When TorchScript encounters super() calls, it may fail to properly resolve the method calls during compilation, resulting in runtime
errors or unexpected behavior. This is particularly problematic in forward() methods of neural network modules, where inheritance is
commonly used.
Using super() calls in TorchScript methods can cause:
Replace super() calls with direct method calls or refactor the inheritance structure to avoid super() usage in TorchScript methods.
import torch
import torch.nn as nn
class MyModel(nn.Module):
@torch.jit.script_method
def forward(self, x):
return super().forward(x) # Noncompliant
import torch
import torch.nn as nn
class MyModel(nn.Module):
def forward(self, x):
# Avoid super() in TorchScript methods
return self.process(x)
def process(self, x):
return x
For complex inheritance scenarios, explicitly call parent class methods by name instead of using super().
import torch
import torch.nn as nn
class BaseModel(nn.Module):
def forward(self, x):
return x * 2
class DerivedModel(BaseModel):
@torch.jit.script_method
def forward(self, x):
result = super().forward(x) # Noncompliant
return result + 1
import torch
import torch.nn as nn
class BaseModel(nn.Module):
def forward(self, x):
return x * 2
def base_forward(self, x):
return x * 2
class DerivedModel(BaseModel):
@torch.jit.script_method
def forward(self, x):
result = self.base_forward(x)
return result + 1