From f38f2d5b5495acdf8fd21c25976d67df47a3f1bf Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Tue, 1 Sep 2020 19:00:49 -0700 Subject: [PATCH] Port #4920 into the new pytorch frontend (#4965) --- .../python/experimental/orttrainer.py | 95 ++++++++----------- 1 file changed, 42 insertions(+), 53 deletions(-) diff --git a/orttraining/orttraining/python/experimental/orttrainer.py b/orttraining/orttraining/python/experimental/orttrainer.py index a64eb371a3..e25e7aec69 100644 --- a/orttraining/orttraining/python/experimental/orttrainer.py +++ b/orttraining/orttraining/python/experimental/orttrainer.py @@ -375,13 +375,41 @@ class ORTTrainer(object): results = [session_run_results[o_desc.name] for o_desc in self.model_desc.outputs] return results[0] if len (results) == 1 else results - def _combine_torch_model_with_loss_fn_and_wrap_input(self): + def _convert_torch_model_loss_fn_to_onnx(self, inputs, device): + # Dynamic axes + dynamic_axes = {} + for input in self.model_desc.inputs: + symbolic_axis = {} + for i, axis in enumerate(input.shape): + if isinstance(axis, str): + symbolic_axis[i] = axis + if len(symbolic_axis): + dynamic_axes[input.name] = symbolic_axis + for output in self.model_desc.outputs: + symbolic_axis = {} + for i, axis in enumerate(output.shape): + if isinstance(axis, str): + symbolic_axis[i] = axis + if len(symbolic_axis): + dynamic_axes[output.name] = symbolic_axis + + if isinstance(inputs, torch.Tensor): + inputs = [inputs] + if isinstance(inputs, dict): + sample_inputs = [inputs[k.name_].to(device=device) for k in self.model_desc.inputs] + elif isinstance(inputs, (list, tuple)): + sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(self.model_desc.inputs)] + else: + raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.") + + # PyTorch ONNX exporter does not match argument names + # This is an issue because the ONNX graph depends on all inputs to be specified + # Validate loss_fn if self.loss_fn: sig_loss = signature(self.loss_fn) if len(sig_loss.parameters) != 2: - raise RuntimeError( - "loss function should take two arguments - predict and label.") + raise RuntimeError("loss function should take two arguments - predict and label.") # Basic input names from model input_names = [input.name for input in self.model_desc.inputs] @@ -416,39 +444,7 @@ class ORTTrainer(object): preds = model_out return self.loss_fn(preds, label), preds - return CombineTorchModelLossFnWrapInput(self._torch_model, self.loss_fn, input_names) - - - def _convert_torch_model_loss_fn_to_onnx(self, inputs, device): - # Dynamic axes - dynamic_axes = {} - for input in self.model_desc.inputs: - symbolic_axis = {} - for i, axis in enumerate(input.shape): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[input.name] = symbolic_axis - for output in self.model_desc.outputs: - symbolic_axis = {} - for i, axis in enumerate(output.shape): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[output.name] = symbolic_axis - - if isinstance(inputs, torch.Tensor): - inputs = [inputs] - if isinstance(inputs, dict): - sample_inputs = [inputs[k.name_].to(device=device) for k in self.model_desc.inputs] - elif isinstance(inputs, (list, tuple)): - sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(self.model_desc.inputs)] - else: - raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.") - - # PyTorch ONNX exporter does not match argument names - # This is an issue because the ONNX graph depends on all inputs to be specified - model = self._combine_torch_model_with_loss_fn_and_wrap_input() + model = CombineTorchModelLossFnWrapInput(self._torch_model, self.loss_fn, input_names) # Do an inference to grab output types model.eval() @@ -493,23 +489,16 @@ class ORTTrainer(object): onnx_model = onnx.load_model_from_string(f.getvalue()) # Remove 'model.' prefix introduced by CombineTorchModelLossFn class - replace_name_dict = {} - for n in onnx_model.graph.initializer: - if n.name.startswith('model.'): - replace_name_dict[n.name] = n.name[len('model.'):] - n.name = replace_name_dict[n.name] - for n in onnx_model.graph.node: - for i, name in enumerate(n.input): - if name in replace_name_dict: - n.input[i] = replace_name_dict[name] - - # ONNX model initializers may contain non-trainable registered buffers - # that are not part of PyTorch model named parameteres - named_parameters = model.model.named_parameters() if hasattr(model, 'model') else model.named_parameters() - assert set([n for n, t in named_parameters]).issubset( - set([n.name for n in onnx_model.graph.initializer])), \ - "Initializer names do not match between PyTorch model and ONNX model, " \ - "please report a bug to ONNX Runtime." + if isinstance(model, CombineTorchModelLossFnWrapInput): + replace_name_dict = {} + for n in onnx_model.graph.initializer: + if n.name.startswith('model.'): + replace_name_dict[n.name] = n.name[len('model.'):] + n.name = replace_name_dict[n.name] + for n in onnx_model.graph.node: + for i, name in enumerate(n.input): + if name in replace_name_dict: + n.input[i] = replace_name_dict[name] return onnx_model