mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Fix ONNX exporter call with latest API for ORTrainer (#9228)
* update the exporter call with latest api in orttrainer * use official export api instead of the private call
This commit is contained in:
parent
448325b254
commit
be4d887439
1 changed files with 1 additions and 2 deletions
|
|
@ -540,12 +540,11 @@ class ORTTrainer(object):
|
|||
pytorch_export_contrib_ops.unregister()
|
||||
|
||||
# Export torch.nn.Module to ONNX
|
||||
torch.onnx._export(model, tuple(sample_inputs_copy), f,
|
||||
torch.onnx.export(model, tuple(sample_inputs_copy), f,
|
||||
input_names=[input.name for input in self.model_desc.inputs],
|
||||
output_names=[output.name for output in self.model_desc.outputs],
|
||||
opset_version=self.options._internal_use.onnx_opset_version,
|
||||
dynamic_axes=dynamic_axes,
|
||||
_retain_param_name=True,
|
||||
example_outputs=tuple(sample_outputs),
|
||||
do_constant_folding=False,
|
||||
training=torch.onnx.TrainingMode.TRAINING)
|
||||
|
|
|
|||
Loading…
Reference in a new issue