From 103c8b44bcb6fbf30b5411c5af19d312427525e7 Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Fri, 7 Feb 2025 22:41:18 +0000 Subject: [PATCH] move and fix logic to update unbacked bindings (#146115) Summary: Previously we were touching up unbacked bindings between Dynamo and AOTAutograd in strict export, but the logic had a bug: if an unbacked symint gets substituted by a backed symint, we would put the backed symint in the unbacked bindings (the check `is_symbol` was not enough here). This PR fixes this logic, and moreover, moves it into the serializer instead, because we don't need this adjustment outside serde. Test Plan: added test D68880766 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146115 Approved by: https://github.com/pianpwk --- test/export/test_export.py | 39 +++++++++++++++++++++++++++++++- torch/_export/serde/serialize.py | 16 ++++++++++--- torch/export/_trace.py | 16 ------------- 3 files changed, 51 insertions(+), 20 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index d724d13b0c6..35a459b265b 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -980,6 +980,44 @@ graph(): ep_output = ep.module()(seq_embeddings, mask, exp) self.assertTrue(torch.allclose(output, ep_output)) + def test_mask_torch_check(self): + class TestModule(torch.nn.Module): + def forward(self, seq_embeddings, mask, exp): + output = seq_embeddings[mask] + # output.shape has unbacked symint, assert side knowledge of + # output.shape as exp.shape to force it to have backed symint + torch._check(output.size(0) == exp.size(0)) + final_output = output * 2 + return final_output + + m = TestModule() + + seq_embeddings = torch.randn(5, 5) + mask = torch.ones(5, 5, dtype=torch.bool) + exp = torch.randn(25) + output = m(seq_embeddings, mask, exp) + + batch = torch.export.Dim("batch") + exp_size = torch.export.Dim("exp_size", max=100) + ep = export( + m, + (seq_embeddings, mask, exp), + dynamic_shapes={ + "seq_embeddings": (batch, None), + "mask": (batch, None), + "exp": (exp_size,), + }, + ) + ep_output = ep.module()(seq_embeddings, mask, exp) + self.assertTrue(torch.allclose(output, ep_output)) + + seq_embeddings = torch.randn(6, 5) + mask = torch.ones(6, 5, dtype=torch.bool) + exp = torch.randn(30) + output = m(seq_embeddings, mask, exp) + ep_output = ep.module()(seq_embeddings, mask, exp) + self.assertTrue(torch.allclose(output, ep_output)) + def test_setgrad_lifted_tensor(self): class M(torch.nn.Module): def forward(self, x, y): @@ -9691,7 +9729,6 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x): "torch.ops.profiler._record_function_enter_new.default", 0, exactly=True ).run(ep.graph_module.code) - @testing.expectedFailureSerDerNonStrict def test_replace_unbacked_with_very_large_upperbound(self): # beyond 2^53 where python floats lose precision VERY_LARGE_INT = 1000000007999999992 diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index a6caf384bb9..eb0fe292361 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -635,9 +635,19 @@ class GraphModuleSerializer(metaclass=Final): if unbacked_bindings := node.meta.get("unbacked_bindings"): # serialize the symbol names of unbacked bindings; # reconstruct the key paths to those symbols when deserializing - ret["unbacked_bindings"] = ",".join( - u.name for u in unbacked_bindings.keys() - ) + val = node.meta["val"] + new_unbacked_bindings = {} + for key in unbacked_bindings.values(): + expr = pytree.key_get(val, key).node.expr + if expr.is_symbol and ( + expr.name.startswith(prefix_str[SymT.UNBACKED_FLOAT]) + or expr.name.startswith(prefix_str[SymT.UNBACKED_INT]) + ): + new_unbacked_bindings[expr] = key + if new_unbacked_bindings: + ret["unbacked_bindings"] = ",".join( + u.name for u in new_unbacked_bindings.keys() + ) if stack_trace := node.meta.get("stack_trace"): ret["stack_trace"] = stack_trace diff --git a/torch/export/_trace.py b/torch/export/_trace.py index ec6a3a6c31d..22ec6548657 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -1399,22 +1399,6 @@ def _strict_export_lower_to_aten_ir( export_graph_signature = aten_export_artifact.sig constants = aten_export_artifact.constants - # update unbacked bindings that might have gone out of sync - # between Dynamo and AOTAutograd - for node in gm.graph.nodes: - if "unbacked_bindings" in node.meta: - old_unbacked_bindings = node.meta["unbacked_bindings"] - val = node.meta["val"] - new_unbacked_bindings = {} - for key in old_unbacked_bindings.values(): - expr = pytree.key_get(val, key).node.expr - if expr.is_symbol: - new_unbacked_bindings[expr] = key - if new_unbacked_bindings: - node.meta["unbacked_bindings"] = new_unbacked_bindings - else: - del node.meta["unbacked_bindings"] - _populate_param_buffer_metadata_to_new_gm( params_buffers_to_node_meta, gm, export_graph_signature )