[export] handle buffer/input mutations for joint-graph (#144806)

Summary: previous construction of GraphSignature output specs didn't consider buffer/user input mutations

Test Plan: test_experimental

Differential Revision: D68177409

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144806
Approved by: https://github.com/zhxchen17, https://github.com/avikchaudhuri
This commit is contained in:
Pian Pawakapan 2025-01-16 00:22:16 +00:00 committed by PyTorch MergeBot
parent d7f45fc575
commit 774f21a370
2 changed files with 83 additions and 9 deletions

View file

@ -10,6 +10,7 @@ from torch._functorch.aot_autograd import aot_export_module
from torch.export import export, export_for_training
from torch.export._trace import _convert_ts_to_export_experimental
from torch.export.experimental import _export_forward_backward
from torch.export.graph_signature import OutputKind
from torch.testing import FileCheck
@ -264,8 +265,6 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
ep = _export_forward_backward(ep)
def test_joint_loss_index(self):
from torch.export.graph_signature import OutputKind
class Foo(torch.nn.Module):
def __init__(self, index):
super().__init__()
@ -290,6 +289,48 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
else:
self.assertTrue(spec.kind != OutputKind.LOSS_OUTPUT)
def test_joint_buffer_input_mutations(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = torch.nn.Linear(4, 4)
self.register_buffer("buf", torch.randn(4))
self.loss = torch.nn.CrossEntropyLoss()
def forward(self, x, label):
x.add_(self.buf)
x = self.l(x)
self.buf.add_(2.0)
return self.loss(x, label)
inputs = (
torch.randn(4, 4),
torch.randint(0, 4, (4,)),
)
ep = export(Foo(), inputs)
ep_joint = _export_forward_backward(ep)
self.assertEqual(len(ep_joint.graph_signature.output_specs), 5)
self.assertEqual(
ep_joint.graph_signature.output_specs[0].kind,
OutputKind.BUFFER_MUTATION,
)
self.assertEqual(
ep_joint.graph_signature.output_specs[0].target,
"buf",
)
self.assertEqual(
ep_joint.graph_signature.output_specs[1].kind,
OutputKind.USER_INPUT_MUTATION,
)
self.assertEqual(
ep_joint.graph_signature.output_specs[1].target,
"x",
)
self.assertEqual(
ep_joint.graph_signature.output_specs[2].kind,
OutputKind.LOSS_OUTPUT,
)
if __name__ == "__main__":
run_tests()

View file

@ -641,14 +641,47 @@ def _decompose_and_get_gm_with_new_signature_constants(
for i, spec in enumerate(ep.graph_signature.input_specs)
]
output_specs = [
OutputSpec(
OutputKind.LOSS_OUTPUT if i == joint_loss_index else spec.kind,
update_arg(spec.arg, new_outputs[i]),
old_new_placeholder_map.get(spec.target, spec.target),
output_specs = []
# handle buffer & input mutations; these appear before loss output & gradients
# (1) ep.graph_signature.input_specs tells us types of inputs
# (2) graph_signature.user_inputs tells us node input names in order
# (3) graph_signature.user_inputs_to_mutate tells us buffer & input mutations
# map (3) -> (2) for input order, -> (1) for input type
user_inputs_index = {name: i for i, name in enumerate(graph_signature.user_inputs)}
mutation_names = list(graph_signature.user_inputs_to_mutate.keys())
assert mutation_names == [node.name for node in new_outputs[: len(mutation_names)]]
for output_name, input_name in graph_signature.user_inputs_to_mutate.items():
i = user_inputs_index[input_name]
input_spec = ep.graph_signature.input_specs[i]
assert input_spec.kind in (InputKind.USER_INPUT, InputKind.BUFFER)
output_kind = (
OutputKind.BUFFER_MUTATION
if input_spec.kind == InputKind.BUFFER
else OutputKind.USER_INPUT_MUTATION
)
target = (
input_spec.target
if input_spec.kind == InputKind.BUFFER
else input_spec.arg.name
)
output_specs.append(
OutputSpec(
kind=output_kind,
arg=TensorArgument(name=output_name),
target=target,
)
)
# handle actual user outputs
for i, spec in enumerate(ep.graph_signature.output_specs):
output_specs.append(
OutputSpec(
OutputKind.LOSS_OUTPUT if i == joint_loss_index else spec.kind,
update_arg(spec.arg, new_outputs[len(mutation_names) + i]),
old_new_placeholder_map.get(spec.target, spec.target),
)
)
for i, spec in enumerate(ep.graph_signature.output_specs)
]
if joint_loss_index is not None:
assert graph_signature.backward_signature is not None