mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Support lists as inputs to ORTModule (#8311)
This commit is contained in:
parent
9a855fe9e7
commit
6652d17dcd
2 changed files with 113 additions and 1 deletions
|
|
@ -129,8 +129,23 @@ def _combine_input_buffers_initializers(params, onnx_input_names, input_info, bu
|
|||
* Initializers: computed from original PyTorch model parameters
|
||||
'''
|
||||
|
||||
def _expand_inputs(current_input, non_none_inputs):
|
||||
# The exporter handles input lists by expanding them so that each
|
||||
# element of the list is its own input.
|
||||
# ORTModule must match this behavior by also expanding the inputs.
|
||||
if isinstance(current_input, abc.Sequence):
|
||||
# If the input is a sequence (like a list), expand the list so that
|
||||
# each element of the list is an input by itself
|
||||
for inp in current_input:
|
||||
_expand_inputs(inp, non_none_inputs)
|
||||
elif current_input is not None:
|
||||
# else just collect all the non none inputs within non_none_inputs
|
||||
non_none_inputs.append(current_input)
|
||||
|
||||
|
||||
# User inputs
|
||||
non_none_inputs = [inp for inp in inputs if inp is not None]
|
||||
non_none_inputs = []
|
||||
_expand_inputs(inputs, non_none_inputs)
|
||||
buffer_names_dict = {buffer_name: inp for buffer_name, inp in buffer_names}
|
||||
result = []
|
||||
|
||||
|
|
@ -398,6 +413,18 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, inputs, kwarg
|
|||
# Drop all None inputs.
|
||||
return
|
||||
|
||||
if isinstance(input, abc.Sequence):
|
||||
# If the input is a sequence (like a list), expand the list so that
|
||||
# each element of the list is an input by itself.
|
||||
for i, val in enumerate(input):
|
||||
# Name each input with the index appended to the original name of the
|
||||
# argument.
|
||||
_add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names)
|
||||
|
||||
# Return here since the list by itself is not a valid input.
|
||||
# All the elements of the list have already been added as inputs individually.
|
||||
return
|
||||
|
||||
# InputInfo should contain all the names irrespective of whether they are
|
||||
# a part of the onnx graph or not.
|
||||
input_names.append(name)
|
||||
|
|
|
|||
|
|
@ -2823,3 +2823,88 @@ def test_input_with_string_exception():
|
|||
with pytest.raises(TypeError) as ex_info:
|
||||
_ = model(torch.randn(1, 2), 'hello')
|
||||
assert "ORTModule does not support the following model data type <class 'str'>" in str(ex_info.value)
|
||||
|
||||
def test_ortmodule_list_input():
|
||||
class ListNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(ListNet, self).__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
|
||||
|
||||
def forward(self, batch):
|
||||
a = batch[0]
|
||||
b = batch[1]
|
||||
return self.dummy + a + b
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
pt_model = ListNet().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
x = [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)]
|
||||
y = copy.deepcopy(x)
|
||||
|
||||
_test_helpers.assert_values_are_close(pt_model(x), ort_model(y))
|
||||
|
||||
def test_ortmodule_list_input_with_unused_values():
|
||||
class ListNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(ListNet, self).__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
|
||||
|
||||
def forward(self, batch):
|
||||
a = batch[0]
|
||||
b = batch[1]
|
||||
return self.dummy + b
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
pt_model = ListNet().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
x = [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)]
|
||||
y = copy.deepcopy(x)
|
||||
|
||||
_test_helpers.assert_values_are_close(pt_model(x), ort_model(y))
|
||||
|
||||
def test_ortmodule_list_input_with_none_values():
|
||||
class ListNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(ListNet, self).__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
|
||||
|
||||
def forward(self, batch):
|
||||
a = batch[0] if batch[0] is not None else torch.FloatTensor([2]).cuda()
|
||||
b = batch[1]
|
||||
return self.dummy + a + b
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
pt_model = ListNet().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
x = [None, torch.randn(N, D_in, device=device)]
|
||||
y = copy.deepcopy(x)
|
||||
|
||||
_test_helpers.assert_values_are_close(pt_model(x), ort_model(y))
|
||||
|
||||
def test_ortmodule_nested_list_input():
|
||||
class ListNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(ListNet, self).__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
|
||||
|
||||
def forward(self, batch):
|
||||
a = batch[0]
|
||||
b = batch[1][0]
|
||||
c = batch[1][1]
|
||||
d = batch[2][0]
|
||||
e = batch[2][1][0]
|
||||
return self.dummy + a + b + c + d + e
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
pt_model = ListNet().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
x = [torch.randn(N, D_in, device=device),
|
||||
[torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)],
|
||||
[torch.randn(N, D_in, device=device), [torch.randn(N, D_in, device=device)]]]
|
||||
y = copy.deepcopy(x)
|
||||
|
||||
_test_helpers.assert_values_are_close(pt_model(x), ort_model(y))
|
||||
|
|
|
|||
Loading…
Reference in a new issue