[dynamo][user-defined] Improve getattr_static for user_defined objects (#133742)

Fixes https://github.com/pytorch/pytorch/issues/133607

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133742
Approved by: https://github.com/Skylion007, https://github.com/jansel
This commit is contained in:
Animesh Jain 2024-08-20 10:37:52 -07:00 committed by PyTorch MergeBot
parent a36739f36a
commit 1ae5d5bb62
19 changed files with 44 additions and 16 deletions

View file

@ -5018,6 +5018,28 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
self.assertEqual(v.data.shape, (10, 20))
self.assertEqual(type(v), Matrix)
def test_classmethod_with_slots(self):
class Mock:
__slots__ = ("_a",)
def __init__(self):
self._a = 2
@classmethod
def _m(cls):
return 3
def run(self, x):
return torch.sin(x) * self._a * self._m()
def fn(x):
mock = Mock()
return mock.run(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
self.assertEqual(fn(x), opt_fn(x))
def test_nn_parametrize(self):
class Module(nn.Module):
def __init__(self) -> None:

View file

@ -8,7 +8,6 @@ import inspect
import itertools
import random
import sys
import threading
import types
import warnings
from typing import Dict, Generic, List, TYPE_CHECKING
@ -890,23 +889,30 @@ class UserDefinedObjectVariable(UserDefinedVariable):
return get_custom_getattr(self.value)
def _getattr_static(self, name):
subobj = inspect.getattr_static(self.value, name, NO_SUCH_SUBOBJ)
import _collections
# In some cases, we have to do dynamic lookup because getattr_static is not enough. For example, threading.local
# has side-effect free __getattribute__ and the attribute is not visible without a dynamic lookup.
if (
isinstance(self.value, PyTreeSpec)
or "__slots__" in self.value.__class__.__dict__
or type(self.value) == threading.local
subobj is NO_SUCH_SUBOBJ # e.g., threading.local
or isinstance(
subobj, _collections._tuplegetter
) # namedtuple fields are represented by _tuplegetter
or (
inspect.ismemberdescriptor(subobj) and name in self.value.__slots__
) # handle memberdecriptor and slots
or (
isinstance(subobj, property)
and isinstance(
subobj.fget, types.BuiltinFunctionType
) # property with C-defined fget
)
):
try:
cls_var = inspect.getattr_static(
self.value.__class__, name, NO_SUCH_SUBOBJ
)
if cls_var is not NO_SUCH_SUBOBJ and name not in self.value.__dict__:
# maybe user-defined @property that we need to inline
return cls_var
except AttributeError:
pass # __slots__
subobj = getattr(self.value, name)
else:
subobj = inspect.getattr_static(self.value, name)
# Call __getattribute__, we have already checked that this is not overridden and side-effect free. We don't
# want to call getattr because it can be user-overridden.
subobj = self.value.__getattribute__(name)
return subobj
def has_key_in_generic_dict(self, tx: "InstructionTranslator", key):