Use updated symbolic_helper.check_training_mode (#11900)

Co-authored-by: Jingyan Wang, Baiju Meswani
This commit is contained in:
jingyanwangms 2022-07-12 17:26:06 -07:00 committed by GitHub
parent 178a413ca1
commit a9d0d3323e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -4,14 +4,17 @@
# --------------------------------------------------------------------------
import sys
import warnings
import torch
import torch.utils.checkpoint
import warnings
from packaging import version
from torch.onnx import symbolic_helper
from onnxruntime.capi._pybind_state import register_torch_autograd_function
from ._fallback import _FallbackManager, ORTModuleONNXModelException, ORTModuleTorchModelException, wrap_exception
from . import _logger
from ._fallback import ORTModuleONNXModelException, ORTModuleTorchModelException, _FallbackManager, wrap_exception
# Some autograd.Function's shouldn't be exported as PythonOp.
# If CheckpointFunction is exported as PythonOp, the checkpointed computation
@ -37,7 +40,15 @@ def _export_pt_1_10(g, n, *args, **kwargs):
"wrap exportable sub-nn.Module's as ORTModule."
)
inplace = kwargs["inplace"]
training_mode = symbolic_helper._training_mode
# TODO move to public API once exporter team exposes that
training_mode = None
runtime_pytorch_version = version.parse(torch.__version__.split("+")[0])
if runtime_pytorch_version > version.parse("1.11"):
from torch.onnx import _globals
training_mode = _globals.GLOBALS.training_mode
else:
training_mode = symbolic_helper._training_mode
cconv = n.cconv()
input_tensor_types = []
input_requires_grads = []