From a9d0d3323e60bd8e4d89eb608e7195b27504d2be Mon Sep 17 00:00:00 2001 From: jingyanwangms <47403504+jingyanwangms@users.noreply.github.com> Date: Tue, 12 Jul 2022 17:26:06 -0700 Subject: [PATCH] Use updated symbolic_helper.check_training_mode (#11900) Co-authored-by: Jingyan Wang, Baiju Meswani --- .../_custom_autograd_function_exporter.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 63af43ce48..1459d3b86d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -4,14 +4,17 @@ # -------------------------------------------------------------------------- import sys +import warnings + import torch import torch.utils.checkpoint -import warnings +from packaging import version from torch.onnx import symbolic_helper from onnxruntime.capi._pybind_state import register_torch_autograd_function -from ._fallback import _FallbackManager, ORTModuleONNXModelException, ORTModuleTorchModelException, wrap_exception + from . import _logger +from ._fallback import ORTModuleONNXModelException, ORTModuleTorchModelException, _FallbackManager, wrap_exception # Some autograd.Function's shouldn't be exported as PythonOp. # If CheckpointFunction is exported as PythonOp, the checkpointed computation @@ -37,7 +40,15 @@ def _export_pt_1_10(g, n, *args, **kwargs): "wrap exportable sub-nn.Module's as ORTModule." ) inplace = kwargs["inplace"] - training_mode = symbolic_helper._training_mode + # TODO move to public API once exporter team exposes that + training_mode = None + runtime_pytorch_version = version.parse(torch.__version__.split("+")[0]) + if runtime_pytorch_version > version.parse("1.11"): + from torch.onnx import _globals + + training_mode = _globals.GLOBALS.training_mode + else: + training_mode = symbolic_helper._training_mode cconv = n.cconv() input_tensor_types = [] input_requires_grads = []