Support keyword arguments for ORTModule (#6539)

* Support keyword arguments for ORTModule.

* Add backward workaround to the test.

* Specify test name directly without -k.

* Handle unused inputs removed by ONNX exporter.
This commit is contained in:
Sergii Dymchenko 2021-02-19 13:40:44 -08:00 committed by GitHub
parent 1a2f1bd23a
commit 58f3aca95d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 101 additions and 47 deletions

View file

@ -46,7 +46,13 @@ def _deepcopy_model_input(*inputs, **kwargs):
for model_input in inputs:
sample_inputs_copy.append(model_input.data if isinstance(model_input, torch.Tensor) else model_input)
sample_inputs_copy = copy.deepcopy(tuple(sample_inputs_copy))
return sample_inputs_copy
sample_kwargs_copy = {}
for name, model_input in kwargs.items():
sample_kwargs_copy[name] = model_input.data if isinstance(model_input, torch.Tensor) else model_input
sample_kwargs_copy = copy.deepcopy(sample_kwargs_copy)
return sample_inputs_copy, sample_kwargs_copy
def _onnx_value_info_to_buffer_tensor(value_info, device):
'''Create a torch zeroed tensor with the same shape and type of `value_info`'''
@ -55,26 +61,38 @@ def _onnx_value_info_to_buffer_tensor(value_info, device):
dtype = _utils.dtype_onnx_to_torch(value_info.type.tensor_type.elem_type)
return torch.zeros(shape, device=device, dtype=dtype)
def _parse_inputs_for_onnx_export(module, *inputs, **kwargs):
# Ignore optional *inputs explicitly specified as None
sig = signature(module.forward)
all_input_names = sig.parameters.keys()
def _parse_inputs_for_onnx_export(all_input_names, onnx_graph, *inputs, **kwargs):
# Ignore optional inputs explicitly specified as None
# ONNX exporter may remove unused inputs
onnx_graph_input_names = []
if onnx_graph is not None:
onnx_graph_input_names = set([inp.name for inp in onnx_graph.graph.input])
input_names = []
dynamic_axes = {}
input_names_require_grad = []
input_shape = []
for input_idx, name in enumerate(all_input_names):
inp = None
if input_idx < len(inputs) and inputs[input_idx] is not None:
if inputs[input_idx].requires_grad:
inp = inputs[input_idx]
elif name in kwargs and kwargs[name] is not None:
inp = kwargs[name]
if inp is not None and (onnx_graph is None or name in onnx_graph_input_names):
if inp.requires_grad:
# input_names_require_grad holds all input tensors that have requires_grad
input_names_require_grad.append(name)
input_names.append(name)
dynamic_axes[name] = {}
for dim_idx in range(len(inputs[input_idx].shape)):
dynamic_axes[name].update({dim_idx : 'input{}_dim{}'.format(input_idx, dim_idx)})
return input_names, dynamic_axes, input_names_require_grad
for dim_idx in range(len(inp.shape)):
dynamic_axes[name].update({dim_idx : f'input{input_idx}_dim{dim_idx}'})
def _parse_outputs_for_onnx_export(module, inputs):
input_shape.append(list(inp.size()))
return input_names, dynamic_axes, input_names_require_grad, input_shape
def _parse_outputs_for_onnx_export(module, inputs, kwargs):
def _create_output_dim_names_from_mapping(output):
output_names, dynamic_axes = [], {}
@ -106,7 +124,7 @@ def _parse_outputs_for_onnx_export(module, inputs):
sample_output_type = None
with torch.no_grad():
# Deepcopy inputs, since input values may change after model run.
sample_inputs_copy = _deepcopy_model_input(*inputs)
sample_inputs_copy, sample_kwargs_copy = _deepcopy_model_input(*inputs, **kwargs)
try:
# Deepcopy model, in case model is stateful and changes after model run.
model_copy = copy.deepcopy(module)
@ -115,7 +133,7 @@ def _parse_outputs_for_onnx_export(module, inputs):
warnings.warn("This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX."
" Compute will continue, but unexpected results may occur!")
sample_outputs = model_copy(*sample_inputs_copy)
sample_outputs = model_copy(*sample_inputs_copy, **sample_kwargs_copy)
sample_output_type = type(sample_outputs)
if isinstance(sample_outputs, torch.Tensor):
output_names, output_dynamic_axes = _create_output_dim_names(sample_outputs, 0, False)
@ -189,6 +207,8 @@ class ORTModule(torch.nn.Module):
# User module is wrapped to use its initializers and save computed gradients
self._original_module = module
sig = signature(self._original_module.forward)
self._original_module_input_names = sig.parameters.keys()
self._onnx_inference = None
self._is_training = True
@ -354,14 +374,13 @@ class ORTModule(torch.nn.Module):
raise RuntimeError('A device must be specified in the model or data!')
self._get_inference_graph_and_init_gradient_graph_builder(*inputs, **kwargs)
_, _, input_names_require_grad = _parse_inputs_for_onnx_export(self._original_module, *inputs, **kwargs)
_, _, input_names_require_grad, new_input_shape = _parse_inputs_for_onnx_export(self._original_module_input_names, self._onnx_inference, *inputs, **kwargs)
# If inputs requiring gradient change from one call to forward to the next, the module_gradient_graph_builder
# needs to be reinitialized so it can compute the backward output for the new inputs that require_grad
if input_names_require_grad != self._input_names_require_grad:
self._input_names_require_grad = input_names_require_grad
self._initialize_module_gradient_graph_builder()
new_input_shape = [list(input.size()) for input in inputs if input is not None]
if self._current_input_shape is None or self._current_input_shape != new_input_shape:
self._current_input_shape = new_input_shape
self._build_training_graph()
@ -379,7 +398,9 @@ class ORTModule(torch.nn.Module):
def forward(ctx, *inputs, **kwargs):
'''Performs forward pass based on user input and PyTorch initializer
TODO: **kwargs are not supported
Autograd Function's apply() doesn't support keyword arguments,
so `*inputs` has all the arguments - keyword arguments converted
to positional by the caller.
Module outputs are returned to the user
'''
@ -426,27 +447,29 @@ class ORTModule(torch.nn.Module):
for backward_output in backward_outputs[num_user_input_grads:]]
return tuple(results)
proc_inputs = [data for data in inputs if data is not None]
return _populate_user_output(self._original_module_output_type, self._onnx_graphs_info.user_output_names,
_ORTModuleFunction.apply(*self._convert_training_graph_input_to_list(*proc_inputs, **kwargs)))
_ORTModuleFunction.apply(*self._convert_training_graph_input_to_list(*inputs, **kwargs)))
@_utils.timeit(enabled=__TEMP_ENABLE_METHOD_TIMING__)
def _convert_training_graph_input_to_list(self, *inputs, **kwargs):
'''Creates forward `*inputs` list from user input and PyTorch initializers
TODO: **kwargs is not supported
TODO: How IO binding model inputs and outputs affects initializer copies?
ONNX Runtime forward requires an order list of:
* User input: computed from forward InferenceSession
* Initializers: computed from original PyTorch model parameters
This codes assumes the exported model's inputs and initializers
are the same as the original PyTorch model
'''
# User inputs
result = list(inputs[:len(self._onnx_graphs_info.user_input_names)])
result = []
for input_idx, name in enumerate(self._original_module_input_names):
inp = None
if input_idx < len(inputs) and inputs[input_idx] is not None:
inp = inputs[input_idx]
elif name in kwargs and kwargs[name] is not None:
inp = kwargs[name]
if inp is not None and name in self._onnx_graphs_info.user_input_names:
result.append(inp)
# Initializers
for param in self._original_module.named_parameters():
@ -462,8 +485,8 @@ class ORTModule(torch.nn.Module):
'''
# Setup dynamic axes for onnx model
input_names, dynamic_axes, self._input_names_require_grad = _parse_inputs_for_onnx_export(self._original_module, *inputs, **kwargs)
output_names, output_dynamic_axes, self._original_module_output_type = _parse_outputs_for_onnx_export(self._original_module, inputs)
input_names, dynamic_axes, self._input_names_require_grad, _ = _parse_inputs_for_onnx_export(self._original_module_input_names, None, *inputs, **kwargs)
output_names, output_dynamic_axes, self._original_module_output_type = _parse_outputs_for_onnx_export(self._original_module, inputs, kwargs)
dynamic_axes.update(output_dynamic_axes)
# Export torch.nn.Module to ONNX
@ -472,12 +495,12 @@ class ORTModule(torch.nn.Module):
# Deepcopy inputs, since input values may change after model run.
# NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn).
# Therefore, deepcopy only the data component of the input tensors for export.
sample_inputs_copy = _deepcopy_model_input(*inputs, **kwargs)
sample_inputs_copy, sample_kwargs_copy = _deepcopy_model_input(*inputs, **kwargs)
try:
with torch.no_grad():
torch.onnx.export(self._original_module,
sample_inputs_copy,
sample_inputs_copy + (sample_kwargs_copy, ),
f,
input_names=input_names,
output_names=output_names,

View file

@ -47,8 +47,8 @@ def run_ortmodule_api_tests(cwd, log):
# because ORTModule doesn't support multiple run call at the same time
for test_name in plugin.collected:
run_subprocess([
sys.executable, '-m', 'pytest',
'orttraining_test_ortmodule_api.py', '-sv', '-k', test_name], cwd=cwd).check_returncode()
sys.executable, '-m', 'pytest', '-sv',
'orttraining_test_ortmodule_api.py' + '::' + test_name], cwd=cwd).check_returncode()
def run_ortmodule_poc_net(cwd, log, no_cuda, data_dir):
log.debug('Running: ORTModule POCNet for MNIST with --no-cuda arg {}.'.format(no_cuda))

View file

@ -91,6 +91,17 @@ class NeuralNetPositionalAndKeywordArguments(torch.nn.Module):
out = self.fc2(out)
return out
class NeuralNetSimplePositionalAndKeywordArguments(torch.nn.Module):
def __init__(self):
super(NeuralNetSimplePositionalAndKeywordArguments, self).__init__()
self.a = torch.nn.Parameter(torch.FloatTensor([-1., 1.]))
def forward(self, x, y=None, z=None):
if z is not None:
return torch.mean(self.a) + x + 4 * z
if y is not None:
return torch.mean(self.a) + 3 * y
return torch.mean(self.a) + x
def _get_bert_for_sequence_classification_model(device):
"""Returns the BertForSequenceClassification pretrained model"""
@ -185,6 +196,40 @@ def test_forward_call_positional_and_keyword_arguments():
output = model(a, x, y, z)
assert output is not None
@pytest.mark.parametrize("forward_statement", [
"model(one)",
"model(x=one)",
"model(one, None, None)",
"model(one, None, z=None)",
"model(one, None)",
"model(x=one, y=one)",
"model(y=one, x=one)",
"model(y=one, z=None, x=one)",
"model(one, None, z=one)",
"model(x=one, z=one)",
"model(one, z=one)",
"model(one, z=one, y=one)",
"model(one, one, one)",
"model(one, None, one)",
"model(z=one, x=one, y=one)",
"model(z=one, x=one, y=None)"
])
def test_compare_pytorch_forward_call_positional_and_keyword_arguments(forward_statement):
one = torch.FloatTensor([1])
model = NeuralNetSimplePositionalAndKeywordArguments()
pytorch_result = eval(forward_statement + ".item()")
model = NeuralNetSimplePositionalAndKeywordArguments()
model = ORTModule(model)
ortmodule_result = eval(forward_statement)
# TODO: remove backward call when the issue with multiple call to forward fixed.
ortmodule_result.backward()
ortmodule_result = ortmodule_result.item()
ortmodule_result_again = eval(forward_statement + ".item()")
assert ortmodule_result == ortmodule_result_again
assert pytorch_result == ortmodule_result
def test_model_cuda():
original_device = 'cpu'
to_device = 'cuda'

View file

@ -72,18 +72,9 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
# The documentation for this `model` function is here:
# https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
# TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule
# outputs = model(b_input_ids,
# token_type_ids = None,
# attention_mask = b_input_mask,
# labels = b_labels)
outputs = model(b_input_ids,
b_input_mask,
None,
None,
None,
None,
b_labels)
attention_mask=b_input_mask,
labels=b_labels)
# The call to `model` always returns a tuple, so we need to pull the
# loss value out of the tuple.
@ -170,17 +161,12 @@ def test(model, validation_dataloader, device, args):
# The documentation for this `model` function is here:
# https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
# TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule
# TODO: original sample had the last argument equal to None, but b_labels is because model was
# exported using 3 inputs for training, so validation must follow.
# Another approach would be checkpoint the trained model, re-export the model for validation with the checkpoint
outputs = model(b_input_ids,
b_input_mask,
None,
None,
None,
None,
b_labels)
attention_mask=b_input_mask,
labels=b_labels)
# Get the "logits" output by the model. The "logits" are the output
# values prior to applying an activation function like the softmax.