mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
Add *args support for ORTModule inputs (#6883)
This commit is contained in:
parent
1e13e2666e
commit
ce403eea98
3 changed files with 104 additions and 38 deletions
|
|
@ -1,6 +1,7 @@
|
|||
from collections import abc
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
|
|
@ -159,7 +160,7 @@ def get_flattened_output_module(original_module):
|
|||
|
||||
return FlattenedOutputModule(original_module)
|
||||
|
||||
def parse_inputs_for_onnx_export(all_input_names, onnx_graph, *inputs, **kwargs):
|
||||
def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, *inputs, **kwargs):
|
||||
# Ignore optional inputs explicitly specified as None
|
||||
# ONNX exporter may remove unused inputs
|
||||
onnx_graph_input_names = []
|
||||
|
|
@ -171,23 +172,44 @@ def parse_inputs_for_onnx_export(all_input_names, onnx_graph, *inputs, **kwargs)
|
|||
input_names_require_grad = []
|
||||
input_shape = []
|
||||
|
||||
for input_idx, name in enumerate(all_input_names):
|
||||
inp = None
|
||||
if input_idx < len(inputs) and inputs[input_idx] is not None:
|
||||
inp = inputs[input_idx]
|
||||
elif name in kwargs and kwargs[name] is not None:
|
||||
inp = kwargs[name]
|
||||
if inp is not None and (onnx_graph is None or name in onnx_graph_input_names):
|
||||
if inp.requires_grad:
|
||||
# input_names_require_grad holds all input tensors that have requires_grad
|
||||
input_names_require_grad.append(name)
|
||||
for input_idx, input_parameter in enumerate(all_input_parameters):
|
||||
if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL:
|
||||
# Looking at VAR_POSITIONAL parameter (*args) in the original forward method.
|
||||
# All the rest positional inputs go into this parameter.
|
||||
var_positional_idx = 0
|
||||
for i in range(input_idx, len(inputs)):
|
||||
name = f'var_positional_{input_parameter.name}{var_positional_idx}'
|
||||
var_positional_idx += 1
|
||||
inp = inputs[i]
|
||||
if inp is not None and (onnx_graph is None or name in onnx_graph_input_names):
|
||||
if inp.requires_grad:
|
||||
# input_names_require_grad holds all input tensors that have requires_grad
|
||||
input_names_require_grad.append(name)
|
||||
|
||||
input_names.append(name)
|
||||
dynamic_axes[name] = {}
|
||||
for dim_idx in range(len(inp.shape)):
|
||||
dynamic_axes[name].update({dim_idx : f'input{input_idx}_dim{dim_idx}'})
|
||||
input_names.append(name)
|
||||
dynamic_axes[name] = {}
|
||||
for dim_idx in range(len(inp.shape)):
|
||||
dynamic_axes[name].update({dim_idx : f'input{input_idx}_dim{dim_idx}'})
|
||||
|
||||
input_shape.append(list(inp.size()))
|
||||
input_shape.append(list(inp.size()))
|
||||
else:
|
||||
name = input_parameter.name
|
||||
inp = None
|
||||
if input_idx < len(inputs) and inputs[input_idx] is not None:
|
||||
inp = inputs[input_idx]
|
||||
elif name in kwargs and kwargs[name] is not None:
|
||||
inp = kwargs[name]
|
||||
if inp is not None and (onnx_graph is None or name in onnx_graph_input_names):
|
||||
if inp.requires_grad:
|
||||
# input_names_require_grad holds all input tensors that have requires_grad
|
||||
input_names_require_grad.append(name)
|
||||
|
||||
input_names.append(name)
|
||||
dynamic_axes[name] = {}
|
||||
for dim_idx in range(len(inp.shape)):
|
||||
dynamic_axes[name].update({dim_idx : f'input{input_idx}_dim{dim_idx}'})
|
||||
|
||||
input_shape.append(list(inp.size()))
|
||||
return input_names, dynamic_axes, input_names_require_grad, input_shape
|
||||
|
||||
def parse_outputs_for_onnx_export_and_extract_output_schema(module, inputs, kwargs):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import logging
|
|||
import onnx
|
||||
import onnxruntime
|
||||
import torch
|
||||
import inspect
|
||||
from inspect import signature
|
||||
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
|
|
@ -85,10 +86,10 @@ class ORTModule(torch.nn.Module):
|
|||
'''Forward pass starts here and continues at `_ORTModuleFunction.forward`
|
||||
|
||||
ONNX model is exported the first time this method is executed.
|
||||
Next, we build a full training graph with module_gradient_graph_builder.
|
||||
Next, we build a full training graph with module_gradient_graph_builder.
|
||||
Finally, we instantiate the ONNX Runtime InferenceSession.
|
||||
'''
|
||||
# TODO: using pytorch for evaluation for now. We will use ORT for evaluation latter.
|
||||
# TODO: using pytorch for evaluation for now. We will use ORT for evaluation later.
|
||||
if not self._is_training:
|
||||
return self._original_module(*inputs, **kwargs)
|
||||
|
||||
|
|
@ -103,7 +104,7 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
_, _, input_names_require_grad, new_input_shape = \
|
||||
_ortmodule_output_transformation.parse_inputs_for_onnx_export(
|
||||
self._original_module_input_names, self._onnx_inference, *inputs, **kwargs)
|
||||
self._original_module_parameters, self._onnx_inference, *inputs, **kwargs)
|
||||
# If inputs requiring gradient change from one call to forward to the next, the module_gradient_graph_builder
|
||||
# needs to be reinitialized so it can compute the backward output for the new inputs that require_grad
|
||||
if input_names_require_grad != self._input_names_require_grad:
|
||||
|
|
@ -229,8 +230,13 @@ class ORTModule(torch.nn.Module):
|
|||
# Get the module that flattens the output from the original module into a tuple
|
||||
self._flattened_output_module = \
|
||||
_ortmodule_output_transformation.get_flattened_output_module(self._original_module)
|
||||
sig = signature(self._original_module.forward)
|
||||
self._original_module_input_names = sig.parameters.keys()
|
||||
self._original_module_parameters = signature(self._original_module.forward).parameters.values()
|
||||
|
||||
# TODO: remove after PyTorch ONNX exporter supports VAR_KEYWORD parameters.
|
||||
for input_parameter in self._original_module_parameters:
|
||||
if input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
raise NotImplementedError("The model's forward method has **kwargs parameter which is currently not supported.")
|
||||
|
||||
self._onnx_inference = None
|
||||
self._is_training = True
|
||||
|
||||
|
|
@ -345,20 +351,24 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
TODO: How IO binding model inputs and outputs affects initializer copies?
|
||||
|
||||
ONNX Runtime forward requires an order list of:
|
||||
ONNX Runtime forward requires an ordered list of:
|
||||
* User input: computed from forward InferenceSession
|
||||
* Initializers: computed from original PyTorch model parameters
|
||||
'''
|
||||
# User inputs
|
||||
non_none_inputs = [inp for inp in inputs if inp is not None]
|
||||
result = []
|
||||
for input_idx, name in enumerate(self._original_module_input_names):
|
||||
for input_idx, name in enumerate(self._onnx_graphs_info.user_input_names):
|
||||
inp = None
|
||||
if input_idx < len(inputs) and inputs[input_idx] is not None:
|
||||
inp = inputs[input_idx]
|
||||
if input_idx < len(non_none_inputs):
|
||||
inp = non_none_inputs[input_idx]
|
||||
elif name in kwargs and kwargs[name] is not None:
|
||||
inp = kwargs[name]
|
||||
if inp is not None and name in self._onnx_graphs_info.user_input_names:
|
||||
if inp is not None:
|
||||
result.append(inp)
|
||||
else:
|
||||
# TODO: Re-export ONNX if any input from _onnx_graphs_info.user_input_names is None.
|
||||
raise RuntimeError(f'Input is present in ONNX graph but not provided: {name}.')
|
||||
|
||||
# Initializers
|
||||
for param in self._flattened_output_module.named_parameters():
|
||||
|
|
@ -370,13 +380,12 @@ class ORTModule(torch.nn.Module):
|
|||
'''Exports PyTorch `module` to ONNX with training flag, using `*inputs` as input
|
||||
|
||||
TODO: How to support dynamic axes? Dimensions are determined by samples
|
||||
TODO: How to ingest **kwargs in proper order during export?
|
||||
'''
|
||||
|
||||
# Setup dynamic axes for onnx model
|
||||
input_names, dynamic_axes, self._input_names_require_grad, _ = \
|
||||
_ortmodule_output_transformation.parse_inputs_for_onnx_export(
|
||||
self._original_module_input_names, None, *inputs, **kwargs)
|
||||
self._original_module_parameters, None, *inputs, **kwargs)
|
||||
output_names, output_dynamic_axes, self._original_module_output_schema = \
|
||||
_ortmodule_output_transformation.parse_outputs_for_onnx_export_and_extract_output_schema(
|
||||
self._original_module, inputs, kwargs)
|
||||
|
|
|
|||
|
|
@ -83,6 +83,21 @@ class NeuralNetMultiplePositionalArguments(torch.nn.Module):
|
|||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
class NeuralNetMultiplePositionalArgumentsVarKeyword(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(NeuralNetMultiplePositionalArgumentsVarKeyword, self).__init__()
|
||||
|
||||
self.fc1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, input1, input2, **kwargs):
|
||||
model_input = input1 + input2
|
||||
out = self.fc1(model_input)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
class NeuralNetPositionalArguments(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(NeuralNetPositionalArguments, self).__init__()
|
||||
|
|
@ -218,18 +233,38 @@ def test_forward_call_multiple_positional_arguments():
|
|||
output = ort_model(x, y)
|
||||
assert output is not None
|
||||
|
||||
# TODO: Re-enable after "Support models with dynamically defined inputs" done.
|
||||
# def test_forward_call_positional_arguments():
|
||||
# device = 'cuda'
|
||||
def test_forward_call_multiple_positional_arguments_var_keyword():
|
||||
device = 'cuda'
|
||||
|
||||
# N, D_in, H, D_out = 64, 784, 500, 10
|
||||
# model = NeuralNetPositionalArguments(input_size=D_in, hidden_size=H, num_classes=D_out).to(device)
|
||||
# model = ORTModule(model)
|
||||
# args = [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)]
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetMultiplePositionalArgumentsVarKeyword(input_size=D_in, hidden_size=H, num_classes=D_out).to(device)
|
||||
|
||||
# # Make sure model runs without any exception
|
||||
# output = model(*args)
|
||||
# assert output is not None
|
||||
# TODO: remove exception check and uncomment the rest of the test when
|
||||
# PyTorch ONNX exporter supports **kwargs.
|
||||
with pytest.raises(NotImplementedError) as runtime_error:
|
||||
ort_model = ORTModule(model)
|
||||
assert '**kwargs' in str(runtime_error.value)
|
||||
|
||||
# # Check that the original forward signature is preserved.
|
||||
# assert signature(model.forward) == signature(ort_model.forward)
|
||||
# x = torch.randn(N, D_in, device=device)
|
||||
# y = torch.randn(N, D_in, device=device)
|
||||
|
||||
# # Make sure model runs without any exception
|
||||
# output = ort_model(x, y)
|
||||
# assert output is not None
|
||||
|
||||
def test_forward_call_positional_arguments():
|
||||
device = 'cuda'
|
||||
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetPositionalArguments(input_size=D_in, hidden_size=H, num_classes=D_out).to(device)
|
||||
model = ORTModule(model)
|
||||
args = [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)]
|
||||
|
||||
# Make sure model runs without any exception
|
||||
output = model(*args)
|
||||
assert output is not None
|
||||
|
||||
def test_forward_call_keyword_arguments():
|
||||
device = 'cuda'
|
||||
|
|
|
|||
Loading…
Reference in a new issue