From bcc93ab17c3151ff34abbc068946eaf4fb889434 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Fri, 23 Sep 2022 18:10:09 -0700 Subject: [PATCH] Deprecate ORTTrainer (#13022) --- orttraining/orttraining/python/ort_trainer.py | 4 ++++ .../orttraining/python/training/orttrainer.py | 22 +++++++++++-------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index ec159d83bd..a9229bd3c3 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -731,6 +731,10 @@ class ORTTrainer: optimized_model_filepath: path to output the optimized training graph. Defaults to "" (no output). """ + warnings.warn( + "ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.", + FutureWarning, + ) warnings.warn( "DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it" ) diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index ed84582546..bdf6a1e9e1 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -1,19 +1,20 @@ import copy import io import os -import onnx -import torch -from inspect import signature import warnings from functools import partial +from inspect import signature + import numpy as np +import onnx +import torch import onnxruntime as ort -from . import _utils, amp, checkpoint, optim, postprocess, ORTTrainerOptions, _checkpoint_storage -from .model_desc_validation import _ORTTrainerModelDesc - from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference +from . import ORTTrainerOptions, _checkpoint_storage, _utils, amp, checkpoint, optim, postprocess +from .model_desc_validation import _ORTTrainerModelDesc + class TrainStepInfo(object): r"""Private class used to store runtime information from current train step. @@ -119,6 +120,11 @@ class ORTTrainer(object): """ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None): + warnings.warn( + "ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.", + FutureWarning, + ) + assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model" assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'" assert isinstance( @@ -291,10 +297,8 @@ class ORTTrainer(object): f.write(self._onnx_model.SerializeToString()) def _check_model_export(self, input): - from onnx import helper, TensorProto, numpy_helper - import numpy as np from numpy.testing import assert_allclose - import _test_helpers + from onnx import TensorProto, helper, numpy_helper onnx_model_copy = copy.deepcopy(self._onnx_model)