From 3050f2e5dd01e9bc16dead107f8c4e2cd0a88e99 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 11 Oct 2024 16:36:41 -0700 Subject: [PATCH] [dynamo] Check nn modules parameters are not overwritten before taking tracing shortcut (#137824) Fixes https://github.com/pytorch/pytorch/issues/136257 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137824 Approved by: https://github.com/jansel --- test/dynamo/test_repros.py | 55 +++++++++++++++++++++++++ torch/_dynamo/variables/user_defined.py | 9 +++- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 392d2552820..366d53af419 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -6039,6 +6039,61 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): opt_fn = torch.compile(fn, backend="eager") self.assertEqual(fn(x), opt_fn(x)) + # https://github.com/pytorch/pytorch/issues/136257 + def test_overwriting_params(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(2, 2) + self.fc2 = torch.nn.Linear(2, 2) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + class ZeROOrderedDict(collections.OrderedDict): + def __init__(self, parent_module=None, *args, **kwargs): + """A replacement for ``collections.OrderedDict`` to detect external ZeRO params. + + Args: + parent_module (``collections.OrderedDict``): the collection to replace + """ + + super().__init__(*args, **kwargs) + self._parent_module = parent_module + + def __getitem__(self, key): + param = super().__getitem__(key) + + # Params can be registered as None (e.g., bias) + if param is None: + return param + + # do something here + return param + + def inject_parameters(module, cls): + for module in module.modules(): # noqa: B020 + if cls == ZeROOrderedDict: + new_param = cls(parent_module=module) + else: + new_param = cls() + + for key, param in module._parameters.items(): + new_param[key] = param + module._parameters = new_param + + model = M() + + inject_parameters(model, ZeROOrderedDict) + + model = torch.compile(model, backend="eager", fullgraph=True) + + x = torch.ones(2) + with torch.no_grad(): + y = model(x) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 609c4872ea8..3f5c9f866a4 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -1029,8 +1029,13 @@ class UserDefinedObjectVariable(UserDefinedVariable): if isinstance(getattr_fn, types.FunctionType): # Dynamo is going to trace the __getattr__ function with # args=name. Set the source accordingly. - if getattr_fn is unpatched_nn_module_getattr and isinstance( - self, variables.UnspecializedNNModuleVariable + if ( + getattr_fn is unpatched_nn_module_getattr + and isinstance(self, variables.UnspecializedNNModuleVariable) + # prevent against overwriting of params/buffers/submodules + and istype(self.value._parameters, dict) + and istype(self.value._buffers, dict) + and istype(self.value._modules, dict) ): # Manually trace out the nn module __getattr__ to avoid large compilation latency. out = self.manually_trace_nn_module_getattr(tx, name)