move and fix logic to update unbacked bindings (#146115)

Summary:
Previously we were touching up unbacked bindings between Dynamo and AOTAutograd in strict export, but the logic had a bug: if an unbacked symint gets substituted by a backed symint, we would put the backed symint in the unbacked bindings (the check `is_symbol` was not enough here).

This PR fixes this logic, and moreover, moves it into the serializer instead, because we don't need this adjustment outside serde.

Test Plan: added test

 D68880766

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146115
Approved by: https://github.com/pianpwk
This commit is contained in:
Avik Chaudhuri 2025-02-07 22:41:18 +00:00 committed by PyTorch MergeBot
parent 45d35f5f5a
commit 103c8b44bc
3 changed files with 51 additions and 20 deletions

View file

@ -980,6 +980,44 @@ graph():
ep_output = ep.module()(seq_embeddings, mask, exp)
self.assertTrue(torch.allclose(output, ep_output))
def test_mask_torch_check(self):
class TestModule(torch.nn.Module):
def forward(self, seq_embeddings, mask, exp):
output = seq_embeddings[mask]
# output.shape has unbacked symint, assert side knowledge of
# output.shape as exp.shape to force it to have backed symint
torch._check(output.size(0) == exp.size(0))
final_output = output * 2
return final_output
m = TestModule()
seq_embeddings = torch.randn(5, 5)
mask = torch.ones(5, 5, dtype=torch.bool)
exp = torch.randn(25)
output = m(seq_embeddings, mask, exp)
batch = torch.export.Dim("batch")
exp_size = torch.export.Dim("exp_size", max=100)
ep = export(
m,
(seq_embeddings, mask, exp),
dynamic_shapes={
"seq_embeddings": (batch, None),
"mask": (batch, None),
"exp": (exp_size,),
},
)
ep_output = ep.module()(seq_embeddings, mask, exp)
self.assertTrue(torch.allclose(output, ep_output))
seq_embeddings = torch.randn(6, 5)
mask = torch.ones(6, 5, dtype=torch.bool)
exp = torch.randn(30)
output = m(seq_embeddings, mask, exp)
ep_output = ep.module()(seq_embeddings, mask, exp)
self.assertTrue(torch.allclose(output, ep_output))
def test_setgrad_lifted_tensor(self):
class M(torch.nn.Module):
def forward(self, x, y):
@ -9691,7 +9729,6 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
"torch.ops.profiler._record_function_enter_new.default", 0, exactly=True
).run(ep.graph_module.code)
@testing.expectedFailureSerDerNonStrict
def test_replace_unbacked_with_very_large_upperbound(self):
# beyond 2^53 where python floats lose precision
VERY_LARGE_INT = 1000000007999999992

View file

@ -635,9 +635,19 @@ class GraphModuleSerializer(metaclass=Final):
if unbacked_bindings := node.meta.get("unbacked_bindings"):
# serialize the symbol names of unbacked bindings;
# reconstruct the key paths to those symbols when deserializing
ret["unbacked_bindings"] = ",".join(
u.name for u in unbacked_bindings.keys()
)
val = node.meta["val"]
new_unbacked_bindings = {}
for key in unbacked_bindings.values():
expr = pytree.key_get(val, key).node.expr
if expr.is_symbol and (
expr.name.startswith(prefix_str[SymT.UNBACKED_FLOAT])
or expr.name.startswith(prefix_str[SymT.UNBACKED_INT])
):
new_unbacked_bindings[expr] = key
if new_unbacked_bindings:
ret["unbacked_bindings"] = ",".join(
u.name for u in new_unbacked_bindings.keys()
)
if stack_trace := node.meta.get("stack_trace"):
ret["stack_trace"] = stack_trace

View file

@ -1399,22 +1399,6 @@ def _strict_export_lower_to_aten_ir(
export_graph_signature = aten_export_artifact.sig
constants = aten_export_artifact.constants
# update unbacked bindings that might have gone out of sync
# between Dynamo and AOTAutograd
for node in gm.graph.nodes:
if "unbacked_bindings" in node.meta:
old_unbacked_bindings = node.meta["unbacked_bindings"]
val = node.meta["val"]
new_unbacked_bindings = {}
for key in old_unbacked_bindings.values():
expr = pytree.key_get(val, key).node.expr
if expr.is_symbol:
new_unbacked_bindings[expr] = key
if new_unbacked_bindings:
node.meta["unbacked_bindings"] = new_unbacked_bindings
else:
del node.meta["unbacked_bindings"]
_populate_param_buffer_metadata_to_new_gm(
params_buffers_to_node_meta, gm, export_graph_signature
)