mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Deprecate ORTTrainer (#13022)
This commit is contained in:
parent
6f27659ceb
commit
bcc93ab17c
2 changed files with 17 additions and 9 deletions
|
|
@ -731,6 +731,10 @@ class ORTTrainer:
|
|||
optimized_model_filepath: path to output the optimized training graph.
|
||||
Defaults to "" (no output).
|
||||
"""
|
||||
warnings.warn(
|
||||
"ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
warnings.warn(
|
||||
"DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,20 @@
|
|||
import copy
|
||||
import io
|
||||
import os
|
||||
import onnx
|
||||
import torch
|
||||
from inspect import signature
|
||||
import warnings
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
|
||||
import onnxruntime as ort
|
||||
from . import _utils, amp, checkpoint, optim, postprocess, ORTTrainerOptions, _checkpoint_storage
|
||||
from .model_desc_validation import _ORTTrainerModelDesc
|
||||
|
||||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
|
||||
from . import ORTTrainerOptions, _checkpoint_storage, _utils, amp, checkpoint, optim, postprocess
|
||||
from .model_desc_validation import _ORTTrainerModelDesc
|
||||
|
||||
|
||||
class TrainStepInfo(object):
|
||||
r"""Private class used to store runtime information from current train step.
|
||||
|
|
@ -119,6 +120,11 @@ class ORTTrainer(object):
|
|||
"""
|
||||
|
||||
def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):
|
||||
warnings.warn(
|
||||
"ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model"
|
||||
assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'"
|
||||
assert isinstance(
|
||||
|
|
@ -291,10 +297,8 @@ class ORTTrainer(object):
|
|||
f.write(self._onnx_model.SerializeToString())
|
||||
|
||||
def _check_model_export(self, input):
|
||||
from onnx import helper, TensorProto, numpy_helper
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
import _test_helpers
|
||||
from onnx import TensorProto, helper, numpy_helper
|
||||
|
||||
onnx_model_copy = copy.deepcopy(self._onnx_model)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue