diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 4470bcf2c09..4de501d9350 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -6220,6 +6220,38 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): with torch.device("cpu"): self.assertEqual(res, split(x)) + def test_method_overriding(self): + class DilateConv(torch.nn.Module): + def __init__( + self, + dilate_func=None, + ): + super().__init__() + self.dilate_func = dilate_func + + def forward(self, x): + return self.dilate_func() * torch.sin(x) + + class MainModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = DilateConv(self.dilate_func) + self.a = 4 + + def dilate_func(self): + return self.a + + def forward(self, x): + return self.mod(x) + + mod = MainModule() + + opt_mod = torch.compile(mod, backend="eager", fullgraph=True) + x = torch.randn(4) + ref = mod(x) + res = opt_mod(x) + self.assertEqual(ref, res) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 5ee9bede0d6..7387a577db1 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2978,3 +2978,26 @@ class SourcelessBuilder: SourcelessBuilder._type_handlers = SourcelessBuilder.make_type_handlers() + + +class SourcelessUserDefinedObjectBuilder: + """ + SourceLessBuilder does not return a UserDefinedObjectVariable, but in some + cases it might be ok to return UserDefinedObjects. In such case, use this + builder. + """ + + def __init__(self) -> None: + raise AssertionError("Use SourcelessUserDefinedObjectBuilder.create()") + + @staticmethod + def create(tx: "InstructionTranslator", value) -> VariableTracker: + value_type = type(value) + if issubclass(value_type, MutableMapping): + return MutableMappingVariable(value, mutation_type=ValueMutationNew()) + elif isinstance(value, torch.nn.Module): + return UnspecializedNNModuleVariable( + value, mutation_type=ValueMutationNew() + ) + else: + return UserDefinedObjectVariable(value, mutation_type=ValueMutationNew()) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index bda8fc407a8..c02e1349a5e 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -1158,7 +1158,20 @@ class UserDefinedObjectVariable(UserDefinedVariable): if isinstance(subobj, types.MethodType): if dynamic_subobj.__self__ is not self.value: - unimplemented("__self__ mismatch for bound method") + if not isinstance(dynamic_subobj.__func__, types.FunctionType): + unimplemented( + f"Found a method whose __func__ is not of FunctionType - {dynamic_subobj}" + ) + + from .builder import SourcelessUserDefinedObjectBuilder + + # This means that we are calling a method of some other object here. + object_vt = SourcelessUserDefinedObjectBuilder.create( + tx, dynamic_subobj.__self__ + ) + return variables.UserMethodVariable( + dynamic_subobj.__func__, object_vt + ) func = subobj.__func__ else: assert isinstance(subobj, types.FunctionType)