diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 8d073c8a20f..aecf3789e22 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -6635,6 +6635,19 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase): mem_after = torch.cuda.memory_allocated() self.assertEqual(mem_before, mem_after) + def test_udf_class_source(self): + class Foo: + pass + + def fn(x): + foo = Foo() + bar = type(foo)() # noqa: F841 + return torch.cos(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + self.assertEqual(fn(x), opt_fn(x)) + @requires_cuda def test_sdpa_dynamic_shapes(self, device): def f(x, s0, s1, s2): diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 564ede5ba88..8d4b16544da 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1942,6 +1942,12 @@ class BuiltinVariable(VariableTracker): ) from None source = obj.source and TypeSource(obj.source) + if ( + source is None + and isinstance(obj, variables.UserDefinedObjectVariable) + and obj.cls_source + ): + source = obj.cls_source if py_type is torch.Tensor: # In some cases torch isn't available in globals name = tx.output.install_global_by_id("", torch) @@ -2004,6 +2010,8 @@ class BuiltinVariable(VariableTracker): return tensor_variable.call_id(tx) elif istype(args[0], variables.UserFunctionVariable): return variables.ConstantVariable.create(id(args[0].fn)) + elif istype(args[0], variables.SkipFunctionVariable): + return variables.ConstantVariable.create(id(args[0].value)) else: unimplemented(f"call_id with args {args}") diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 0bda54fdbb4..d2797d13d58 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -734,6 +734,13 @@ class SkipFunctionVariable(VariableTracker): ) -> "VariableTracker": if inspect.getattr_static(self.value, "_torchdynamo_disable", False): unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}") + elif isinstance(self.value, types.WrapperDescriptorType): + msg = ( + f"Graph break due to unsupported wrapper descriptor {self.value}. " + f"Please file an issue on GitHub " + f"so the PyTorch team can add support for it. " + ) + torch._dynamo.utils.warn_once(msg) else: try: path = inspect.getfile(self.value)