mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
retracing in strict doesn't like dataclass registration (#144487)
Retracing in strict doesn't seem to like dataclass registration. Just refactoring some tests to make this explicit (whereas other export testing variants work fine). Differential Revision: [D67985149](https://our.internmc.facebook.com/intern/diff/D67985149/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144487 Approved by: https://github.com/angelayi
This commit is contained in:
parent
b2fde28283
commit
a222029f4e
1 changed files with 59 additions and 62 deletions
|
|
@ -169,12 +169,24 @@ def foo_unbacked(x):
|
|||
|
||||
|
||||
@dataclass
|
||||
class Inp:
|
||||
class Inp1:
|
||||
x: Tensor
|
||||
y: List[Tensor]
|
||||
z: Dict[str, Tensor]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Inp2:
|
||||
a: Tensor
|
||||
b: Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class Inp3:
|
||||
f: torch.Tensor
|
||||
p: torch.Tensor
|
||||
|
||||
|
||||
NON_STRICT_SUFFIX = "_non_strict"
|
||||
RETRACEABILITY_STRICT_SUFFIX = "_retraceability"
|
||||
RETRACEABILITY_NON_STRICT_SUFFIX = "_retraceability_non_strict"
|
||||
|
|
@ -3359,22 +3371,22 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
|
||||
# retracing doesn't seem to like dataclass registration,
|
||||
# raising a dynamo error in fx_pytree.tree_flatten_spec
|
||||
@testing.expectedFailureRetraceability
|
||||
@testing.expectedFailureRetraceability # T186979579
|
||||
def test_dynamic_shapes_builder_pytree(self):
|
||||
torch.export.register_dataclass(
|
||||
Inp,
|
||||
serialized_type_name="test_dynamic_shapes_builder_pytree.Inp",
|
||||
Inp1,
|
||||
serialized_type_name="test_dynamic_shapes_builder_pytree.Inp1",
|
||||
)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, inp: Inp):
|
||||
def forward(self, inp: Inp1):
|
||||
return inp.x + inp.y[0] + inp.z["k"]
|
||||
|
||||
m = M()
|
||||
x = torch.randn(4)
|
||||
y = [torch.randn(4)]
|
||||
z = {"k": torch.randn(4)}
|
||||
args = (Inp(x, y, z),)
|
||||
args = (Inp1(x, y, z),)
|
||||
|
||||
shapes_collection = torch.export.ShapesCollection()
|
||||
dim = torch.export.Dim("dim", max=10)
|
||||
|
|
@ -4369,8 +4381,36 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
):
|
||||
self.assertTrue("source_fn_stack" in node.meta)
|
||||
|
||||
@testing.expectedFailureRetraceability # T186979579
|
||||
def test_dynamic_shapes_dataclass(self):
|
||||
torch.export.register_dataclass(
|
||||
Inp2,
|
||||
serialized_type_name="test_export_api_with_dynamic_shapes.Inp2",
|
||||
)
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, inputs):
|
||||
return torch.matmul(inputs.a, inputs.b)
|
||||
|
||||
foo = Foo()
|
||||
inputs = (Inp2(a=torch.randn(10, 2, 3), b=torch.randn(10, 3, 4)),)
|
||||
batch = Dim("batch")
|
||||
efoo = export(
|
||||
foo,
|
||||
inputs,
|
||||
dynamic_shapes={"inputs": [{0: batch}, {0: batch}]},
|
||||
)
|
||||
self.assertEqual(
|
||||
[
|
||||
str(node.meta["val"].shape)
|
||||
for node in efoo.graph_module.graph.nodes
|
||||
if node.op == "placeholder"
|
||||
],
|
||||
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
|
||||
)
|
||||
|
||||
def test_export_api_with_dynamic_shapes(self):
|
||||
from torch.export import Dim, dims, export
|
||||
from torch.export import Dim, dims
|
||||
|
||||
# pass dynamic shapes of inputs [args]
|
||||
class Foo(torch.nn.Module):
|
||||
|
|
@ -4513,43 +4553,6 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
)
|
||||
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
|
||||
|
||||
# pass dynamic shapes of inputs [dataclass]
|
||||
|
||||
# TODO(avik): This part of the test should have failed both serde and retracing
|
||||
# but these failures are hidden because of the local import of `export` in this test.
|
||||
# The serde failure is benign, and easily avoided by moving the dataclass definition
|
||||
# to the top-level. OTOH the retracing failure needs further investigation.
|
||||
@dataclass
|
||||
class DataClass:
|
||||
a: Tensor
|
||||
b: Tensor
|
||||
|
||||
register_dataclass_as_pytree_node(
|
||||
DataClass,
|
||||
serialized_type_name="test_export_api_with_dynamic_shapes.DataClass",
|
||||
)
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, inputs):
|
||||
return torch.matmul(inputs.a, inputs.b)
|
||||
|
||||
foo = Foo()
|
||||
inputs = (DataClass(a=torch.randn(10, 2, 3), b=torch.randn(10, 3, 4)),)
|
||||
batch = Dim("batch")
|
||||
efoo = export(
|
||||
foo,
|
||||
inputs,
|
||||
dynamic_shapes={"inputs": [{0: batch}, {0: batch}]},
|
||||
)
|
||||
self.assertEqual(
|
||||
[
|
||||
str(node.meta["val"].shape)
|
||||
for node in efoo.graph_module.graph.nodes
|
||||
if node.op == "placeholder"
|
||||
],
|
||||
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
|
||||
)
|
||||
|
||||
# pass dynamic shapes of inputs [pytree-registered classes]
|
||||
if HAS_TORCHREC:
|
||||
# skipping tests if torchrec not available
|
||||
|
|
@ -4890,7 +4893,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
self.assertTrue(spec, LeafSpec())
|
||||
self.assertTrue(len(flat) == 1)
|
||||
|
||||
register_dataclass_as_pytree_node(
|
||||
torch.export.register_dataclass(
|
||||
MyDataClass,
|
||||
serialized_type_name="test_pytree_register_data_class.MyDataClass",
|
||||
)
|
||||
|
|
@ -4961,10 +4964,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
dt = Outer(xy, ab)
|
||||
inp = {"dt1": (dt, ({},)), "dt2": ((torch.ones(1),), dt)}
|
||||
|
||||
register_dataclass_as_pytree_node(
|
||||
torch.export.register_dataclass(
|
||||
Inner, serialized_type_name="test_pytree_register_nested_data_class.Inner"
|
||||
)
|
||||
register_dataclass_as_pytree_node(
|
||||
torch.export.register_dataclass(
|
||||
Outer, serialized_type_name="test_pytree_register_nested_data_class.Outer"
|
||||
)
|
||||
|
||||
|
|
@ -6172,24 +6175,20 @@ def forward(self, b_a_buffer, x):
|
|||
ep = export(m, ())
|
||||
self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"])
|
||||
|
||||
@testing.expectedFailureRetraceability # T186979579
|
||||
def test_preserve_shape_dynamism_for_unused_inputs(self):
|
||||
@dataclass
|
||||
class Input:
|
||||
f: torch.Tensor
|
||||
p: torch.Tensor
|
||||
|
||||
torch._export.utils.register_dataclass_as_pytree_node(
|
||||
Input,
|
||||
serialized_type_name="test_preserve_shape_dynamism_for_unused_inputs.Input",
|
||||
torch.export.register_dataclass(
|
||||
Inp3,
|
||||
serialized_type_name="test_preserve_shape_dynamism_for_unused_inputs.Inp3",
|
||||
)
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: Input):
|
||||
def forward(self, x: Inp3):
|
||||
return x.f + 1
|
||||
|
||||
mod = Module()
|
||||
example_inputs = (Input(f=torch.ones(10, 4), p=torch.zeros(10, 4)),)
|
||||
ep_static = torch.export.export(mod, example_inputs)
|
||||
example_inputs = (Inp3(f=torch.ones(10, 4), p=torch.zeros(10, 4)),)
|
||||
ep_static = export(mod, example_inputs)
|
||||
for node in ep_static.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
for s in node.meta["val"].shape:
|
||||
|
|
@ -6197,9 +6196,7 @@ def forward(self, b_a_buffer, x):
|
|||
|
||||
dim0_x_f, dim0_x_p = torch.export.dims("dim0_x_f", "dim0_x_p")
|
||||
dynamic_shapes = {"x": [{0: dim0_x_f}, {0: dim0_x_p}]}
|
||||
ep_dynamic = torch.export.export(
|
||||
mod, example_inputs, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
ep_dynamic = export(mod, example_inputs, dynamic_shapes=dynamic_shapes)
|
||||
for node in ep_dynamic.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
for i, s in enumerate(node.meta["val"].shape):
|
||||
|
|
@ -10944,7 +10941,7 @@ def forward(self, x):
|
|||
a: Tensor
|
||||
b: Tensor
|
||||
|
||||
register_dataclass_as_pytree_node(
|
||||
torch.export.register_dataclass(
|
||||
Input,
|
||||
serialized_type_name="test_dynamic_shapes_serdes_various.Input",
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue