diff --git a/test/onnx/exporter/test_small_models_e2e.py b/test/onnx/exporter/test_small_models_e2e.py index fc9f8e51970..0135e7c282a 100644 --- a/test/onnx/exporter/test_small_models_e2e.py +++ b/test/onnx/exporter/test_small_models_e2e.py @@ -125,6 +125,33 @@ class DynamoExporterTest(common_utils.TestCase): onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([0, 0]),)) onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([43, 43]),)) + def test_onnx_export_control_flow_multi_outputs(self): + class CondModel(torch.nn.Module): + def forward(self, x): + z = torch.ones_like(x) + + def true_fn(x, z): + x = x + 1.0 + z = z * 1.0 + return x, z + + def false_fn(x, z): + x = x - 1.0 + z = z * 0.0 + return x, z + + x = torch.cond(x.sum() > 0, true_fn, false_fn, (x, z)) + return x, z + + onnx_program = torch.onnx.export( + CondModel(), + (torch.tensor([1, 2]),), + dynamo=True, + fallback=False, + ) + onnx_testing.assert_onnx_program(onnx_program) + onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([-1, -2]),)) + def test_onnx_export_torchvision_ops(self): class VisionModel(torch.nn.Module): def __init__(self): @@ -142,8 +169,6 @@ class DynamoExporterTest(common_utils.TestCase): onnx_program = self.export(VisionModel(), args) onnx_testing.assert_onnx_program(onnx_program) - # TODO(justinchuby): Test multi-output HOPs - def test_empty(self): def func(x): return torch.empty(x.size(), dtype=torch.int64) diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index d1173093b2e..1e1c13b8fd5 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -645,15 +645,18 @@ def _handle_output_node( node_name_to_values: A mapping of FX node names to their produced ONNX ``Value``. graph_like: The ONNX graph at construction. """ - output_value_name = node.args[0][0].name # type: ignore[index,union-attr] - assert isinstance( - output_value_name, str - ), f"Bug: Expected {output_value_name!r} to be a string" - values = node_name_to_values[output_value_name] - if isinstance(values, Sequence): - graph_like.outputs.extend(values) - return - graph_like.outputs.append(values) + # node.args[0] can be a tuple with more than one elements. This happens when, + # for example, a subgraph has multiple outputs. We flatten them all as ONNX graph outputs + for output in node.args[0]: # type: ignore[index,union-attr] + output_value_name = output.name # type: ignore[union-attr] + assert isinstance( + output_value_name, str + ), f"Bug: Expected {output_value_name!r} to be a string" + values = node_name_to_values[output_value_name] + if isinstance(values, Sequence): + graph_like.outputs.extend(values) + return + graph_like.outputs.append(values) def _translate_fx_graph(