[ONNX] Support subgraphs with 1+ outputs (#145860)

Fixed a bug in _handle_output_node where additional output values were not added as graph outputs

Fixes #145734
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145860
Approved by: https://github.com/titaiwangms
This commit is contained in:
Justin Chu 2025-01-29 04:13:23 +00:00 committed by PyTorch MergeBot
parent fd515e4f59
commit 776bdb962c
2 changed files with 39 additions and 11 deletions

View file

@ -125,6 +125,33 @@ class DynamoExporterTest(common_utils.TestCase):
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([0, 0]),))
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([43, 43]),))
def test_onnx_export_control_flow_multi_outputs(self):
class CondModel(torch.nn.Module):
def forward(self, x):
z = torch.ones_like(x)
def true_fn(x, z):
x = x + 1.0
z = z * 1.0
return x, z
def false_fn(x, z):
x = x - 1.0
z = z * 0.0
return x, z
x = torch.cond(x.sum() > 0, true_fn, false_fn, (x, z))
return x, z
onnx_program = torch.onnx.export(
CondModel(),
(torch.tensor([1, 2]),),
dynamo=True,
fallback=False,
)
onnx_testing.assert_onnx_program(onnx_program)
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([-1, -2]),))
def test_onnx_export_torchvision_ops(self):
class VisionModel(torch.nn.Module):
def __init__(self):
@ -142,8 +169,6 @@ class DynamoExporterTest(common_utils.TestCase):
onnx_program = self.export(VisionModel(), args)
onnx_testing.assert_onnx_program(onnx_program)
# TODO(justinchuby): Test multi-output HOPs
def test_empty(self):
def func(x):
return torch.empty(x.size(), dtype=torch.int64)

View file

@ -645,15 +645,18 @@ def _handle_output_node(
node_name_to_values: A mapping of FX node names to their produced ONNX ``Value``.
graph_like: The ONNX graph at construction.
"""
output_value_name = node.args[0][0].name # type: ignore[index,union-attr]
assert isinstance(
output_value_name, str
), f"Bug: Expected {output_value_name!r} to be a string"
values = node_name_to_values[output_value_name]
if isinstance(values, Sequence):
graph_like.outputs.extend(values)
return
graph_like.outputs.append(values)
# node.args[0] can be a tuple with more than one elements. This happens when,
# for example, a subgraph has multiple outputs. We flatten them all as ONNX graph outputs
for output in node.args[0]: # type: ignore[index,union-attr]
output_value_name = output.name # type: ignore[union-attr]
assert isinstance(
output_value_name, str
), f"Bug: Expected {output_value_name!r} to be a string"
values = node_name_to_values[output_value_name]
if isinstance(values, Sequence):
graph_like.outputs.extend(values)
return
graph_like.outputs.append(values)
def _translate_fx_graph(