[dynamo] Check nn modules parameters are not overwritten before taking tracing shortcut (#137824)

Fixes https://github.com/pytorch/pytorch/issues/136257

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137824
Approved by: https://github.com/jansel
This commit is contained in:
Animesh Jain 2024-10-11 16:36:41 -07:00 committed by PyTorch MergeBot
parent 09e2a0d7bc
commit 3050f2e5dd
2 changed files with 62 additions and 2 deletions

View file

@ -6039,6 +6039,61 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
opt_fn = torch.compile(fn, backend="eager")
self.assertEqual(fn(x), opt_fn(x))
# https://github.com/pytorch/pytorch/issues/136257
def test_overwriting_params(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(2, 2)
self.fc2 = torch.nn.Linear(2, 2)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
class ZeROOrderedDict(collections.OrderedDict):
def __init__(self, parent_module=None, *args, **kwargs):
"""A replacement for ``collections.OrderedDict`` to detect external ZeRO params.
Args:
parent_module (``collections.OrderedDict``): the collection to replace
"""
super().__init__(*args, **kwargs)
self._parent_module = parent_module
def __getitem__(self, key):
param = super().__getitem__(key)
# Params can be registered as None (e.g., bias)
if param is None:
return param
# do something here
return param
def inject_parameters(module, cls):
for module in module.modules(): # noqa: B020
if cls == ZeROOrderedDict:
new_param = cls(parent_module=module)
else:
new_param = cls()
for key, param in module._parameters.items():
new_param[key] = param
module._parameters = new_param
model = M()
inject_parameters(model, ZeROOrderedDict)
model = torch.compile(model, backend="eager", fullgraph=True)
x = torch.ones(2)
with torch.no_grad():
y = model(x)
instantiate_parametrized_tests(ReproTests)

View file

@ -1029,8 +1029,13 @@ class UserDefinedObjectVariable(UserDefinedVariable):
if isinstance(getattr_fn, types.FunctionType):
# Dynamo is going to trace the __getattr__ function with
# args=name. Set the source accordingly.
if getattr_fn is unpatched_nn_module_getattr and isinstance(
self, variables.UnspecializedNNModuleVariable
if (
getattr_fn is unpatched_nn_module_getattr
and isinstance(self, variables.UnspecializedNNModuleVariable)
# prevent against overwriting of params/buffers/submodules
and istype(self.value._parameters, dict)
and istype(self.value._buffers, dict)
and istype(self.value._modules, dict)
):
# Manually trace out the nn module __getattr__ to avoid large compilation latency.
out = self.manually_trace_nn_module_getattr(tx, name)