Support lists as inputs to ORTModule (#8311)

This commit is contained in:
baijumeswani 2021-07-07 13:04:19 -07:00 committed by GitHub
parent 9a855fe9e7
commit 6652d17dcd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 113 additions and 1 deletions

View file

@ -129,8 +129,23 @@ def _combine_input_buffers_initializers(params, onnx_input_names, input_info, bu
* Initializers: computed from original PyTorch model parameters
'''
def _expand_inputs(current_input, non_none_inputs):
# The exporter handles input lists by expanding them so that each
# element of the list is its own input.
# ORTModule must match this behavior by also expanding the inputs.
if isinstance(current_input, abc.Sequence):
# If the input is a sequence (like a list), expand the list so that
# each element of the list is an input by itself
for inp in current_input:
_expand_inputs(inp, non_none_inputs)
elif current_input is not None:
# else just collect all the non none inputs within non_none_inputs
non_none_inputs.append(current_input)
# User inputs
non_none_inputs = [inp for inp in inputs if inp is not None]
non_none_inputs = []
_expand_inputs(inputs, non_none_inputs)
buffer_names_dict = {buffer_name: inp for buffer_name, inp in buffer_names}
result = []
@ -398,6 +413,18 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, inputs, kwarg
# Drop all None inputs.
return
if isinstance(input, abc.Sequence):
# If the input is a sequence (like a list), expand the list so that
# each element of the list is an input by itself.
for i, val in enumerate(input):
# Name each input with the index appended to the original name of the
# argument.
_add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names)
# Return here since the list by itself is not a valid input.
# All the elements of the list have already been added as inputs individually.
return
# InputInfo should contain all the names irrespective of whether they are
# a part of the onnx graph or not.
input_names.append(name)

View file

@ -2823,3 +2823,88 @@ def test_input_with_string_exception():
with pytest.raises(TypeError) as ex_info:
_ = model(torch.randn(1, 2), 'hello')
assert "ORTModule does not support the following model data type <class 'str'>" in str(ex_info.value)
def test_ortmodule_list_input():
class ListNet(torch.nn.Module):
def __init__(self):
super(ListNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
def forward(self, batch):
a = batch[0]
b = batch[1]
return self.dummy + a + b
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
pt_model = ListNet().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
x = [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)]
y = copy.deepcopy(x)
_test_helpers.assert_values_are_close(pt_model(x), ort_model(y))
def test_ortmodule_list_input_with_unused_values():
class ListNet(torch.nn.Module):
def __init__(self):
super(ListNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
def forward(self, batch):
a = batch[0]
b = batch[1]
return self.dummy + b
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
pt_model = ListNet().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
x = [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)]
y = copy.deepcopy(x)
_test_helpers.assert_values_are_close(pt_model(x), ort_model(y))
def test_ortmodule_list_input_with_none_values():
class ListNet(torch.nn.Module):
def __init__(self):
super(ListNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
def forward(self, batch):
a = batch[0] if batch[0] is not None else torch.FloatTensor([2]).cuda()
b = batch[1]
return self.dummy + a + b
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
pt_model = ListNet().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
x = [None, torch.randn(N, D_in, device=device)]
y = copy.deepcopy(x)
_test_helpers.assert_values_are_close(pt_model(x), ort_model(y))
def test_ortmodule_nested_list_input():
class ListNet(torch.nn.Module):
def __init__(self):
super(ListNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
def forward(self, batch):
a = batch[0]
b = batch[1][0]
c = batch[1][1]
d = batch[2][0]
e = batch[2][1][0]
return self.dummy + a + b + c + d + e
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
pt_model = ListNet().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
x = [torch.randn(N, D_in, device=device),
[torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)],
[torch.randn(N, D_in, device=device), [torch.randn(N, D_in, device=device)]]]
y = copy.deepcopy(x)
_test_helpers.assert_values_are_close(pt_model(x), ort_model(y))