Add *args support for ORTModule inputs (#6883)

This commit is contained in:
Sergii Dymchenko 2021-03-10 10:15:23 -08:00 committed by GitHub
parent 1e13e2666e
commit ce403eea98
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 104 additions and 38 deletions

View file

@ -1,6 +1,7 @@
from collections import abc
import copy
import functools
import inspect
import torch
import warnings
@ -159,7 +160,7 @@ def get_flattened_output_module(original_module):
return FlattenedOutputModule(original_module)
def parse_inputs_for_onnx_export(all_input_names, onnx_graph, *inputs, **kwargs):
def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, *inputs, **kwargs):
# Ignore optional inputs explicitly specified as None
# ONNX exporter may remove unused inputs
onnx_graph_input_names = []
@ -171,23 +172,44 @@ def parse_inputs_for_onnx_export(all_input_names, onnx_graph, *inputs, **kwargs)
input_names_require_grad = []
input_shape = []
for input_idx, name in enumerate(all_input_names):
inp = None
if input_idx < len(inputs) and inputs[input_idx] is not None:
inp = inputs[input_idx]
elif name in kwargs and kwargs[name] is not None:
inp = kwargs[name]
if inp is not None and (onnx_graph is None or name in onnx_graph_input_names):
if inp.requires_grad:
# input_names_require_grad holds all input tensors that have requires_grad
input_names_require_grad.append(name)
for input_idx, input_parameter in enumerate(all_input_parameters):
if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL:
# Looking at VAR_POSITIONAL parameter (*args) in the original forward method.
# All the rest positional inputs go into this parameter.
var_positional_idx = 0
for i in range(input_idx, len(inputs)):
name = f'var_positional_{input_parameter.name}{var_positional_idx}'
var_positional_idx += 1
inp = inputs[i]
if inp is not None and (onnx_graph is None or name in onnx_graph_input_names):
if inp.requires_grad:
# input_names_require_grad holds all input tensors that have requires_grad
input_names_require_grad.append(name)
input_names.append(name)
dynamic_axes[name] = {}
for dim_idx in range(len(inp.shape)):
dynamic_axes[name].update({dim_idx : f'input{input_idx}_dim{dim_idx}'})
input_names.append(name)
dynamic_axes[name] = {}
for dim_idx in range(len(inp.shape)):
dynamic_axes[name].update({dim_idx : f'input{input_idx}_dim{dim_idx}'})
input_shape.append(list(inp.size()))
input_shape.append(list(inp.size()))
else:
name = input_parameter.name
inp = None
if input_idx < len(inputs) and inputs[input_idx] is not None:
inp = inputs[input_idx]
elif name in kwargs and kwargs[name] is not None:
inp = kwargs[name]
if inp is not None and (onnx_graph is None or name in onnx_graph_input_names):
if inp.requires_grad:
# input_names_require_grad holds all input tensors that have requires_grad
input_names_require_grad.append(name)
input_names.append(name)
dynamic_axes[name] = {}
for dim_idx in range(len(inp.shape)):
dynamic_axes[name].update({dim_idx : f'input{input_idx}_dim{dim_idx}'})
input_shape.append(list(inp.size()))
return input_names, dynamic_axes, input_names_require_grad, input_shape
def parse_outputs_for_onnx_export_and_extract_output_schema(module, inputs, kwargs):

View file

