mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Fix opset version change by not using copy of global constant (#9393)
This commit is contained in:
parent
b5a652c578
commit
5d5c03bcdc
6 changed files with 60 additions and 37 deletions
|
|
@ -18,6 +18,8 @@ from .torch_cpp_extensions import is_installed as is_torch_cpp_extensions_instal
|
|||
|
||||
################################################################################
|
||||
# All global constant goes here, before ORTModule is imported ##################
|
||||
# NOTE: To *change* values in runtime, import onnxruntime.training.ortmodule and
|
||||
# assign them new values. Importing them directly do not propagate changes.
|
||||
################################################################################
|
||||
ONNX_OPSET_VERSION = 12
|
||||
MINIMUM_RUNTIME_PYTORCH_VERSION_STR = '1.8.1'
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from ._fallback import (_FallbackManager,
|
|||
ORTModuleTorchModelException,
|
||||
wrap_exception)
|
||||
from ._gradient_accumulation_manager import GradientAccumulationManager
|
||||
from onnxruntime.training.ortmodule import ONNX_OPSET_VERSION
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
|
|
@ -99,7 +99,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
|
||||
# indicators of some logic have been executed previously thus could be skipped for faster training
|
||||
# default is enabled, if not define in os env
|
||||
self._skip_check = _SkipCheck(_SkipCheck.SKIP_CHECK_DEVICE | _SkipCheck.SKIP_CHECK_BUILD_GRADIENT | _SkipCheck.SKIP_CHECK_EXECUTION_AGENT)
|
||||
self._skip_check = _SkipCheck(_SkipCheck.SKIP_CHECK_DEVICE | _SkipCheck.SKIP_CHECK_BUILD_GRADIENT | _SkipCheck.SKIP_CHECK_EXECUTION_AGENT)
|
||||
if os.getenv('ORTMODULE_SKIPCHECK_POLICY') is not None:
|
||||
self._skip_check = reduce(lambda x, y: x | y,
|
||||
[_SkipCheck[name] for name in
|
||||
|
|
@ -363,7 +363,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
_logger.suppress_os_stream_output(log_level=self._debug_options.logging.log_level):
|
||||
required_export_kwargs = {'input_names': self._input_info.names,
|
||||
'output_names': output_names,
|
||||
'opset_version': ONNX_OPSET_VERSION,
|
||||
'opset_version': ortmodule.ONNX_OPSET_VERSION,
|
||||
'do_constant_folding': False,
|
||||
'training': self._export_mode,
|
||||
'dynamic_axes': self._input_info.dynamic_axes,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
import tempfile
|
||||
import torch
|
||||
from ... import ORTModule
|
||||
from ... import ONNX_OPSET_VERSION
|
||||
from .... import ortmodule
|
||||
from ...debug_options import DebugOptions
|
||||
|
||||
|
||||
|
|
@ -91,7 +91,7 @@ class HierarchicalORTModule(torch.nn.Module):
|
|||
try:
|
||||
with tempfile.NamedTemporaryFile(prefix='sub-module') as temp:
|
||||
torch.onnx.export(
|
||||
module, args, temp, opset_version=ONNX_OPSET_VERSION,
|
||||
module, args, temp, opset_version=ortmodule.ONNX_OPSET_VERSION,
|
||||
do_constant_folding=False, export_params=False,
|
||||
keep_initializers_as_inputs=True,
|
||||
training=torch.onnx.TrainingMode.TRAINING)
|
||||
|
|
@ -125,7 +125,7 @@ class HierarchicalORTModule(torch.nn.Module):
|
|||
try:
|
||||
with tempfile.NamedTemporaryFile(prefix='sub-module') as temp:
|
||||
torch.onnx.export(
|
||||
module, args, temp, opset_version=ONNX_OPSET_VERSION,
|
||||
module, args, temp, opset_version=ortmodule.ONNX_OPSET_VERSION,
|
||||
do_constant_folding=False, export_params=False,
|
||||
keep_initializers_as_inputs=True,
|
||||
training=torch.onnx.TrainingMode.TRAINING)
|
||||
|
|
|
|||
|
|
@ -13,9 +13,7 @@ from .debug_options import DebugOptions
|
|||
from ._fallback import (_FallbackManager,
|
||||
_FallbackPolicy,
|
||||
ORTModuleFallbackException)
|
||||
from . import (_FALLBACK_INIT_EXCEPTION,
|
||||
ORTMODULE_FALLBACK_POLICY,
|
||||
ORTMODULE_FALLBACK_RETRY)
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
from onnxruntime.tools import pytorch_export_contrib_ops
|
||||
|
||||
|
|
@ -60,13 +58,13 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
# Fallback settings
|
||||
self._fallback_manager = _FallbackManager(pytorch_module=module,
|
||||
policy=ORTMODULE_FALLBACK_POLICY,
|
||||
retry=ORTMODULE_FALLBACK_RETRY)
|
||||
policy=ortmodule.ORTMODULE_FALLBACK_POLICY,
|
||||
retry=ortmodule.ORTMODULE_FALLBACK_RETRY)
|
||||
|
||||
try:
|
||||
# Read ORTModule module initialization status
|
||||
if _FALLBACK_INIT_EXCEPTION:
|
||||
raise _FALLBACK_INIT_EXCEPTION
|
||||
if ortmodule._FALLBACK_INIT_EXCEPTION:
|
||||
raise ortmodule._FALLBACK_INIT_EXCEPTION
|
||||
|
||||
super(ORTModule, self).__init__()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from onnxruntime.training.ortmodule import (ORTMODULE_TORCH_CPP_DIR,
|
||||
ONNXRUNTIME_CUDA_VERSION,
|
||||
ONNXRUNTIME_ROCM_VERSION)
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
from glob import glob
|
||||
from shutil import copyfile
|
||||
|
|
@ -24,11 +22,11 @@ def _list_extensions(path):
|
|||
|
||||
|
||||
def _list_cpu_extensions():
|
||||
return _list_extensions(os.path.join(ORTMODULE_TORCH_CPP_DIR, 'cpu'))
|
||||
return _list_extensions(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, 'cpu'))
|
||||
|
||||
|
||||
def _list_cuda_extensions():
|
||||
return _list_extensions(os.path.join(ORTMODULE_TORCH_CPP_DIR, 'cuda'))
|
||||
return _list_extensions(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, 'cuda'))
|
||||
|
||||
|
||||
def _install_extension(ext_name, ext_path, cwd):
|
||||
|
|
@ -44,12 +42,15 @@ def build_torch_cpp_extensions():
|
|||
'''Builds PyTorch CPP extensions and returns metadata'''
|
||||
|
||||
# Run this from within onnxruntime package folder
|
||||
is_gpu_available = ONNXRUNTIME_CUDA_VERSION is not None or ONNXRUNTIME_ROCM_VERSION is not None
|
||||
os.chdir(ORTMODULE_TORCH_CPP_DIR)
|
||||
is_gpu_available = ortmodule.ONNXRUNTIME_CUDA_VERSION is not None or\
|
||||
ortmodule.ONNXRUNTIME_ROCM_VERSION is not None
|
||||
os.chdir(ortmodule.ORTMODULE_TORCH_CPP_DIR)
|
||||
|
||||
# Extensions might leverage CUDA/ROCM versions internally
|
||||
os.environ["ONNXRUNTIME_CUDA_VERSION"] = ONNXRUNTIME_CUDA_VERSION if not ONNXRUNTIME_CUDA_VERSION is None else ''
|
||||
os.environ["ONNXRUNTIME_ROCM_VERSION"] = ONNXRUNTIME_ROCM_VERSION if not ONNXRUNTIME_ROCM_VERSION is None else ''
|
||||
os.environ["ONNXRUNTIME_CUDA_VERSION"] = ortmodule.ONNXRUNTIME_CUDA_VERSION \
|
||||
if not ortmodule.ONNXRUNTIME_CUDA_VERSION is None else ''
|
||||
os.environ["ONNXRUNTIME_ROCM_VERSION"] = ortmodule.ONNXRUNTIME_ROCM_VERSION \
|
||||
if not ortmodule.ONNXRUNTIME_ROCM_VERSION is None else ''
|
||||
|
||||
############################################################################
|
||||
# Pytorch CPP Extensions that DO require CUDA/ROCM
|
||||
|
|
@ -57,32 +58,32 @@ def build_torch_cpp_extensions():
|
|||
if is_gpu_available:
|
||||
for ext_setup in _list_cuda_extensions():
|
||||
_install_extension(ext_setup.split(
|
||||
os.sep)[-2], ext_setup, ORTMODULE_TORCH_CPP_DIR)
|
||||
os.sep)[-2], ext_setup, ortmodule.ORTMODULE_TORCH_CPP_DIR)
|
||||
|
||||
############################################################################
|
||||
# Pytorch CPP Extensions that DO NOT require CUDA/ROCM
|
||||
############################################################################
|
||||
for ext_setup in _list_cpu_extensions():
|
||||
_install_extension(ext_setup.split(
|
||||
os.sep)[-2], ext_setup, ORTMODULE_TORCH_CPP_DIR)
|
||||
os.sep)[-2], ext_setup, ortmodule.ORTMODULE_TORCH_CPP_DIR)
|
||||
|
||||
############################################################################
|
||||
# Install Pytorch CPP Extensions into local onnxruntime package folder
|
||||
############################################################################
|
||||
torch_cpp_exts = glob(os.path.join(ORTMODULE_TORCH_CPP_DIR,
|
||||
torch_cpp_exts = glob(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR,
|
||||
'build',
|
||||
'lib.*',
|
||||
'*.so'))
|
||||
torch_cpp_exts.extend(glob(os.path.join(ORTMODULE_TORCH_CPP_DIR,
|
||||
torch_cpp_exts.extend(glob(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR,
|
||||
'build',
|
||||
'lib.*',
|
||||
'*.dll')))
|
||||
torch_cpp_exts.extend(glob(os.path.join(ORTMODULE_TORCH_CPP_DIR,
|
||||
torch_cpp_exts.extend(glob(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR,
|
||||
'build',
|
||||
'lib.*',
|
||||
'*.dylib')))
|
||||
for ext in torch_cpp_exts:
|
||||
dest_ext = os.path.join(ORTMODULE_TORCH_CPP_DIR, os.path.basename(ext))
|
||||
dest_ext = os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, os.path.basename(ext))
|
||||
print(f'Installing {ext} -> {dest_ext}')
|
||||
copyfile(ext, dest_ext)
|
||||
|
||||
|
|
|
|||
|
|
@ -924,7 +924,6 @@ def test_export_correctness_pool2d(pool_type, stride):
|
|||
super(NeuralNetPool2d, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.pool_type = pool_type
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
x = self.conv(input)
|
||||
|
|
@ -1053,8 +1052,8 @@ def test_gradient_correctness_reducesum(dim, keepdim):
|
|||
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
|
||||
|
||||
# Since multinomial is a generator function, we do not have to test for gradient
|
||||
# Two consecutive calls on the torch.multinomail on a probability distribution with more
|
||||
# than one index with non-zero probability(eg, [0, 10, 3, 0]) will not result in
|
||||
# Two consecutive calls on the torch.multinomail on a probability distribution with more
|
||||
# than one index with non-zero probability(eg, [0, 10, 3, 0]) will not result in
|
||||
# the same output. Thus we reset the seed before each call to the op torch.multinomial.
|
||||
@pytest.mark.parametrize("input_shape", ([5], [2,5]))
|
||||
@pytest.mark.parametrize("num_samples, replacement", ((1, False), (2, True)))
|
||||
|
|
@ -1083,8 +1082,8 @@ def test_aten_multinomial(input_shape, num_samples, replacement):
|
|||
ort_input = copy.deepcopy(pt_input)
|
||||
pt_prediction = run_step(pt_model, pt_input)
|
||||
ort_prediction = run_step(ort_model, ort_input)
|
||||
# run the ort prediction again since the first call involves export
|
||||
# and run step, which means the torch.multinomial is called twice in a row without
|
||||
# run the ort prediction again since the first call involves export
|
||||
# and run step, which means the torch.multinomial is called twice in a row without
|
||||
# resetting the generator in between, which will result in a different output
|
||||
ort_prediction = run_step(ort_model, ort_input)
|
||||
|
||||
|
|
@ -1843,7 +1842,7 @@ def test_exception_raised_for_custom_class_return_value_module(device):
|
|||
with pytest.raises(_fallback.ORTModuleIOError) as runtime_error:
|
||||
ort_model(x, y, z)
|
||||
assert 'ORTModule does not support the following model output type' in str(runtime_error.value)
|
||||
|
||||
|
||||
del os.environ['ORTMODULE_SKIPCHECK_POLICY']
|
||||
|
||||
def test_dynamic_axes_config():
|
||||
|
|
@ -2683,7 +2682,7 @@ def test_forward_dynamic_args():
|
|||
assert output is not None
|
||||
hash_args_size3 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
|
||||
assert hash_args_size3 != hash_args_size2
|
||||
|
||||
|
||||
del os.environ['ORTMODULE_SKIPCHECK_POLICY']
|
||||
|
||||
|
||||
|
|
@ -4214,7 +4213,6 @@ def test_sigmoid_grad():
|
|||
_test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad)
|
||||
_test_helpers.assert_values_are_close(ort_loss, pt_loss)
|
||||
|
||||
|
||||
def test_tanh_grad():
|
||||
class NeuralNetTanh(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
|
|
@ -4246,4 +4244,28 @@ def test_tanh_grad():
|
|||
pt_prediction, pt_loss = run_step(pt_model, pt_x)
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad)
|
||||
_test_helpers.assert_values_are_close(ort_loss, pt_loss)
|
||||
_test_helpers.assert_values_are_close(ort_loss, pt_loss)
|
||||
|
||||
@pytest.mark.parametrize("opset_version", [12, 13])
|
||||
def test_opset_version_change(opset_version):
|
||||
device = 'cuda'
|
||||
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
|
||||
ort_model = ORTModule(model)
|
||||
|
||||
# Must import a namespace containing ONNX_OPSET_VERSION, not ONNX_OPSET_VERSION directly
|
||||
from onnxruntime.training import ortmodule
|
||||
ortmodule.ONNX_OPSET_VERSION=opset_version
|
||||
|
||||
# Make sure model runs without any exception
|
||||
prediction = ort_model(x)
|
||||
assert prediction is not None
|
||||
prediction = prediction.sum()
|
||||
prediction.backward()
|
||||
|
||||
# Check opset version on ONNX model
|
||||
exported_model = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model
|
||||
assert exported_model.opset_import[0].version == opset_version
|
||||
|
|
|
|||
Loading…
Reference in a new issue