mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo] misc fixes for inspect (#146283)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146283 Approved by: https://github.com/jansel ghstack dependencies: #146075
This commit is contained in:
parent
6ac8bc0cd2
commit
fa48757180
3 changed files with 28 additions and 0 deletions
|
|
@ -6635,6 +6635,19 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
|||
mem_after = torch.cuda.memory_allocated()
|
||||
self.assertEqual(mem_before, mem_after)
|
||||
|
||||
def test_udf_class_source(self):
|
||||
class Foo:
|
||||
pass
|
||||
|
||||
def fn(x):
|
||||
foo = Foo()
|
||||
bar = type(foo)() # noqa: F841
|
||||
return torch.cos(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
x = torch.randn(4)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
@requires_cuda
|
||||
def test_sdpa_dynamic_shapes(self, device):
|
||||
def f(x, s0, s1, s2):
|
||||
|
|
|
|||
|
|
@ -1942,6 +1942,12 @@ class BuiltinVariable(VariableTracker):
|
|||
) from None
|
||||
|
||||
source = obj.source and TypeSource(obj.source)
|
||||
if (
|
||||
source is None
|
||||
and isinstance(obj, variables.UserDefinedObjectVariable)
|
||||
and obj.cls_source
|
||||
):
|
||||
source = obj.cls_source
|
||||
if py_type is torch.Tensor:
|
||||
# In some cases torch isn't available in globals
|
||||
name = tx.output.install_global_by_id("", torch)
|
||||
|
|
@ -2004,6 +2010,8 @@ class BuiltinVariable(VariableTracker):
|
|||
return tensor_variable.call_id(tx)
|
||||
elif istype(args[0], variables.UserFunctionVariable):
|
||||
return variables.ConstantVariable.create(id(args[0].fn))
|
||||
elif istype(args[0], variables.SkipFunctionVariable):
|
||||
return variables.ConstantVariable.create(id(args[0].value))
|
||||
else:
|
||||
unimplemented(f"call_id with args {args}")
|
||||
|
||||
|
|
|
|||
|
|
@ -734,6 +734,13 @@ class SkipFunctionVariable(VariableTracker):
|
|||
) -> "VariableTracker":
|
||||
if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
|
||||
unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
|
||||
elif isinstance(self.value, types.WrapperDescriptorType):
|
||||
msg = (
|
||||
f"Graph break due to unsupported wrapper descriptor {self.value}. "
|
||||
f"Please file an issue on GitHub "
|
||||
f"so the PyTorch team can add support for it. "
|
||||
)
|
||||
torch._dynamo.utils.warn_once(msg)
|
||||
else:
|
||||
try:
|
||||
path = inspect.getfile(self.value)
|
||||
|
|
|
|||
Loading…
Reference in a new issue