mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo][user-defined] Improve getattr_static for user_defined objects (#133742)
Fixes https://github.com/pytorch/pytorch/issues/133607 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133742 Approved by: https://github.com/Skylion007, https://github.com/jansel
This commit is contained in:
parent
a36739f36a
commit
1ae5d5bb62
19 changed files with 44 additions and 16 deletions
|
|
@ -5018,6 +5018,28 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
self.assertEqual(v.data.shape, (10, 20))
|
||||
self.assertEqual(type(v), Matrix)
|
||||
|
||||
def test_classmethod_with_slots(self):
|
||||
class Mock:
|
||||
__slots__ = ("_a",)
|
||||
|
||||
def __init__(self):
|
||||
self._a = 2
|
||||
|
||||
@classmethod
|
||||
def _m(cls):
|
||||
return 3
|
||||
|
||||
def run(self, x):
|
||||
return torch.sin(x) * self._a * self._m()
|
||||
|
||||
def fn(x):
|
||||
mock = Mock()
|
||||
return mock.run(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
x = torch.randn(4)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
def test_nn_parametrize(self):
|
||||
class Module(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import inspect
|
|||
import itertools
|
||||
import random
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
import warnings
|
||||
from typing import Dict, Generic, List, TYPE_CHECKING
|
||||
|
|
@ -890,23 +889,30 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
return get_custom_getattr(self.value)
|
||||
|
||||
def _getattr_static(self, name):
|
||||
subobj = inspect.getattr_static(self.value, name, NO_SUCH_SUBOBJ)
|
||||
import _collections
|
||||
|
||||
# In some cases, we have to do dynamic lookup because getattr_static is not enough. For example, threading.local
|
||||
# has side-effect free __getattribute__ and the attribute is not visible without a dynamic lookup.
|
||||
if (
|
||||
isinstance(self.value, PyTreeSpec)
|
||||
or "__slots__" in self.value.__class__.__dict__
|
||||
or type(self.value) == threading.local
|
||||
subobj is NO_SUCH_SUBOBJ # e.g., threading.local
|
||||
or isinstance(
|
||||
subobj, _collections._tuplegetter
|
||||
) # namedtuple fields are represented by _tuplegetter
|
||||
or (
|
||||
inspect.ismemberdescriptor(subobj) and name in self.value.__slots__
|
||||
) # handle memberdecriptor and slots
|
||||
or (
|
||||
isinstance(subobj, property)
|
||||
and isinstance(
|
||||
subobj.fget, types.BuiltinFunctionType
|
||||
) # property with C-defined fget
|
||||
)
|
||||
):
|
||||
try:
|
||||
cls_var = inspect.getattr_static(
|
||||
self.value.__class__, name, NO_SUCH_SUBOBJ
|
||||
)
|
||||
if cls_var is not NO_SUCH_SUBOBJ and name not in self.value.__dict__:
|
||||
# maybe user-defined @property that we need to inline
|
||||
return cls_var
|
||||
except AttributeError:
|
||||
pass # __slots__
|
||||
subobj = getattr(self.value, name)
|
||||
else:
|
||||
subobj = inspect.getattr_static(self.value, name)
|
||||
# Call __getattribute__, we have already checked that this is not overridden and side-effect free. We don't
|
||||
# want to call getattr because it can be user-overridden.
|
||||
subobj = self.value.__getattribute__(name)
|
||||
|
||||
return subobj
|
||||
|
||||
def has_key_in_generic_dict(self, tx: "InstructionTranslator", key):
|
||||
|
|
|
|||
Loading…
Reference in a new issue