onnxruntime/orttraining/orttraining/python/training/__init__.py

19 lines
754 B
Python

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from onnxruntime.capi._pybind_state import PropagateCastOpsStrategy, TrainingParameters
from onnxruntime.capi.training.training_session import TrainingSession
# Options need to be imported before `ORTTrainer`.
from .orttrainer_options import ORTTrainerOptions
from .orttrainer import ORTTrainer, TrainStepInfo
from . import amp, checkpoint, model_desc_validation, optim
try:
from .ortmodule import ORTModule
except ImportError:
# That is OK iff this is not a ORTModule training package
pass