mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Use updated symbolic_helper.check_training_mode (#11900)
Co-authored-by: Jingyan Wang, Baiju Meswani
This commit is contained in:
parent
178a413ca1
commit
a9d0d3323e
1 changed files with 14 additions and 3 deletions
|
|
@ -4,14 +4,17 @@
|
|||
# --------------------------------------------------------------------------
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import warnings
|
||||
from packaging import version
|
||||
from torch.onnx import symbolic_helper
|
||||
|
||||
from onnxruntime.capi._pybind_state import register_torch_autograd_function
|
||||
from ._fallback import _FallbackManager, ORTModuleONNXModelException, ORTModuleTorchModelException, wrap_exception
|
||||
|
||||
from . import _logger
|
||||
from ._fallback import ORTModuleONNXModelException, ORTModuleTorchModelException, _FallbackManager, wrap_exception
|
||||
|
||||
# Some autograd.Function's shouldn't be exported as PythonOp.
|
||||
# If CheckpointFunction is exported as PythonOp, the checkpointed computation
|
||||
|
|
@ -37,7 +40,15 @@ def _export_pt_1_10(g, n, *args, **kwargs):
|
|||
"wrap exportable sub-nn.Module's as ORTModule."
|
||||
)
|
||||
inplace = kwargs["inplace"]
|
||||
training_mode = symbolic_helper._training_mode
|
||||
# TODO move to public API once exporter team exposes that
|
||||
training_mode = None
|
||||
runtime_pytorch_version = version.parse(torch.__version__.split("+")[0])
|
||||
if runtime_pytorch_version > version.parse("1.11"):
|
||||
from torch.onnx import _globals
|
||||
|
||||
training_mode = _globals.GLOBALS.training_mode
|
||||
else:
|
||||
training_mode = symbolic_helper._training_mode
|
||||
cconv = n.cconv()
|
||||
input_tensor_types = []
|
||||
input_requires_grads = []
|
||||
|
|
|
|||
Loading…
Reference in a new issue