fix problem of reduplicate input names (#14163)

Contributor: @guyang3532
This commit is contained in:
guyang3532 2023-02-11 04:57:51 +08:00 committed by GitHub
parent 0de4bc7050
commit ba00f3a134
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 12 deletions

View file

@ -481,17 +481,17 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, schema, input
dynamic_axes[name].update({dim_idx: f"{name}_dim{dim_idx}"})
return dynamic_axes
def _add_input(name, input, onnx_graph, onnx_graph_input_names):
def _add_input(name, input_value, onnx_graph, onnx_graph_input_names):
"""Returns number of expanded non none inputs that _add_input processed"""
if input is None or isinstance(input, str):
if name in input_names or input_value is None or isinstance(input_value, str):
# Drop all None and string inputs and return 0.
return
if isinstance(input, abc.Sequence):
if isinstance(input_value, abc.Sequence):
# If the input is a sequence (like a list), expand the list so that
# each element of the list is an input by itself.
for i, val in enumerate(input):
for i, val in enumerate(input_value):
# Name each input with the index appended to the original name of the
# argument.
_add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names)
@ -499,10 +499,10 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, schema, input
# Return here since the list by itself is not a valid input.
# All the elements of the list have already been added as inputs individually.
return
elif isinstance(input, abc.Mapping):
elif isinstance(input_value, abc.Mapping):
# If the input is a mapping (like a dict), expand the dict so that
# each element of the dict is an input by itself.
for key, val in input.items():
for key, val in input_value.items():
_add_input(f"{name}_{key}", val, onnx_graph, onnx_graph_input_names)
# Return here since the dict by itself is not a valid input.
@ -513,11 +513,11 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, schema, input
# a part of the onnx graph or not.
input_names.append(name)
if (onnx_graph is None or name in onnx_graph_input_names) and isinstance(input, torch.Tensor):
if input.requires_grad:
if (onnx_graph is None or name in onnx_graph_input_names) and isinstance(input_value, torch.Tensor):
if input_value.requires_grad:
input_names_require_grad.append(name)
dynamic_axes.update(_add_dynamic_shape(name, input))
input_shape.append(list(input.size()))
dynamic_axes.update(_add_dynamic_shape(name, input_value))
input_shape.append(list(input_value.size()))
# Ignore optional inputs explicitly specified as None
# ONNX exporter may remove unused inputs
@ -557,8 +557,7 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, schema, input
elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
# **kwargs is always the last argument of forward()
for name, inp in kwargs.items():
if name not in input_names:
_add_input(name, inp, onnx_graph, onnx_graph_input_names)
_add_input(name, inp, onnx_graph, onnx_graph_input_names)
return _InputInfo(
names=input_names,

View file

@ -5607,6 +5607,61 @@ def test_kwargs_dict_input():
_test_helpers.assert_values_are_close(pt_model(x, batch=batch), ort_model(x_copy, batch=batch_copy))
def test_named_kwargs_dict_input():
class DictNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
def forward(self, *args, named_kwarg, **kwargs):
a = named_kwarg["named_one"]
b = named_kwarg["named_two"]["named_three"]
c = named_kwarg["named_two"]["named_four"]
d = named_kwarg["named_five"]["named_six"]
e = named_kwarg["named_five"]["named_seven"]["named_eight"]
batch = kwargs["batch"]
f = batch["one_value"]
g = batch["two_value"]["three_value"]
h = batch["two_value"]["four_value"]
i = batch["five_value"]["six_value"]
j = batch["five_value"]["seven_value"]["eight_value"]
return self.dummy + a + b + c + d + e + f + g + h + i + j
device = "cuda"
N, D_in = 64, 784
pt_model = DictNet().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
x = torch.randn(N, D_in, device=device)
named_kwarg = {
"named_one": torch.randn(N, D_in, device=device),
"named_two": {
"named_three": torch.randn(N, D_in, device=device),
"named_four": torch.randn(N, D_in, device=device),
},
"named_five": {
"named_six": torch.randn(N, D_in, device=device),
"named_seven": {"named_eight": torch.randn(N, D_in, device=device)},
},
}
batch = {
"one_value": torch.randn(N, D_in, device=device),
"two_value": {
"three_value": torch.randn(N, D_in, device=device),
"four_value": torch.randn(N, D_in, device=device),
},
"five_value": {
"six_value": torch.randn(N, D_in, device=device),
"seven_value": {"eight_value": torch.randn(N, D_in, device=device)},
},
}
batch_copy = copy.deepcopy(batch)
x_copy = copy.deepcopy(x)
_test_helpers.assert_values_are_close(
pt_model(x, named_kwarg=named_kwarg, batch=batch), ort_model(x_copy, named_kwarg=named_kwarg, batch=batch_copy)
)
@pytest.mark.parametrize("training_mode", [False, True])
def test_non_contiguous_tensors_as_inputs(training_mode):
class NonContigousNet(torch.nn.Module):