mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[ONNX] Support subgraphs with 1+ outputs (#145860)
Fixed a bug in _handle_output_node where additional output values were not added as graph outputs Fixes #145734 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145860 Approved by: https://github.com/titaiwangms
This commit is contained in:
parent
fd515e4f59
commit
776bdb962c
2 changed files with 39 additions and 11 deletions
|
|
@ -125,6 +125,33 @@ class DynamoExporterTest(common_utils.TestCase):
|
|||
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([0, 0]),))
|
||||
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([43, 43]),))
|
||||
|
||||
def test_onnx_export_control_flow_multi_outputs(self):
|
||||
class CondModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
z = torch.ones_like(x)
|
||||
|
||||
def true_fn(x, z):
|
||||
x = x + 1.0
|
||||
z = z * 1.0
|
||||
return x, z
|
||||
|
||||
def false_fn(x, z):
|
||||
x = x - 1.0
|
||||
z = z * 0.0
|
||||
return x, z
|
||||
|
||||
x = torch.cond(x.sum() > 0, true_fn, false_fn, (x, z))
|
||||
return x, z
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
CondModel(),
|
||||
(torch.tensor([1, 2]),),
|
||||
dynamo=True,
|
||||
fallback=False,
|
||||
)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([-1, -2]),))
|
||||
|
||||
def test_onnx_export_torchvision_ops(self):
|
||||
class VisionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -142,8 +169,6 @@ class DynamoExporterTest(common_utils.TestCase):
|
|||
onnx_program = self.export(VisionModel(), args)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
# TODO(justinchuby): Test multi-output HOPs
|
||||
|
||||
def test_empty(self):
|
||||
def func(x):
|
||||
return torch.empty(x.size(), dtype=torch.int64)
|
||||
|
|
|
|||
|
|
@ -645,15 +645,18 @@ def _handle_output_node(
|
|||
node_name_to_values: A mapping of FX node names to their produced ONNX ``Value``.
|
||||
graph_like: The ONNX graph at construction.
|
||||
"""
|
||||
output_value_name = node.args[0][0].name # type: ignore[index,union-attr]
|
||||
assert isinstance(
|
||||
output_value_name, str
|
||||
), f"Bug: Expected {output_value_name!r} to be a string"
|
||||
values = node_name_to_values[output_value_name]
|
||||
if isinstance(values, Sequence):
|
||||
graph_like.outputs.extend(values)
|
||||
return
|
||||
graph_like.outputs.append(values)
|
||||
# node.args[0] can be a tuple with more than one elements. This happens when,
|
||||
# for example, a subgraph has multiple outputs. We flatten them all as ONNX graph outputs
|
||||
for output in node.args[0]: # type: ignore[index,union-attr]
|
||||
output_value_name = output.name # type: ignore[union-attr]
|
||||
assert isinstance(
|
||||
output_value_name, str
|
||||
), f"Bug: Expected {output_value_name!r} to be a string"
|
||||
values = node_name_to_values[output_value_name]
|
||||
if isinstance(values, Sequence):
|
||||
graph_like.outputs.extend(values)
|
||||
return
|
||||
graph_like.outputs.append(values)
|
||||
|
||||
|
||||
def _translate_fx_graph(
|
||||
|
|
|
|||
Loading…
Reference in a new issue