From ce403eea98fc408f66c8242e46c128e62569899e Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Wed, 10 Mar 2021 10:15:23 -0800 Subject: [PATCH] Add *args support for ORTModule inputs (#6883) --- .../_ortmodule_output_transformation.py | 54 ++++++++++++------ .../orttraining/python/training/ortmodule.py | 33 +++++++---- .../python/orttraining_test_ortmodule_api.py | 55 +++++++++++++++---- 3 files changed, 104 insertions(+), 38 deletions(-) diff --git a/orttraining/orttraining/python/training/_ortmodule_output_transformation.py b/orttraining/orttraining/python/training/_ortmodule_output_transformation.py index 9644a382e9..0e28b7e76a 100644 --- a/orttraining/orttraining/python/training/_ortmodule_output_transformation.py +++ b/orttraining/orttraining/python/training/_ortmodule_output_transformation.py @@ -1,6 +1,7 @@ from collections import abc import copy import functools +import inspect import torch import warnings @@ -159,7 +160,7 @@ def get_flattened_output_module(original_module): return FlattenedOutputModule(original_module) -def parse_inputs_for_onnx_export(all_input_names, onnx_graph, *inputs, **kwargs): +def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, *inputs, **kwargs): # Ignore optional inputs explicitly specified as None # ONNX exporter may remove unused inputs onnx_graph_input_names = [] @@ -171,23 +172,44 @@ def parse_inputs_for_onnx_export(all_input_names, onnx_graph, *inputs, **kwargs) input_names_require_grad = [] input_shape = [] - for input_idx, name in enumerate(all_input_names): - inp = None - if input_idx < len(inputs) and inputs[input_idx] is not None: - inp = inputs[input_idx] - elif name in kwargs and kwargs[name] is not None: - inp = kwargs[name] - if inp is not None and (onnx_graph is None or name in onnx_graph_input_names): - if inp.requires_grad: - # input_names_require_grad holds all input tensors that have requires_grad - input_names_require_grad.append(name) + for input_idx, input_parameter in enumerate(all_input_parameters): + if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL: + # Looking at VAR_POSITIONAL parameter (*args) in the original forward method. + # All the rest positional inputs go into this parameter. + var_positional_idx = 0 + for i in range(input_idx, len(inputs)): + name = f'var_positional_{input_parameter.name}{var_positional_idx}' + var_positional_idx += 1 + inp = inputs[i] + if inp is not None and (onnx_graph is None or name in onnx_graph_input_names): + if inp.requires_grad: + # input_names_require_grad holds all input tensors that have requires_grad + input_names_require_grad.append(name) - input_names.append(name) - dynamic_axes[name] = {} - for dim_idx in range(len(inp.shape)): - dynamic_axes[name].update({dim_idx : f'input{input_idx}_dim{dim_idx}'}) + input_names.append(name) + dynamic_axes[name] = {} + for dim_idx in range(len(inp.shape)): + dynamic_axes[name].update({dim_idx : f'input{input_idx}_dim{dim_idx}'}) - input_shape.append(list(inp.size())) + input_shape.append(list(inp.size())) + else: + name = input_parameter.name + inp = None + if input_idx < len(inputs) and inputs[input_idx] is not None: + inp = inputs[input_idx] + elif name in kwargs and kwargs[name] is not None: + inp = kwargs[name] + if inp is not None and (onnx_graph is None or name in onnx_graph_input_names): + if inp.requires_grad: + # input_names_require_grad holds all input tensors that have requires_grad + input_names_require_grad.append(name) + + input_names.append(name) + dynamic_axes[name] = {} + for dim_idx in range(len(inp.shape)): + dynamic_axes[name].update({dim_idx : f'input{input_idx}_dim{dim_idx}'}) + + input_shape.append(list(inp.size())) return input_names, dynamic_axes, input_names_require_grad, input_shape def parse_outputs_for_onnx_export_and_extract_output_schema(module, inputs, kwargs): diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index 1487757a11..70b0ac97c6 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -4,6 +4,7 @@ import logging import onnx import onnxruntime import torch +import inspect from inspect import signature from torch.utils.dlpack import from_dlpack, to_dlpack @@ -85,10 +86,10 @@ class ORTModule(torch.nn.Module): '''Forward pass starts here and continues at `_ORTModuleFunction.forward` ONNX model is exported the first time this method is executed. - Next, we build a full training graph with module_gradient_graph_builder. + Next, we build a full training graph with module_gradient_graph_builder. Finally, we instantiate the ONNX Runtime InferenceSession. ''' - # TODO: using pytorch for evaluation for now. We will use ORT for evaluation latter. + # TODO: using pytorch for evaluation for now. We will use ORT for evaluation later. if not self._is_training: return self._original_module(*inputs, **kwargs) @@ -103,7 +104,7 @@ class ORTModule(torch.nn.Module): _, _, input_names_require_grad, new_input_shape = \ _ortmodule_output_transformation.parse_inputs_for_onnx_export( - self._original_module_input_names, self._onnx_inference, *inputs, **kwargs) + self._original_module_parameters, self._onnx_inference, *inputs, **kwargs) # If inputs requiring gradient change from one call to forward to the next, the module_gradient_graph_builder # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad if input_names_require_grad != self._input_names_require_grad: @@ -229,8 +230,13 @@ class ORTModule(torch.nn.Module): # Get the module that flattens the output from the original module into a tuple self._flattened_output_module = \ _ortmodule_output_transformation.get_flattened_output_module(self._original_module) - sig = signature(self._original_module.forward) - self._original_module_input_names = sig.parameters.keys() + self._original_module_parameters = signature(self._original_module.forward).parameters.values() + + # TODO: remove after PyTorch ONNX exporter supports VAR_KEYWORD parameters. + for input_parameter in self._original_module_parameters: + if input_parameter.kind == inspect.Parameter.VAR_KEYWORD: + raise NotImplementedError("The model's forward method has **kwargs parameter which is currently not supported.") + self._onnx_inference = None self._is_training = True @@ -345,20 +351,24 @@ class ORTModule(torch.nn.Module): TODO: How IO binding model inputs and outputs affects initializer copies? - ONNX Runtime forward requires an order list of: + ONNX Runtime forward requires an ordered list of: * User input: computed from forward InferenceSession * Initializers: computed from original PyTorch model parameters ''' # User inputs + non_none_inputs = [inp for inp in inputs if inp is not None] result = [] - for input_idx, name in enumerate(self._original_module_input_names): + for input_idx, name in enumerate(self._onnx_graphs_info.user_input_names): inp = None - if input_idx < len(inputs) and inputs[input_idx] is not None: - inp = inputs[input_idx] + if input_idx < len(non_none_inputs): + inp = non_none_inputs[input_idx] elif name in kwargs and kwargs[name] is not None: inp = kwargs[name] - if inp is not None and name in self._onnx_graphs_info.user_input_names: + if inp is not None: result.append(inp) + else: + # TODO: Re-export ONNX if any input from _onnx_graphs_info.user_input_names is None. + raise RuntimeError(f'Input is present in ONNX graph but not provided: {name}.') # Initializers for param in self._flattened_output_module.named_parameters(): @@ -370,13 +380,12 @@ class ORTModule(torch.nn.Module): '''Exports PyTorch `module` to ONNX with training flag, using `*inputs` as input TODO: How to support dynamic axes? Dimensions are determined by samples - TODO: How to ingest **kwargs in proper order during export? ''' # Setup dynamic axes for onnx model input_names, dynamic_axes, self._input_names_require_grad, _ = \ _ortmodule_output_transformation.parse_inputs_for_onnx_export( - self._original_module_input_names, None, *inputs, **kwargs) + self._original_module_parameters, None, *inputs, **kwargs) output_names, output_dynamic_axes, self._original_module_output_schema = \ _ortmodule_output_transformation.parse_outputs_for_onnx_export_and_extract_output_schema( self._original_module, inputs, kwargs) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index d611253c80..4044a392b9 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -83,6 +83,21 @@ class NeuralNetMultiplePositionalArguments(torch.nn.Module): out = self.fc2(out) return out +class NeuralNetMultiplePositionalArgumentsVarKeyword(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(NeuralNetMultiplePositionalArgumentsVarKeyword, self).__init__() + + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, input1, input2, **kwargs): + model_input = input1 + input2 + out = self.fc1(model_input) + out = self.relu(out) + out = self.fc2(out) + return out + class NeuralNetPositionalArguments(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNetPositionalArguments, self).__init__() @@ -218,18 +233,38 @@ def test_forward_call_multiple_positional_arguments(): output = ort_model(x, y) assert output is not None -# TODO: Re-enable after "Support models with dynamically defined inputs" done. -# def test_forward_call_positional_arguments(): -# device = 'cuda' +def test_forward_call_multiple_positional_arguments_var_keyword(): + device = 'cuda' -# N, D_in, H, D_out = 64, 784, 500, 10 -# model = NeuralNetPositionalArguments(input_size=D_in, hidden_size=H, num_classes=D_out).to(device) -# model = ORTModule(model) -# args = [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)] + N, D_in, H, D_out = 64, 784, 500, 10 + model = NeuralNetMultiplePositionalArgumentsVarKeyword(input_size=D_in, hidden_size=H, num_classes=D_out).to(device) -# # Make sure model runs without any exception -# output = model(*args) -# assert output is not None + # TODO: remove exception check and uncomment the rest of the test when + # PyTorch ONNX exporter supports **kwargs. + with pytest.raises(NotImplementedError) as runtime_error: + ort_model = ORTModule(model) + assert '**kwargs' in str(runtime_error.value) + + # # Check that the original forward signature is preserved. + # assert signature(model.forward) == signature(ort_model.forward) + # x = torch.randn(N, D_in, device=device) + # y = torch.randn(N, D_in, device=device) + + # # Make sure model runs without any exception + # output = ort_model(x, y) + # assert output is not None + +def test_forward_call_positional_arguments(): + device = 'cuda' + + N, D_in, H, D_out = 64, 784, 500, 10 + model = NeuralNetPositionalArguments(input_size=D_in, hidden_size=H, num_classes=D_out).to(device) + model = ORTModule(model) + args = [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)] + + # Make sure model runs without any exception + output = model(*args) + assert output is not None def test_forward_call_keyword_arguments(): device = 'cuda'