mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
fix problem of reduplicate input names (#14163)
Contributor: @guyang3532
This commit is contained in:
parent
0de4bc7050
commit
ba00f3a134
2 changed files with 66 additions and 12 deletions
|
|
@ -481,17 +481,17 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, schema, input
|
|||
dynamic_axes[name].update({dim_idx: f"{name}_dim{dim_idx}"})
|
||||
return dynamic_axes
|
||||
|
||||
def _add_input(name, input, onnx_graph, onnx_graph_input_names):
|
||||
def _add_input(name, input_value, onnx_graph, onnx_graph_input_names):
|
||||
"""Returns number of expanded non none inputs that _add_input processed"""
|
||||
|
||||
if input is None or isinstance(input, str):
|
||||
if name in input_names or input_value is None or isinstance(input_value, str):
|
||||
# Drop all None and string inputs and return 0.
|
||||
return
|
||||
|
||||
if isinstance(input, abc.Sequence):
|
||||
if isinstance(input_value, abc.Sequence):
|
||||
# If the input is a sequence (like a list), expand the list so that
|
||||
# each element of the list is an input by itself.
|
||||
for i, val in enumerate(input):
|
||||
for i, val in enumerate(input_value):
|
||||
# Name each input with the index appended to the original name of the
|
||||
# argument.
|
||||
_add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names)
|
||||
|
|
@ -499,10 +499,10 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, schema, input
|
|||
# Return here since the list by itself is not a valid input.
|
||||
# All the elements of the list have already been added as inputs individually.
|
||||
return
|
||||
elif isinstance(input, abc.Mapping):
|
||||
elif isinstance(input_value, abc.Mapping):
|
||||
# If the input is a mapping (like a dict), expand the dict so that
|
||||
# each element of the dict is an input by itself.
|
||||
for key, val in input.items():
|
||||
for key, val in input_value.items():
|
||||
_add_input(f"{name}_{key}", val, onnx_graph, onnx_graph_input_names)
|
||||
|
||||
# Return here since the dict by itself is not a valid input.
|
||||
|
|
@ -513,11 +513,11 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, schema, input
|
|||
# a part of the onnx graph or not.
|
||||
input_names.append(name)
|
||||
|
||||
if (onnx_graph is None or name in onnx_graph_input_names) and isinstance(input, torch.Tensor):
|
||||
if input.requires_grad:
|
||||
if (onnx_graph is None or name in onnx_graph_input_names) and isinstance(input_value, torch.Tensor):
|
||||
if input_value.requires_grad:
|
||||
input_names_require_grad.append(name)
|
||||
dynamic_axes.update(_add_dynamic_shape(name, input))
|
||||
input_shape.append(list(input.size()))
|
||||
dynamic_axes.update(_add_dynamic_shape(name, input_value))
|
||||
input_shape.append(list(input_value.size()))
|
||||
|
||||
# Ignore optional inputs explicitly specified as None
|
||||
# ONNX exporter may remove unused inputs
|
||||
|
|
@ -557,8 +557,7 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, schema, input
|
|||
elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
# **kwargs is always the last argument of forward()
|
||||
for name, inp in kwargs.items():
|
||||
if name not in input_names:
|
||||
_add_input(name, inp, onnx_graph, onnx_graph_input_names)
|
||||
_add_input(name, inp, onnx_graph, onnx_graph_input_names)
|
||||
|
||||
return _InputInfo(
|
||||
names=input_names,
|
||||
|
|
|
|||
|
|
@ -5607,6 +5607,61 @@ def test_kwargs_dict_input():
|
|||
_test_helpers.assert_values_are_close(pt_model(x, batch=batch), ort_model(x_copy, batch=batch_copy))
|
||||
|
||||
|
||||
def test_named_kwargs_dict_input():
|
||||
class DictNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
|
||||
|
||||
def forward(self, *args, named_kwarg, **kwargs):
|
||||
a = named_kwarg["named_one"]
|
||||
b = named_kwarg["named_two"]["named_three"]
|
||||
c = named_kwarg["named_two"]["named_four"]
|
||||
d = named_kwarg["named_five"]["named_six"]
|
||||
e = named_kwarg["named_five"]["named_seven"]["named_eight"]
|
||||
batch = kwargs["batch"]
|
||||
f = batch["one_value"]
|
||||
g = batch["two_value"]["three_value"]
|
||||
h = batch["two_value"]["four_value"]
|
||||
i = batch["five_value"]["six_value"]
|
||||
j = batch["five_value"]["seven_value"]["eight_value"]
|
||||
return self.dummy + a + b + c + d + e + f + g + h + i + j
|
||||
|
||||
device = "cuda"
|
||||
N, D_in = 64, 784
|
||||
pt_model = DictNet().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
named_kwarg = {
|
||||
"named_one": torch.randn(N, D_in, device=device),
|
||||
"named_two": {
|
||||
"named_three": torch.randn(N, D_in, device=device),
|
||||
"named_four": torch.randn(N, D_in, device=device),
|
||||
},
|
||||
"named_five": {
|
||||
"named_six": torch.randn(N, D_in, device=device),
|
||||
"named_seven": {"named_eight": torch.randn(N, D_in, device=device)},
|
||||
},
|
||||
}
|
||||
batch = {
|
||||
"one_value": torch.randn(N, D_in, device=device),
|
||||
"two_value": {
|
||||
"three_value": torch.randn(N, D_in, device=device),
|
||||
"four_value": torch.randn(N, D_in, device=device),
|
||||
},
|
||||
"five_value": {
|
||||
"six_value": torch.randn(N, D_in, device=device),
|
||||
"seven_value": {"eight_value": torch.randn(N, D_in, device=device)},
|
||||
},
|
||||
}
|
||||
batch_copy = copy.deepcopy(batch)
|
||||
x_copy = copy.deepcopy(x)
|
||||
|
||||
_test_helpers.assert_values_are_close(
|
||||
pt_model(x, named_kwarg=named_kwarg, batch=batch), ort_model(x_copy, named_kwarg=named_kwarg, batch=batch_copy)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("training_mode", [False, True])
|
||||
def test_non_contiguous_tensors_as_inputs(training_mode):
|
||||
class NonContigousNet(torch.nn.Module):
|
||||
|
|
|
|||
Loading…
Reference in a new issue