mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
move and fix logic to update unbacked bindings (#146115)
Summary: Previously we were touching up unbacked bindings between Dynamo and AOTAutograd in strict export, but the logic had a bug: if an unbacked symint gets substituted by a backed symint, we would put the backed symint in the unbacked bindings (the check `is_symbol` was not enough here). This PR fixes this logic, and moreover, moves it into the serializer instead, because we don't need this adjustment outside serde. Test Plan: added test D68880766 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146115 Approved by: https://github.com/pianpwk
This commit is contained in:
parent
45d35f5f5a
commit
103c8b44bc
3 changed files with 51 additions and 20 deletions
|
|
@ -980,6 +980,44 @@ graph():
|
|||
ep_output = ep.module()(seq_embeddings, mask, exp)
|
||||
self.assertTrue(torch.allclose(output, ep_output))
|
||||
|
||||
def test_mask_torch_check(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, seq_embeddings, mask, exp):
|
||||
output = seq_embeddings[mask]
|
||||
# output.shape has unbacked symint, assert side knowledge of
|
||||
# output.shape as exp.shape to force it to have backed symint
|
||||
torch._check(output.size(0) == exp.size(0))
|
||||
final_output = output * 2
|
||||
return final_output
|
||||
|
||||
m = TestModule()
|
||||
|
||||
seq_embeddings = torch.randn(5, 5)
|
||||
mask = torch.ones(5, 5, dtype=torch.bool)
|
||||
exp = torch.randn(25)
|
||||
output = m(seq_embeddings, mask, exp)
|
||||
|
||||
batch = torch.export.Dim("batch")
|
||||
exp_size = torch.export.Dim("exp_size", max=100)
|
||||
ep = export(
|
||||
m,
|
||||
(seq_embeddings, mask, exp),
|
||||
dynamic_shapes={
|
||||
"seq_embeddings": (batch, None),
|
||||
"mask": (batch, None),
|
||||
"exp": (exp_size,),
|
||||
},
|
||||
)
|
||||
ep_output = ep.module()(seq_embeddings, mask, exp)
|
||||
self.assertTrue(torch.allclose(output, ep_output))
|
||||
|
||||
seq_embeddings = torch.randn(6, 5)
|
||||
mask = torch.ones(6, 5, dtype=torch.bool)
|
||||
exp = torch.randn(30)
|
||||
output = m(seq_embeddings, mask, exp)
|
||||
ep_output = ep.module()(seq_embeddings, mask, exp)
|
||||
self.assertTrue(torch.allclose(output, ep_output))
|
||||
|
||||
def test_setgrad_lifted_tensor(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
|
|
@ -9691,7 +9729,6 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
|||
"torch.ops.profiler._record_function_enter_new.default", 0, exactly=True
|
||||
).run(ep.graph_module.code)
|
||||
|
||||
@testing.expectedFailureSerDerNonStrict
|
||||
def test_replace_unbacked_with_very_large_upperbound(self):
|
||||
# beyond 2^53 where python floats lose precision
|
||||
VERY_LARGE_INT = 1000000007999999992
|
||||
|
|
|
|||
|
|
@ -635,9 +635,19 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
if unbacked_bindings := node.meta.get("unbacked_bindings"):
|
||||
# serialize the symbol names of unbacked bindings;
|
||||
# reconstruct the key paths to those symbols when deserializing
|
||||
ret["unbacked_bindings"] = ",".join(
|
||||
u.name for u in unbacked_bindings.keys()
|
||||
)
|
||||
val = node.meta["val"]
|
||||
new_unbacked_bindings = {}
|
||||
for key in unbacked_bindings.values():
|
||||
expr = pytree.key_get(val, key).node.expr
|
||||
if expr.is_symbol and (
|
||||
expr.name.startswith(prefix_str[SymT.UNBACKED_FLOAT])
|
||||
or expr.name.startswith(prefix_str[SymT.UNBACKED_INT])
|
||||
):
|
||||
new_unbacked_bindings[expr] = key
|
||||
if new_unbacked_bindings:
|
||||
ret["unbacked_bindings"] = ",".join(
|
||||
u.name for u in new_unbacked_bindings.keys()
|
||||
)
|
||||
|
||||
if stack_trace := node.meta.get("stack_trace"):
|
||||
ret["stack_trace"] = stack_trace
|
||||
|
|
|
|||
|
|
@ -1399,22 +1399,6 @@ def _strict_export_lower_to_aten_ir(
|
|||
export_graph_signature = aten_export_artifact.sig
|
||||
constants = aten_export_artifact.constants
|
||||
|
||||
# update unbacked bindings that might have gone out of sync
|
||||
# between Dynamo and AOTAutograd
|
||||
for node in gm.graph.nodes:
|
||||
if "unbacked_bindings" in node.meta:
|
||||
old_unbacked_bindings = node.meta["unbacked_bindings"]
|
||||
val = node.meta["val"]
|
||||
new_unbacked_bindings = {}
|
||||
for key in old_unbacked_bindings.values():
|
||||
expr = pytree.key_get(val, key).node.expr
|
||||
if expr.is_symbol:
|
||||
new_unbacked_bindings[expr] = key
|
||||
if new_unbacked_bindings:
|
||||
node.meta["unbacked_bindings"] = new_unbacked_bindings
|
||||
else:
|
||||
del node.meta["unbacked_bindings"]
|
||||
|
||||
_populate_param_buffer_metadata_to_new_gm(
|
||||
params_buffers_to_node_meta, gm, export_graph_signature
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue