Deprecate ORTTrainer (#13022)

This commit is contained in:
Baiju Meswani 2022-09-23 18:10:09 -07:00 committed by GitHub
parent 6f27659ceb
commit bcc93ab17c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 9 deletions

View file

@ -731,6 +731,10 @@ class ORTTrainer:
optimized_model_filepath: path to output the optimized training graph.
Defaults to "" (no output).
"""
warnings.warn(
"ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.",
FutureWarning,
)
warnings.warn(
"DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it"
)

View file

@ -1,19 +1,20 @@
import copy
import io
import os
import onnx
import torch
from inspect import signature
import warnings
from functools import partial
from inspect import signature
import numpy as np
import onnx
import torch
import onnxruntime as ort
from . import _utils, amp, checkpoint, optim, postprocess, ORTTrainerOptions, _checkpoint_storage
from .model_desc_validation import _ORTTrainerModelDesc
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from . import ORTTrainerOptions, _checkpoint_storage, _utils, amp, checkpoint, optim, postprocess
from .model_desc_validation import _ORTTrainerModelDesc
class TrainStepInfo(object):
r"""Private class used to store runtime information from current train step.
@ -119,6 +120,11 @@ class ORTTrainer(object):
"""
def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):
warnings.warn(
"ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.",
FutureWarning,
)
assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model"
assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'"
assert isinstance(
@ -291,10 +297,8 @@ class ORTTrainer(object):
f.write(self._onnx_model.SerializeToString())
def _check_model_export(self, input):
from onnx import helper, TensorProto, numpy_helper
import numpy as np
from numpy.testing import assert_allclose
import _test_helpers
from onnx import TensorProto, helper, numpy_helper
onnx_model_copy = copy.deepcopy(self._onnx_model)