diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 63af43ce48..1459d3b86d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -4,14 +4,17 @@ # -------------------------------------------------------------------------- import sys +import warnings + import torch import torch.utils.checkpoint -import warnings +from packaging import version from torch.onnx import symbolic_helper from onnxruntime.capi._pybind_state import register_torch_autograd_function -from ._fallback import _FallbackManager, ORTModuleONNXModelException, ORTModuleTorchModelException, wrap_exception + from . import _logger +from ._fallback import ORTModuleONNXModelException, ORTModuleTorchModelException, _FallbackManager, wrap_exception # Some autograd.Function's shouldn't be exported as PythonOp. # If CheckpointFunction is exported as PythonOp, the checkpointed computation @@ -37,7 +40,15 @@ def _export_pt_1_10(g, n, *args, **kwargs): "wrap exportable sub-nn.Module's as ORTModule." ) inplace = kwargs["inplace"] - training_mode = symbolic_helper._training_mode + # TODO move to public API once exporter team exposes that + training_mode = None + runtime_pytorch_version = version.parse(torch.__version__.split("+")[0]) + if runtime_pytorch_version > version.parse("1.11"): + from torch.onnx import _globals + + training_mode = _globals.GLOBALS.training_mode + else: + training_mode = symbolic_helper._training_mode cconv = n.cconv() input_tensor_types = [] input_requires_grads = []