Fix ONNX exporter call with latest API for ORTrainer (#9228)

* update the exporter call with latest api in orttrainer

* use official export api instead of the private call
This commit is contained in:
Tang, Cheng 2021-10-01 10:49:55 -07:00 committed by GitHub
parent 448325b254
commit be4d887439
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -540,12 +540,11 @@ class ORTTrainer(object):
pytorch_export_contrib_ops.unregister()
# Export torch.nn.Module to ONNX
torch.onnx._export(model, tuple(sample_inputs_copy), f,
torch.onnx.export(model, tuple(sample_inputs_copy), f,
input_names=[input.name for input in self.model_desc.inputs],
output_names=[output.name for output in self.model_desc.outputs],
opset_version=self.options._internal_use.onnx_opset_version,
dynamic_axes=dynamic_axes,
_retain_param_name=True,
example_outputs=tuple(sample_outputs),
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING)