retracing in strict doesn't like dataclass registration (#144487)

Retracing in strict doesn't seem to like dataclass registration. Just refactoring some tests to make this explicit (whereas other export testing variants work fine).

Differential Revision: [D67985149](https://our.internmc.facebook.com/intern/diff/D67985149/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144487
Approved by: https://github.com/angelayi
This commit is contained in:
Avik Chaudhuri 2025-01-09 16:42:07 -08:00 committed by PyTorch MergeBot
parent b2fde28283
commit a222029f4e

View file

@ -169,12 +169,24 @@ def foo_unbacked(x):
@dataclass
class Inp:
class Inp1:
x: Tensor
y: List[Tensor]
z: Dict[str, Tensor]
@dataclass
class Inp2:
a: Tensor
b: Tensor
@dataclass
class Inp3:
f: torch.Tensor
p: torch.Tensor
NON_STRICT_SUFFIX = "_non_strict"
RETRACEABILITY_STRICT_SUFFIX = "_retraceability"
RETRACEABILITY_NON_STRICT_SUFFIX = "_retraceability_non_strict"
@ -3359,22 +3371,22 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
# retracing doesn't seem to like dataclass registration,
# raising a dynamo error in fx_pytree.tree_flatten_spec
@testing.expectedFailureRetraceability
@testing.expectedFailureRetraceability # T186979579
def test_dynamic_shapes_builder_pytree(self):
torch.export.register_dataclass(
Inp,
serialized_type_name="test_dynamic_shapes_builder_pytree.Inp",
Inp1,
serialized_type_name="test_dynamic_shapes_builder_pytree.Inp1",
)
class M(torch.nn.Module):
def forward(self, inp: Inp):
def forward(self, inp: Inp1):
return inp.x + inp.y[0] + inp.z["k"]
m = M()
x = torch.randn(4)
y = [torch.randn(4)]
z = {"k": torch.randn(4)}
args = (Inp(x, y, z),)
args = (Inp1(x, y, z),)
shapes_collection = torch.export.ShapesCollection()
dim = torch.export.Dim("dim", max=10)
@ -4369,8 +4381,36 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
):
self.assertTrue("source_fn_stack" in node.meta)
@testing.expectedFailureRetraceability # T186979579
def test_dynamic_shapes_dataclass(self):
torch.export.register_dataclass(
Inp2,
serialized_type_name="test_export_api_with_dynamic_shapes.Inp2",
)
class Foo(torch.nn.Module):
def forward(self, inputs):
return torch.matmul(inputs.a, inputs.b)
foo = Foo()
inputs = (Inp2(a=torch.randn(10, 2, 3), b=torch.randn(10, 3, 4)),)
batch = Dim("batch")
efoo = export(
foo,
inputs,
dynamic_shapes={"inputs": [{0: batch}, {0: batch}]},
)
self.assertEqual(
[
str(node.meta["val"].shape)
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
)
def test_export_api_with_dynamic_shapes(self):
from torch.export import Dim, dims, export
from torch.export import Dim, dims
# pass dynamic shapes of inputs [args]
class Foo(torch.nn.Module):
@ -4513,43 +4553,6 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
)
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
# pass dynamic shapes of inputs [dataclass]
# TODO(avik): This part of the test should have failed both serde and retracing
# but these failures are hidden because of the local import of `export` in this test.
# The serde failure is benign, and easily avoided by moving the dataclass definition
# to the top-level. OTOH the retracing failure needs further investigation.
@dataclass
class DataClass:
a: Tensor
b: Tensor
register_dataclass_as_pytree_node(
DataClass,
serialized_type_name="test_export_api_with_dynamic_shapes.DataClass",
)
class Foo(torch.nn.Module):
def forward(self, inputs):
return torch.matmul(inputs.a, inputs.b)
foo = Foo()
inputs = (DataClass(a=torch.randn(10, 2, 3), b=torch.randn(10, 3, 4)),)
batch = Dim("batch")
efoo = export(
foo,
inputs,
dynamic_shapes={"inputs": [{0: batch}, {0: batch}]},
)
self.assertEqual(
[
str(node.meta["val"].shape)
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
)
# pass dynamic shapes of inputs [pytree-registered classes]
if HAS_TORCHREC:
# skipping tests if torchrec not available
@ -4890,7 +4893,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
self.assertTrue(spec, LeafSpec())
self.assertTrue(len(flat) == 1)
register_dataclass_as_pytree_node(
torch.export.register_dataclass(
MyDataClass,
serialized_type_name="test_pytree_register_data_class.MyDataClass",
)
@ -4961,10 +4964,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
dt = Outer(xy, ab)
inp = {"dt1": (dt, ({},)), "dt2": ((torch.ones(1),), dt)}
register_dataclass_as_pytree_node(
torch.export.register_dataclass(
Inner, serialized_type_name="test_pytree_register_nested_data_class.Inner"
)
register_dataclass_as_pytree_node(
torch.export.register_dataclass(
Outer, serialized_type_name="test_pytree_register_nested_data_class.Outer"
)
@ -6172,24 +6175,20 @@ def forward(self, b_a_buffer, x):
ep = export(m, ())
self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"])
@testing.expectedFailureRetraceability # T186979579
def test_preserve_shape_dynamism_for_unused_inputs(self):
@dataclass
class Input:
f: torch.Tensor
p: torch.Tensor
torch._export.utils.register_dataclass_as_pytree_node(
Input,
serialized_type_name="test_preserve_shape_dynamism_for_unused_inputs.Input",
torch.export.register_dataclass(
Inp3,
serialized_type_name="test_preserve_shape_dynamism_for_unused_inputs.Inp3",
)
class Module(torch.nn.Module):
def forward(self, x: Input):
def forward(self, x: Inp3):
return x.f + 1
mod = Module()
example_inputs = (Input(f=torch.ones(10, 4), p=torch.zeros(10, 4)),)
ep_static = torch.export.export(mod, example_inputs)
example_inputs = (Inp3(f=torch.ones(10, 4), p=torch.zeros(10, 4)),)
ep_static = export(mod, example_inputs)
for node in ep_static.graph.nodes:
if node.op == "placeholder":
for s in node.meta["val"].shape:
@ -6197,9 +6196,7 @@ def forward(self, b_a_buffer, x):
dim0_x_f, dim0_x_p = torch.export.dims("dim0_x_f", "dim0_x_p")
dynamic_shapes = {"x": [{0: dim0_x_f}, {0: dim0_x_p}]}
ep_dynamic = torch.export.export(
mod, example_inputs, dynamic_shapes=dynamic_shapes
)
ep_dynamic = export(mod, example_inputs, dynamic_shapes=dynamic_shapes)
for node in ep_dynamic.graph.nodes:
if node.op == "placeholder":
for i, s in enumerate(node.meta["val"].shape):
@ -10944,7 +10941,7 @@ def forward(self, x):
a: Tensor
b: Tensor
register_dataclass_as_pytree_node(
torch.export.register_dataclass(
Input,
serialized_type_name="test_dynamic_shapes_serdes_various.Input",
)