onnxruntime/orttraining/orttraining/python/training/__init__.py

19 lines
745 B
Python

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from onnxruntime.capi._pybind_state import TrainingParameters
from onnxruntime.capi._pybind_state import PropagateCastOpsStrategy
from onnxruntime.capi.training.training_session import TrainingSession
from .orttrainer_options import ORTTrainerOptions
from .orttrainer import ORTTrainer, TrainStepInfo
from . import amp, checkpoint, optim, model_desc_validation
try:
from .ortmodule import ORTModule
except ImportError:
# That is OK iff this is not a ORTModule training package
pass