diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index b2f6cad907..1b5d200a1d 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -540,12 +540,11 @@ class ORTTrainer(object): pytorch_export_contrib_ops.unregister() # Export torch.nn.Module to ONNX - torch.onnx._export(model, tuple(sample_inputs_copy), f, + torch.onnx.export(model, tuple(sample_inputs_copy), f, input_names=[input.name for input in self.model_desc.inputs], output_names=[output.name for output in self.model_desc.outputs], opset_version=self.options._internal_use.onnx_opset_version, dynamic_axes=dynamic_axes, - _retain_param_name=True, example_outputs=tuple(sample_outputs), do_constant_folding=False, training=torch.onnx.TrainingMode.TRAINING)