diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 307a8bcf01f..f95484f0a12 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -10,6 +10,7 @@ from torch._functorch.aot_autograd import aot_export_module from torch.export import export, export_for_training from torch.export._trace import _convert_ts_to_export_experimental from torch.export.experimental import _export_forward_backward +from torch.export.graph_signature import OutputKind from torch.testing import FileCheck @@ -264,8 +265,6 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): ep = _export_forward_backward(ep) def test_joint_loss_index(self): - from torch.export.graph_signature import OutputKind - class Foo(torch.nn.Module): def __init__(self, index): super().__init__() @@ -290,6 +289,48 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): else: self.assertTrue(spec.kind != OutputKind.LOSS_OUTPUT) + def test_joint_buffer_input_mutations(self): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.l = torch.nn.Linear(4, 4) + self.register_buffer("buf", torch.randn(4)) + self.loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, label): + x.add_(self.buf) + x = self.l(x) + self.buf.add_(2.0) + return self.loss(x, label) + + inputs = ( + torch.randn(4, 4), + torch.randint(0, 4, (4,)), + ) + ep = export(Foo(), inputs) + ep_joint = _export_forward_backward(ep) + self.assertEqual(len(ep_joint.graph_signature.output_specs), 5) + self.assertEqual( + ep_joint.graph_signature.output_specs[0].kind, + OutputKind.BUFFER_MUTATION, + ) + self.assertEqual( + ep_joint.graph_signature.output_specs[0].target, + "buf", + ) + self.assertEqual( + ep_joint.graph_signature.output_specs[1].kind, + OutputKind.USER_INPUT_MUTATION, + ) + self.assertEqual( + ep_joint.graph_signature.output_specs[1].target, + "x", + ) + self.assertEqual( + ep_joint.graph_signature.output_specs[2].kind, + OutputKind.LOSS_OUTPUT, + ) + if __name__ == "__main__": run_tests() diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 4201fb3875d..c308b331299 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -641,14 +641,47 @@ def _decompose_and_get_gm_with_new_signature_constants( for i, spec in enumerate(ep.graph_signature.input_specs) ] - output_specs = [ - OutputSpec( - OutputKind.LOSS_OUTPUT if i == joint_loss_index else spec.kind, - update_arg(spec.arg, new_outputs[i]), - old_new_placeholder_map.get(spec.target, spec.target), + output_specs = [] + + # handle buffer & input mutations; these appear before loss output & gradients + # (1) ep.graph_signature.input_specs tells us types of inputs + # (2) graph_signature.user_inputs tells us node input names in order + # (3) graph_signature.user_inputs_to_mutate tells us buffer & input mutations + # map (3) -> (2) for input order, -> (1) for input type + user_inputs_index = {name: i for i, name in enumerate(graph_signature.user_inputs)} + mutation_names = list(graph_signature.user_inputs_to_mutate.keys()) + assert mutation_names == [node.name for node in new_outputs[: len(mutation_names)]] + for output_name, input_name in graph_signature.user_inputs_to_mutate.items(): + i = user_inputs_index[input_name] + input_spec = ep.graph_signature.input_specs[i] + assert input_spec.kind in (InputKind.USER_INPUT, InputKind.BUFFER) + output_kind = ( + OutputKind.BUFFER_MUTATION + if input_spec.kind == InputKind.BUFFER + else OutputKind.USER_INPUT_MUTATION + ) + target = ( + input_spec.target + if input_spec.kind == InputKind.BUFFER + else input_spec.arg.name + ) + output_specs.append( + OutputSpec( + kind=output_kind, + arg=TensorArgument(name=output_name), + target=target, + ) + ) + + # handle actual user outputs + for i, spec in enumerate(ep.graph_signature.output_specs): + output_specs.append( + OutputSpec( + OutputKind.LOSS_OUTPUT if i == joint_loss_index else spec.kind, + update_arg(spec.arg, new_outputs[len(mutation_names) + i]), + old_new_placeholder_map.get(spec.target, spec.target), + ) ) - for i, spec in enumerate(ep.graph_signature.output_specs) - ] if joint_loss_index is not None: assert graph_signature.backward_signature is not None