mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[export] handle buffer/input mutations for joint-graph (#144806)
Summary: previous construction of GraphSignature output specs didn't consider buffer/user input mutations Test Plan: test_experimental Differential Revision: D68177409 Pull Request resolved: https://github.com/pytorch/pytorch/pull/144806 Approved by: https://github.com/zhxchen17, https://github.com/avikchaudhuri
This commit is contained in:
parent
d7f45fc575
commit
774f21a370
2 changed files with 83 additions and 9 deletions
|
|
@ -10,6 +10,7 @@ from torch._functorch.aot_autograd import aot_export_module
|
|||
from torch.export import export, export_for_training
|
||||
from torch.export._trace import _convert_ts_to_export_experimental
|
||||
from torch.export.experimental import _export_forward_backward
|
||||
from torch.export.graph_signature import OutputKind
|
||||
from torch.testing import FileCheck
|
||||
|
||||
|
||||
|
|
@ -264,8 +265,6 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
|||
ep = _export_forward_backward(ep)
|
||||
|
||||
def test_joint_loss_index(self):
|
||||
from torch.export.graph_signature import OutputKind
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self, index):
|
||||
super().__init__()
|
||||
|
|
@ -290,6 +289,48 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
|||
else:
|
||||
self.assertTrue(spec.kind != OutputKind.LOSS_OUTPUT)
|
||||
|
||||
def test_joint_buffer_input_mutations(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l = torch.nn.Linear(4, 4)
|
||||
self.register_buffer("buf", torch.randn(4))
|
||||
self.loss = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, x, label):
|
||||
x.add_(self.buf)
|
||||
x = self.l(x)
|
||||
self.buf.add_(2.0)
|
||||
return self.loss(x, label)
|
||||
|
||||
inputs = (
|
||||
torch.randn(4, 4),
|
||||
torch.randint(0, 4, (4,)),
|
||||
)
|
||||
ep = export(Foo(), inputs)
|
||||
ep_joint = _export_forward_backward(ep)
|
||||
self.assertEqual(len(ep_joint.graph_signature.output_specs), 5)
|
||||
self.assertEqual(
|
||||
ep_joint.graph_signature.output_specs[0].kind,
|
||||
OutputKind.BUFFER_MUTATION,
|
||||
)
|
||||
self.assertEqual(
|
||||
ep_joint.graph_signature.output_specs[0].target,
|
||||
"buf",
|
||||
)
|
||||
self.assertEqual(
|
||||
ep_joint.graph_signature.output_specs[1].kind,
|
||||
OutputKind.USER_INPUT_MUTATION,
|
||||
)
|
||||
self.assertEqual(
|
||||
ep_joint.graph_signature.output_specs[1].target,
|
||||
"x",
|
||||
)
|
||||
self.assertEqual(
|
||||
ep_joint.graph_signature.output_specs[2].kind,
|
||||
OutputKind.LOSS_OUTPUT,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -641,14 +641,47 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
|||
for i, spec in enumerate(ep.graph_signature.input_specs)
|
||||
]
|
||||
|
||||
output_specs = [
|
||||
OutputSpec(
|
||||
OutputKind.LOSS_OUTPUT if i == joint_loss_index else spec.kind,
|
||||
update_arg(spec.arg, new_outputs[i]),
|
||||
old_new_placeholder_map.get(spec.target, spec.target),
|
||||
output_specs = []
|
||||
|
||||
# handle buffer & input mutations; these appear before loss output & gradients
|
||||
# (1) ep.graph_signature.input_specs tells us types of inputs
|
||||
# (2) graph_signature.user_inputs tells us node input names in order
|
||||
# (3) graph_signature.user_inputs_to_mutate tells us buffer & input mutations
|
||||
# map (3) -> (2) for input order, -> (1) for input type
|
||||
user_inputs_index = {name: i for i, name in enumerate(graph_signature.user_inputs)}
|
||||
mutation_names = list(graph_signature.user_inputs_to_mutate.keys())
|
||||
assert mutation_names == [node.name for node in new_outputs[: len(mutation_names)]]
|
||||
for output_name, input_name in graph_signature.user_inputs_to_mutate.items():
|
||||
i = user_inputs_index[input_name]
|
||||
input_spec = ep.graph_signature.input_specs[i]
|
||||
assert input_spec.kind in (InputKind.USER_INPUT, InputKind.BUFFER)
|
||||
output_kind = (
|
||||
OutputKind.BUFFER_MUTATION
|
||||
if input_spec.kind == InputKind.BUFFER
|
||||
else OutputKind.USER_INPUT_MUTATION
|
||||
)
|
||||
target = (
|
||||
input_spec.target
|
||||
if input_spec.kind == InputKind.BUFFER
|
||||
else input_spec.arg.name
|
||||
)
|
||||
output_specs.append(
|
||||
OutputSpec(
|
||||
kind=output_kind,
|
||||
arg=TensorArgument(name=output_name),
|
||||
target=target,
|
||||
)
|
||||
)
|
||||
|
||||
# handle actual user outputs
|
||||
for i, spec in enumerate(ep.graph_signature.output_specs):
|
||||
output_specs.append(
|
||||
OutputSpec(
|
||||
OutputKind.LOSS_OUTPUT if i == joint_loss_index else spec.kind,
|
||||
update_arg(spec.arg, new_outputs[len(mutation_names) + i]),
|
||||
old_new_placeholder_map.get(spec.target, spec.target),
|
||||
)
|
||||
)
|
||||
for i, spec in enumerate(ep.graph_signature.output_specs)
|
||||
]
|
||||
|
||||
if joint_loss_index is not None:
|
||||
assert graph_signature.backward_signature is not None
|
||||
|
|
|
|||
Loading…
Reference in a new issue