mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo] Check nn modules parameters are not overwritten before taking tracing shortcut (#137824)
Fixes https://github.com/pytorch/pytorch/issues/136257 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137824 Approved by: https://github.com/jansel
This commit is contained in:
parent
09e2a0d7bc
commit
3050f2e5dd
2 changed files with 62 additions and 2 deletions
|
|
@ -6039,6 +6039,61 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
opt_fn = torch.compile(fn, backend="eager")
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/136257
|
||||
def test_overwriting_params(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(2, 2)
|
||||
self.fc2 = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
class ZeROOrderedDict(collections.OrderedDict):
|
||||
def __init__(self, parent_module=None, *args, **kwargs):
|
||||
"""A replacement for ``collections.OrderedDict`` to detect external ZeRO params.
|
||||
|
||||
Args:
|
||||
parent_module (``collections.OrderedDict``): the collection to replace
|
||||
"""
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self._parent_module = parent_module
|
||||
|
||||
def __getitem__(self, key):
|
||||
param = super().__getitem__(key)
|
||||
|
||||
# Params can be registered as None (e.g., bias)
|
||||
if param is None:
|
||||
return param
|
||||
|
||||
# do something here
|
||||
return param
|
||||
|
||||
def inject_parameters(module, cls):
|
||||
for module in module.modules(): # noqa: B020
|
||||
if cls == ZeROOrderedDict:
|
||||
new_param = cls(parent_module=module)
|
||||
else:
|
||||
new_param = cls()
|
||||
|
||||
for key, param in module._parameters.items():
|
||||
new_param[key] = param
|
||||
module._parameters = new_param
|
||||
|
||||
model = M()
|
||||
|
||||
inject_parameters(model, ZeROOrderedDict)
|
||||
|
||||
model = torch.compile(model, backend="eager", fullgraph=True)
|
||||
|
||||
x = torch.ones(2)
|
||||
with torch.no_grad():
|
||||
y = model(x)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ReproTests)
|
||||
|
||||
|
|
|
|||
|
|
@ -1029,8 +1029,13 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
if isinstance(getattr_fn, types.FunctionType):
|
||||
# Dynamo is going to trace the __getattr__ function with
|
||||
# args=name. Set the source accordingly.
|
||||
if getattr_fn is unpatched_nn_module_getattr and isinstance(
|
||||
self, variables.UnspecializedNNModuleVariable
|
||||
if (
|
||||
getattr_fn is unpatched_nn_module_getattr
|
||||
and isinstance(self, variables.UnspecializedNNModuleVariable)
|
||||
# prevent against overwriting of params/buffers/submodules
|
||||
and istype(self.value._parameters, dict)
|
||||
and istype(self.value._buffers, dict)
|
||||
and istype(self.value._modules, dict)
|
||||
):
|
||||
# Manually trace out the nn module __getattr__ to avoid large compilation latency.
|
||||
out = self.manually_trace_nn_module_getattr(tx, name)
|
||||
|
|
|
|||
Loading…
Reference in a new issue