[dynamo] misc fixes for inspect (#146283)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146283
Approved by: https://github.com/jansel
ghstack dependencies: #146075
This commit is contained in:
Animesh Jain 2025-02-02 16:24:19 -08:00 committed by PyTorch MergeBot
parent 6ac8bc0cd2
commit fa48757180
3 changed files with 28 additions and 0 deletions

View file

@ -6635,6 +6635,19 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
mem_after = torch.cuda.memory_allocated()
self.assertEqual(mem_before, mem_after)
def test_udf_class_source(self):
class Foo:
pass
def fn(x):
foo = Foo()
bar = type(foo)() # noqa: F841
return torch.cos(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
self.assertEqual(fn(x), opt_fn(x))
@requires_cuda
def test_sdpa_dynamic_shapes(self, device):
def f(x, s0, s1, s2):

View file

@ -1942,6 +1942,12 @@ class BuiltinVariable(VariableTracker):
) from None
source = obj.source and TypeSource(obj.source)
if (
source is None
and isinstance(obj, variables.UserDefinedObjectVariable)
and obj.cls_source
):
source = obj.cls_source
if py_type is torch.Tensor:
# In some cases torch isn't available in globals
name = tx.output.install_global_by_id("", torch)
@ -2004,6 +2010,8 @@ class BuiltinVariable(VariableTracker):
return tensor_variable.call_id(tx)
elif istype(args[0], variables.UserFunctionVariable):
return variables.ConstantVariable.create(id(args[0].fn))
elif istype(args[0], variables.SkipFunctionVariable):
return variables.ConstantVariable.create(id(args[0].value))
else:
unimplemented(f"call_id with args {args}")

View file

@ -734,6 +734,13 @@ class SkipFunctionVariable(VariableTracker):
) -> "VariableTracker":
if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
elif isinstance(self.value, types.WrapperDescriptorType):
msg = (
f"Graph break due to unsupported wrapper descriptor {self.value}. "
f"Please file an issue on GitHub "
f"so the PyTorch team can add support for it. "
)
torch._dynamo.utils.warn_once(msg)
else:
try:
path = inspect.getfile(self.value)