diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index d24357c288..78d668bc14 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -46,7 +46,13 @@ def _deepcopy_model_input(*inputs, **kwargs): for model_input in inputs: sample_inputs_copy.append(model_input.data if isinstance(model_input, torch.Tensor) else model_input) sample_inputs_copy = copy.deepcopy(tuple(sample_inputs_copy)) - return sample_inputs_copy + + sample_kwargs_copy = {} + for name, model_input in kwargs.items(): + sample_kwargs_copy[name] = model_input.data if isinstance(model_input, torch.Tensor) else model_input + sample_kwargs_copy = copy.deepcopy(sample_kwargs_copy) + + return sample_inputs_copy, sample_kwargs_copy def _onnx_value_info_to_buffer_tensor(value_info, device): '''Create a torch zeroed tensor with the same shape and type of `value_info`''' @@ -55,26 +61,38 @@ def _onnx_value_info_to_buffer_tensor(value_info, device): dtype = _utils.dtype_onnx_to_torch(value_info.type.tensor_type.elem_type) return torch.zeros(shape, device=device, dtype=dtype) -def _parse_inputs_for_onnx_export(module, *inputs, **kwargs): - # Ignore optional *inputs explicitly specified as None - sig = signature(module.forward) - all_input_names = sig.parameters.keys() +def _parse_inputs_for_onnx_export(all_input_names, onnx_graph, *inputs, **kwargs): + # Ignore optional inputs explicitly specified as None + # ONNX exporter may remove unused inputs + onnx_graph_input_names = [] + if onnx_graph is not None: + onnx_graph_input_names = set([inp.name for inp in onnx_graph.graph.input]) + input_names = [] dynamic_axes = {} input_names_require_grad = [] + input_shape = [] + for input_idx, name in enumerate(all_input_names): + inp = None if input_idx < len(inputs) and inputs[input_idx] is not None: - if inputs[input_idx].requires_grad: + inp = inputs[input_idx] + elif name in kwargs and kwargs[name] is not None: + inp = kwargs[name] + if inp is not None and (onnx_graph is None or name in onnx_graph_input_names): + if inp.requires_grad: # input_names_require_grad holds all input tensors that have requires_grad input_names_require_grad.append(name) input_names.append(name) dynamic_axes[name] = {} - for dim_idx in range(len(inputs[input_idx].shape)): - dynamic_axes[name].update({dim_idx : 'input{}_dim{}'.format(input_idx, dim_idx)}) - return input_names, dynamic_axes, input_names_require_grad + for dim_idx in range(len(inp.shape)): + dynamic_axes[name].update({dim_idx : f'input{input_idx}_dim{dim_idx}'}) -def _parse_outputs_for_onnx_export(module, inputs): + input_shape.append(list(inp.size())) + return input_names, dynamic_axes, input_names_require_grad, input_shape + +def _parse_outputs_for_onnx_export(module, inputs, kwargs): def _create_output_dim_names_from_mapping(output): output_names, dynamic_axes = [], {} @@ -106,7 +124,7 @@ def _parse_outputs_for_onnx_export(module, inputs): sample_output_type = None with torch.no_grad(): # Deepcopy inputs, since input values may change after model run. - sample_inputs_copy = _deepcopy_model_input(*inputs) + sample_inputs_copy, sample_kwargs_copy = _deepcopy_model_input(*inputs, **kwargs) try: # Deepcopy model, in case model is stateful and changes after model run. model_copy = copy.deepcopy(module) @@ -115,7 +133,7 @@ def _parse_outputs_for_onnx_export(module, inputs): warnings.warn("This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX." " Compute will continue, but unexpected results may occur!") - sample_outputs = model_copy(*sample_inputs_copy) + sample_outputs = model_copy(*sample_inputs_copy, **sample_kwargs_copy) sample_output_type = type(sample_outputs) if isinstance(sample_outputs, torch.Tensor): output_names, output_dynamic_axes = _create_output_dim_names(sample_outputs, 0, False) @@ -189,6 +207,8 @@ class ORTModule(torch.nn.Module): # User module is wrapped to use its initializers and save computed gradients self._original_module = module + sig = signature(self._original_module.forward) + self._original_module_input_names = sig.parameters.keys() self._onnx_inference = None self._is_training = True @@ -354,14 +374,13 @@ class ORTModule(torch.nn.Module): raise RuntimeError('A device must be specified in the model or data!') self._get_inference_graph_and_init_gradient_graph_builder(*inputs, **kwargs) - _, _, input_names_require_grad = _parse_inputs_for_onnx_export(self._original_module, *inputs, **kwargs) + _, _, input_names_require_grad, new_input_shape = _parse_inputs_for_onnx_export(self._original_module_input_names, self._onnx_inference, *inputs, **kwargs) # If inputs requiring gradient change from one call to forward to the next, the module_gradient_graph_builder # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad if input_names_require_grad != self._input_names_require_grad: self._input_names_require_grad = input_names_require_grad self._initialize_module_gradient_graph_builder() - new_input_shape = [list(input.size()) for input in inputs if input is not None] if self._current_input_shape is None or self._current_input_shape != new_input_shape: self._current_input_shape = new_input_shape self._build_training_graph() @@ -379,7 +398,9 @@ class ORTModule(torch.nn.Module): def forward(ctx, *inputs, **kwargs): '''Performs forward pass based on user input and PyTorch initializer - TODO: **kwargs are not supported + Autograd Function's apply() doesn't support keyword arguments, + so `*inputs` has all the arguments - keyword arguments converted + to positional by the caller. Module outputs are returned to the user ''' @@ -426,27 +447,29 @@ class ORTModule(torch.nn.Module): for backward_output in backward_outputs[num_user_input_grads:]] return tuple(results) - proc_inputs = [data for data in inputs if data is not None] - return _populate_user_output(self._original_module_output_type, self._onnx_graphs_info.user_output_names, - _ORTModuleFunction.apply(*self._convert_training_graph_input_to_list(*proc_inputs, **kwargs))) + _ORTModuleFunction.apply(*self._convert_training_graph_input_to_list(*inputs, **kwargs))) @_utils.timeit(enabled=__TEMP_ENABLE_METHOD_TIMING__) def _convert_training_graph_input_to_list(self, *inputs, **kwargs): '''Creates forward `*inputs` list from user input and PyTorch initializers - TODO: **kwargs is not supported TODO: How IO binding model inputs and outputs affects initializer copies? ONNX Runtime forward requires an order list of: * User input: computed from forward InferenceSession * Initializers: computed from original PyTorch model parameters - - This codes assumes the exported model's inputs and initializers - are the same as the original PyTorch model ''' # User inputs - result = list(inputs[:len(self._onnx_graphs_info.user_input_names)]) + result = [] + for input_idx, name in enumerate(self._original_module_input_names): + inp = None + if input_idx < len(inputs) and inputs[input_idx] is not None: + inp = inputs[input_idx] + elif name in kwargs and kwargs[name] is not None: + inp = kwargs[name] + if inp is not None and name in self._onnx_graphs_info.user_input_names: + result.append(inp) # Initializers for param in self._original_module.named_parameters(): @@ -462,8 +485,8 @@ class ORTModule(torch.nn.Module): ''' # Setup dynamic axes for onnx model - input_names, dynamic_axes, self._input_names_require_grad = _parse_inputs_for_onnx_export(self._original_module, *inputs, **kwargs) - output_names, output_dynamic_axes, self._original_module_output_type = _parse_outputs_for_onnx_export(self._original_module, inputs) + input_names, dynamic_axes, self._input_names_require_grad, _ = _parse_inputs_for_onnx_export(self._original_module_input_names, None, *inputs, **kwargs) + output_names, output_dynamic_axes, self._original_module_output_type = _parse_outputs_for_onnx_export(self._original_module, inputs, kwargs) dynamic_axes.update(output_dynamic_axes) # Export torch.nn.Module to ONNX @@ -472,12 +495,12 @@ class ORTModule(torch.nn.Module): # Deepcopy inputs, since input values may change after model run. # NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn). # Therefore, deepcopy only the data component of the input tensors for export. - sample_inputs_copy = _deepcopy_model_input(*inputs, **kwargs) + sample_inputs_copy, sample_kwargs_copy = _deepcopy_model_input(*inputs, **kwargs) try: with torch.no_grad(): torch.onnx.export(self._original_module, - sample_inputs_copy, + sample_inputs_copy + (sample_kwargs_copy, ), f, input_names=input_names, output_names=output_names, diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py index dcc610cb46..36e98f9716 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py @@ -47,8 +47,8 @@ def run_ortmodule_api_tests(cwd, log): # because ORTModule doesn't support multiple run call at the same time for test_name in plugin.collected: run_subprocess([ - sys.executable, '-m', 'pytest', - 'orttraining_test_ortmodule_api.py', '-sv', '-k', test_name], cwd=cwd).check_returncode() + sys.executable, '-m', 'pytest', '-sv', + 'orttraining_test_ortmodule_api.py' + '::' + test_name], cwd=cwd).check_returncode() def run_ortmodule_poc_net(cwd, log, no_cuda, data_dir): log.debug('Running: ORTModule POCNet for MNIST with --no-cuda arg {}.'.format(no_cuda)) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 5d786e6aa8..a31cc06627 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -91,6 +91,17 @@ class NeuralNetPositionalAndKeywordArguments(torch.nn.Module): out = self.fc2(out) return out +class NeuralNetSimplePositionalAndKeywordArguments(torch.nn.Module): + def __init__(self): + super(NeuralNetSimplePositionalAndKeywordArguments, self).__init__() + self.a = torch.nn.Parameter(torch.FloatTensor([-1., 1.])) + def forward(self, x, y=None, z=None): + if z is not None: + return torch.mean(self.a) + x + 4 * z + if y is not None: + return torch.mean(self.a) + 3 * y + return torch.mean(self.a) + x + def _get_bert_for_sequence_classification_model(device): """Returns the BertForSequenceClassification pretrained model""" @@ -185,6 +196,40 @@ def test_forward_call_positional_and_keyword_arguments(): output = model(a, x, y, z) assert output is not None +@pytest.mark.parametrize("forward_statement", [ + "model(one)", + "model(x=one)", + "model(one, None, None)", + "model(one, None, z=None)", + "model(one, None)", + "model(x=one, y=one)", + "model(y=one, x=one)", + "model(y=one, z=None, x=one)", + "model(one, None, z=one)", + "model(x=one, z=one)", + "model(one, z=one)", + "model(one, z=one, y=one)", + "model(one, one, one)", + "model(one, None, one)", + "model(z=one, x=one, y=one)", + "model(z=one, x=one, y=None)" +]) +def test_compare_pytorch_forward_call_positional_and_keyword_arguments(forward_statement): + one = torch.FloatTensor([1]) + + model = NeuralNetSimplePositionalAndKeywordArguments() + pytorch_result = eval(forward_statement + ".item()") + + model = NeuralNetSimplePositionalAndKeywordArguments() + model = ORTModule(model) + ortmodule_result = eval(forward_statement) + # TODO: remove backward call when the issue with multiple call to forward fixed. + ortmodule_result.backward() + ortmodule_result = ortmodule_result.item() + ortmodule_result_again = eval(forward_statement + ".item()") + assert ortmodule_result == ortmodule_result_again + assert pytorch_result == ortmodule_result + def test_model_cuda(): original_device = 'cpu' to_device = 'cuda' diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py index 016580014d..f56385d371 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py @@ -72,18 +72,9 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args): # The documentation for this `model` function is here: # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification - # TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule - # outputs = model(b_input_ids, - # token_type_ids = None, - # attention_mask = b_input_mask, - # labels = b_labels) outputs = model(b_input_ids, - b_input_mask, - None, - None, - None, - None, - b_labels) + attention_mask=b_input_mask, + labels=b_labels) # The call to `model` always returns a tuple, so we need to pull the # loss value out of the tuple. @@ -170,17 +161,12 @@ def test(model, validation_dataloader, device, args): # The documentation for this `model` function is here: # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification - # TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule # TODO: original sample had the last argument equal to None, but b_labels is because model was # exported using 3 inputs for training, so validation must follow. # Another approach would be checkpoint the trained model, re-export the model for validation with the checkpoint outputs = model(b_input_ids, - b_input_mask, - None, - None, - None, - None, - b_labels) + attention_mask=b_input_mask, + labels=b_labels) # Get the "logits" output by the model. The "logits" are the output # values prior to applying an activation function like the softmax.