From 6652d17dcd252abd3bbd2d3077b46433617f202e Mon Sep 17 00:00:00 2001 From: baijumeswani Date: Wed, 7 Jul 2021 13:04:19 -0700 Subject: [PATCH] Support lists as inputs to ORTModule (#8311) --- .../python/training/ortmodule/_io.py | 29 ++++++- .../python/orttraining_test_ortmodule_api.py | 85 +++++++++++++++++++ 2 files changed, 113 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 086e2fd73c..2f2e767333 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -129,8 +129,23 @@ def _combine_input_buffers_initializers(params, onnx_input_names, input_info, bu * Initializers: computed from original PyTorch model parameters ''' + def _expand_inputs(current_input, non_none_inputs): + # The exporter handles input lists by expanding them so that each + # element of the list is its own input. + # ORTModule must match this behavior by also expanding the inputs. + if isinstance(current_input, abc.Sequence): + # If the input is a sequence (like a list), expand the list so that + # each element of the list is an input by itself + for inp in current_input: + _expand_inputs(inp, non_none_inputs) + elif current_input is not None: + # else just collect all the non none inputs within non_none_inputs + non_none_inputs.append(current_input) + + # User inputs - non_none_inputs = [inp for inp in inputs if inp is not None] + non_none_inputs = [] + _expand_inputs(inputs, non_none_inputs) buffer_names_dict = {buffer_name: inp for buffer_name, inp in buffer_names} result = [] @@ -398,6 +413,18 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, inputs, kwarg # Drop all None inputs. return + if isinstance(input, abc.Sequence): + # If the input is a sequence (like a list), expand the list so that + # each element of the list is an input by itself. + for i, val in enumerate(input): + # Name each input with the index appended to the original name of the + # argument. + _add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names) + + # Return here since the list by itself is not a valid input. + # All the elements of the list have already been added as inputs individually. + return + # InputInfo should contain all the names irrespective of whether they are # a part of the onnx graph or not. input_names.append(name) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index ac73fd06fc..94b2d973b2 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -2823,3 +2823,88 @@ def test_input_with_string_exception(): with pytest.raises(TypeError) as ex_info: _ = model(torch.randn(1, 2), 'hello') assert "ORTModule does not support the following model data type " in str(ex_info.value) + +def test_ortmodule_list_input(): + class ListNet(torch.nn.Module): + def __init__(self): + super(ListNet, self).__init__() + self.dummy = torch.nn.Parameter(torch.FloatTensor([0])) + + def forward(self, batch): + a = batch[0] + b = batch[1] + return self.dummy + a + b + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + pt_model = ListNet().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + x = [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)] + y = copy.deepcopy(x) + + _test_helpers.assert_values_are_close(pt_model(x), ort_model(y)) + +def test_ortmodule_list_input_with_unused_values(): + class ListNet(torch.nn.Module): + def __init__(self): + super(ListNet, self).__init__() + self.dummy = torch.nn.Parameter(torch.FloatTensor([0])) + + def forward(self, batch): + a = batch[0] + b = batch[1] + return self.dummy + b + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + pt_model = ListNet().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + x = [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)] + y = copy.deepcopy(x) + + _test_helpers.assert_values_are_close(pt_model(x), ort_model(y)) + +def test_ortmodule_list_input_with_none_values(): + class ListNet(torch.nn.Module): + def __init__(self): + super(ListNet, self).__init__() + self.dummy = torch.nn.Parameter(torch.FloatTensor([0])) + + def forward(self, batch): + a = batch[0] if batch[0] is not None else torch.FloatTensor([2]).cuda() + b = batch[1] + return self.dummy + a + b + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + pt_model = ListNet().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + x = [None, torch.randn(N, D_in, device=device)] + y = copy.deepcopy(x) + + _test_helpers.assert_values_are_close(pt_model(x), ort_model(y)) + +def test_ortmodule_nested_list_input(): + class ListNet(torch.nn.Module): + def __init__(self): + super(ListNet, self).__init__() + self.dummy = torch.nn.Parameter(torch.FloatTensor([0])) + + def forward(self, batch): + a = batch[0] + b = batch[1][0] + c = batch[1][1] + d = batch[2][0] + e = batch[2][1][0] + return self.dummy + a + b + c + d + e + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + pt_model = ListNet().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + x = [torch.randn(N, D_in, device=device), + [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)], + [torch.randn(N, D_in, device=device), [torch.randn(N, D_in, device=device)]]] + y = copy.deepcopy(x) + + _test_helpers.assert_values_are_close(pt_model(x), ort_model(y))