@ -4,6 +4,7 @@ import logging
import onnx
import onnxruntime
import torch
import inspect
from inspect import signature
from torch.utils.dlpack import from_dlpack, to_dlpack
@ -85,10 +86,10 @@ class ORTModule(torch.nn.Module):
'''Forward pass starts here and continues at `_ORTModuleFunction.forward`
ONNX model is exported the first time this method is executed.
Next, we build a full training graph with module_gradient_graph_builder.
Next, we build a full training graph with module_gradient_graph_builder.
Finally, we instantiate the ONNX Runtime InferenceSession.
'''
# TODO: using pytorch for evaluation for now. We will use ORT for evaluation latter.
# TODO: using pytorch for evaluation for now. We will use ORT for evaluation later.
if not self._is_training:
return self._original_module(*inputs, **kwargs)
@ -103,7 +104,7 @@ class ORTModule(torch.nn.Module):
_, _, input_names_require_grad, new_input_shape = \
_ortmodule_output_transformation.parse_inputs_for_onnx_export(
self._original_module_input_names, self._onnx_inference, *inputs, **kwargs)
self._original_module_parameters, self._onnx_inference, *inputs, **kwargs)
# If inputs requiring gradient change from one call to forward to the next, the module_gradient_graph_builder
# needs to be reinitialized so it can compute the backward output for the new inputs that require_grad
if input_names_require_grad != self._input_names_require_grad:
@ -229,8 +230,13 @@ class ORTModule(torch.nn.Module):
# Get the module that flattens the output from the original module into a tuple
self._flattened_output_module = \
_ortmodule_output_transformation.get_flattened_output_module(self._original_module)
sig = signature(self._original_module.forward)
self._original_module_input_names = sig.parameters.keys()
self._original_module_parameters = signature(self._original_module.forward).parameters.values()
# TODO: remove after PyTorch ONNX exporter supports VAR_KEYWORD parameters.
for input_parameter in self._original_module_parameters:
if input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
raise NotImplementedError("The model's forward method has **kwargs parameter which is currently not supported.")
self._onnx_inference = None
self._is_training = True
@ -345,20 +351,24 @@ class ORTModule(torch.nn.Module):
TODO: How IO binding model inputs and outputs affects initializer copies?
ONNX Runtime forward requires an order list of:
ONNX Runtime forward requires an ordered list of:
* User input: computed from forward InferenceSession
* Initializers: computed from original PyTorch model parameters
'''
# User inputs
non_none_inputs = [inp for inp in inputs if inp is not None]
result = []
for input_idx, name in enumerate(self._original_module_input_names):
for input_idx, name in enumerate(self._onnx_graphs_info.user_input_names):
inp = None
if input_idx < len(inputs) and inputs[input_idx] is not None:
inp = inputs[input_idx]
if input_idx < len(non_none_inputs):
inp = non_none_inputs[input_idx]
elif name in kwargs and kwargs[name] is not None:
inp = kwargs[name]
if inp is not None and name in self._onnx_graphs_info.user_input_names:
if inp is not None:
result.append(inp)
else:
# TODO: Re-export ONNX if any input from _onnx_graphs_info.user_input_names is None.
raise RuntimeError(f'Input is present in ONNX graph but not provided: {name}.')
# Initializers
for param in self._flattened_output_module.named_parameters():
@ -370,13 +380,12 @@ class ORTModule(torch.nn.Module):
'''Exports PyTorch `module` to ONNX with training flag, using `*inputs` as input
TODO: How to support dynamic axes? Dimensions are determined by samples
TODO: How to ingest **kwargs in proper order during export?
'''
# Setup dynamic axes for onnx model
input_names, dynamic_axes, self._input_names_require_grad, _ = \
_ortmodule_output_transformation.parse_inputs_for_onnx_export(
self._original_module_input_names, None, *inputs, **kwargs)
self._original_module_parameters, None, *inputs, **kwargs)
output_names, output_dynamic_axes, self._original_module_output_schema = \
_ortmodule_output_transformation.parse_outputs_for_onnx_export_and_extract_output_schema(
self._original_module, inputs, kwargs)

View file

@ -83,6 +83,21 @@ class NeuralNetMultiplePositionalArguments(torch.nn.Module):
out = self.fc2(out)
return out
class NeuralNetMultiplePositionalArgumentsVarKeyword(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNetMultiplePositionalArgumentsVarKeyword, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
def forward(self, input1, input2, **kwargs):
model_input = input1 + input2
out = self.fc1(model_input)
out = self.relu(out)
out = self.fc2(out)
return out
class NeuralNetPositionalArguments(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNetPositionalArguments, self).__init__()
@ -218,18 +233,38 @@ def test_forward_call_multiple_positional_arguments():
output = ort_model(x, y)
assert output is not None
# TODO: Re-enable after "Support models with dynamically defined inputs" done.
# def test_forward_call_positional_arguments():
# device = 'cuda'
def test_forward_call_multiple_positional_arguments_var_keyword():
device = 'cuda'
# N, D_in, H, D_out = 64, 784, 500, 10
# model = NeuralNetPositionalArguments(input_size=D_in, hidden_size=H, num_classes=D_out).to(device)
# model = ORTModule(model)
# args = [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)]
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetMultiplePositionalArgumentsVarKeyword(input_size=D_in, hidden_size=H, num_classes=D_out).to(device)
# # Make sure model runs without any exception
# output = model(*args)
# assert output is not None
# TODO: remove exception check and uncomment the rest of the test when
# PyTorch ONNX exporter supports **kwargs.
with pytest.raises(NotImplementedError) as runtime_error:
ort_model = ORTModule(model)
assert '**kwargs' in str(runtime_error.value)
# # Check that the original forward signature is preserved.
# assert signature(model.forward) == signature(ort_model.forward)
# x = torch.randn(N, D_in, device=device)
# y = torch.randn(N, D_in, device=device)
# # Make sure model runs without any exception
# output = ort_model(x, y)
# assert output is not None
def test_forward_call_positional_arguments():
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetPositionalArguments(input_size=D_in, hidden_size=H, num_classes=D_out).to(device)
model = ORTModule(model)
args = [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)]
# Make sure model runs without any exception
output = model(*args)
assert output is not None
def test_forward_call_keyword_arguments():
device = 'cuda'