Fix opset version change by not using copy of global constant (#9393)

This commit is contained in:
Thiago Crepaldi 2021-10-27 12:42:06 -04:00 committed by GitHub
parent b5a652c578
commit 5d5c03bcdc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 60 additions and 37 deletions

View file

@ -18,6 +18,8 @@ from .torch_cpp_extensions import is_installed as is_torch_cpp_extensions_instal
################################################################################
# All global constant goes here, before ORTModule is imported ##################
# NOTE: To *change* values in runtime, import onnxruntime.training.ortmodule and
# assign them new values. Importing them directly do not propagate changes.
################################################################################
ONNX_OPSET_VERSION = 12
MINIMUM_RUNTIME_PYTORCH_VERSION_STR = '1.8.1'

View file

@ -19,7 +19,7 @@ from ._fallback import (_FallbackManager,
ORTModuleTorchModelException,
wrap_exception)
from ._gradient_accumulation_manager import GradientAccumulationManager
from onnxruntime.training.ortmodule import ONNX_OPSET_VERSION
from onnxruntime.training import ortmodule
from onnxruntime.capi import _pybind_state as C
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
@ -99,7 +99,7 @@ class GraphExecutionManager(GraphExecutionInterface):
# indicators of some logic have been executed previously thus could be skipped for faster training
# default is enabled, if not define in os env
self._skip_check = _SkipCheck(_SkipCheck.SKIP_CHECK_DEVICE | _SkipCheck.SKIP_CHECK_BUILD_GRADIENT | _SkipCheck.SKIP_CHECK_EXECUTION_AGENT)
self._skip_check = _SkipCheck(_SkipCheck.SKIP_CHECK_DEVICE | _SkipCheck.SKIP_CHECK_BUILD_GRADIENT | _SkipCheck.SKIP_CHECK_EXECUTION_AGENT)
if os.getenv('ORTMODULE_SKIPCHECK_POLICY') is not None:
self._skip_check = reduce(lambda x, y: x | y,
[_SkipCheck[name] for name in
@ -363,7 +363,7 @@ class GraphExecutionManager(GraphExecutionInterface):
_logger.suppress_os_stream_output(log_level=self._debug_options.logging.log_level):
required_export_kwargs = {'input_names': self._input_info.names,
'output_names': output_names,
'opset_version': ONNX_OPSET_VERSION,
'opset_version': ortmodule.ONNX_OPSET_VERSION,
'do_constant_folding': False,
'training': self._export_mode,
'dynamic_axes': self._input_info.dynamic_axes,

View file

@ -4,7 +4,7 @@
import tempfile
import torch
from ... import ORTModule
from ... import ONNX_OPSET_VERSION
from .... import ortmodule
from ...debug_options import DebugOptions
@ -91,7 +91,7 @@ class HierarchicalORTModule(torch.nn.Module):
try:
with tempfile.NamedTemporaryFile(prefix='sub-module') as temp:
torch.onnx.export(
module, args, temp, opset_version=ONNX_OPSET_VERSION,
module, args, temp, opset_version=ortmodule.ONNX_OPSET_VERSION,
do_constant_folding=False, export_params=False,
keep_initializers_as_inputs=True,
training=torch.onnx.TrainingMode.TRAINING)
@ -125,7 +125,7 @@ class HierarchicalORTModule(torch.nn.Module):
try:
with tempfile.NamedTemporaryFile(prefix='sub-module') as temp:
torch.onnx.export(
module, args, temp, opset_version=ONNX_OPSET_VERSION,
module, args, temp, opset_version=ortmodule.ONNX_OPSET_VERSION,
do_constant_folding=False, export_params=False,
keep_initializers_as_inputs=True,
training=torch.onnx.TrainingMode.TRAINING)

View file

@ -13,9 +13,7 @@ from .debug_options import DebugOptions
from ._fallback import (_FallbackManager,
_FallbackPolicy,
ORTModuleFallbackException)
from . import (_FALLBACK_INIT_EXCEPTION,
ORTMODULE_FALLBACK_POLICY,
ORTMODULE_FALLBACK_RETRY)
from onnxruntime.training import ortmodule
from onnxruntime.tools import pytorch_export_contrib_ops
@ -60,13 +58,13 @@ class ORTModule(torch.nn.Module):
# Fallback settings
self._fallback_manager = _FallbackManager(pytorch_module=module,
policy=ORTMODULE_FALLBACK_POLICY,
retry=ORTMODULE_FALLBACK_RETRY)
policy=ortmodule.ORTMODULE_FALLBACK_POLICY,
retry=ortmodule.ORTMODULE_FALLBACK_RETRY)
try:
# Read ORTModule module initialization status
if _FALLBACK_INIT_EXCEPTION:
raise _FALLBACK_INIT_EXCEPTION
if ortmodule._FALLBACK_INIT_EXCEPTION:
raise ortmodule._FALLBACK_INIT_EXCEPTION
super(ORTModule, self).__init__()

View file

@ -3,9 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from onnxruntime.training.ortmodule import (ORTMODULE_TORCH_CPP_DIR,
ONNXRUNTIME_CUDA_VERSION,
ONNXRUNTIME_ROCM_VERSION)
from onnxruntime.training import ortmodule
from glob import glob
from shutil import copyfile
@ -24,11 +22,11 @@ def _list_extensions(path):
def _list_cpu_extensions():
return _list_extensions(os.path.join(ORTMODULE_TORCH_CPP_DIR, 'cpu'))
return _list_extensions(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, 'cpu'))
def _list_cuda_extensions():
return _list_extensions(os.path.join(ORTMODULE_TORCH_CPP_DIR, 'cuda'))
return _list_extensions(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, 'cuda'))
def _install_extension(ext_name, ext_path, cwd):
@ -44,12 +42,15 @@ def build_torch_cpp_extensions():
'''Builds PyTorch CPP extensions and returns metadata'''
# Run this from within onnxruntime package folder
is_gpu_available = ONNXRUNTIME_CUDA_VERSION is not None or ONNXRUNTIME_ROCM_VERSION is not None
os.chdir(ORTMODULE_TORCH_CPP_DIR)
is_gpu_available = ortmodule.ONNXRUNTIME_CUDA_VERSION is not None or\
ortmodule.ONNXRUNTIME_ROCM_VERSION is not None
os.chdir(ortmodule.ORTMODULE_TORCH_CPP_DIR)
# Extensions might leverage CUDA/ROCM versions internally
os.environ["ONNXRUNTIME_CUDA_VERSION"] = ONNXRUNTIME_CUDA_VERSION if not ONNXRUNTIME_CUDA_VERSION is None else ''
os.environ["ONNXRUNTIME_ROCM_VERSION"] = ONNXRUNTIME_ROCM_VERSION if not ONNXRUNTIME_ROCM_VERSION is None else ''
os.environ["ONNXRUNTIME_CUDA_VERSION"] = ortmodule.ONNXRUNTIME_CUDA_VERSION \
if not ortmodule.ONNXRUNTIME_CUDA_VERSION is None else ''
os.environ["ONNXRUNTIME_ROCM_VERSION"] = ortmodule.ONNXRUNTIME_ROCM_VERSION \
if not ortmodule.ONNXRUNTIME_ROCM_VERSION is None else ''
############################################################################
# Pytorch CPP Extensions that DO require CUDA/ROCM
@ -57,32 +58,32 @@ def build_torch_cpp_extensions():
if is_gpu_available:
for ext_setup in _list_cuda_extensions():
_install_extension(ext_setup.split(
os.sep)[-2], ext_setup, ORTMODULE_TORCH_CPP_DIR)
os.sep)[-2], ext_setup, ortmodule.ORTMODULE_TORCH_CPP_DIR)
############################################################################
# Pytorch CPP Extensions that DO NOT require CUDA/ROCM
############################################################################
for ext_setup in _list_cpu_extensions():
_install_extension(ext_setup.split(
os.sep)[-2], ext_setup, ORTMODULE_TORCH_CPP_DIR)
os.sep)[-2], ext_setup, ortmodule.ORTMODULE_TORCH_CPP_DIR)
############################################################################
# Install Pytorch CPP Extensions into local onnxruntime package folder
############################################################################
torch_cpp_exts = glob(os.path.join(ORTMODULE_TORCH_CPP_DIR,
torch_cpp_exts = glob(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR,
'build',
'lib.*',
'*.so'))
torch_cpp_exts.extend(glob(os.path.join(ORTMODULE_TORCH_CPP_DIR,
torch_cpp_exts.extend(glob(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR,
'build',
'lib.*',
'*.dll')))
torch_cpp_exts.extend(glob(os.path.join(ORTMODULE_TORCH_CPP_DIR,
torch_cpp_exts.extend(glob(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR,
'build',
'lib.*',
'*.dylib')))
for ext in torch_cpp_exts:
dest_ext = os.path.join(ORTMODULE_TORCH_CPP_DIR, os.path.basename(ext))
dest_ext = os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, os.path.basename(ext))
print(f'Installing {ext} -> {dest_ext}')
copyfile(ext, dest_ext)

View file

@ -924,7 +924,6 @@ def test_export_correctness_pool2d(pool_type, stride):
super(NeuralNetPool2d, self).__init__()
self.conv = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.pool_type = pool_type
def forward(self, input):
x = self.conv(input)
@ -1053,8 +1052,8 @@ def test_gradient_correctness_reducesum(dim, keepdim):
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
# Since multinomial is a generator function, we do not have to test for gradient
# Two consecutive calls on the torch.multinomail on a probability distribution with more
# than one index with non-zero probability(eg, [0, 10, 3, 0]) will not result in
# Two consecutive calls on the torch.multinomail on a probability distribution with more
# than one index with non-zero probability(eg, [0, 10, 3, 0]) will not result in
# the same output. Thus we reset the seed before each call to the op torch.multinomial.
@pytest.mark.parametrize("input_shape", ([5], [2,5]))
@pytest.mark.parametrize("num_samples, replacement", ((1, False), (2, True)))
@ -1083,8 +1082,8 @@ def test_aten_multinomial(input_shape, num_samples, replacement):
ort_input = copy.deepcopy(pt_input)
pt_prediction = run_step(pt_model, pt_input)
ort_prediction = run_step(ort_model, ort_input)
# run the ort prediction again since the first call involves export
# and run step, which means the torch.multinomial is called twice in a row without
# run the ort prediction again since the first call involves export
# and run step, which means the torch.multinomial is called twice in a row without
# resetting the generator in between, which will result in a different output
ort_prediction = run_step(ort_model, ort_input)
@ -1843,7 +1842,7 @@ def test_exception_raised_for_custom_class_return_value_module(device):
with pytest.raises(_fallback.ORTModuleIOError) as runtime_error:
ort_model(x, y, z)
assert 'ORTModule does not support the following model output type' in str(runtime_error.value)
del os.environ['ORTMODULE_SKIPCHECK_POLICY']
def test_dynamic_axes_config():
@ -2683,7 +2682,7 @@ def test_forward_dynamic_args():
assert output is not None
hash_args_size3 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema))
assert hash_args_size3 != hash_args_size2
del os.environ['ORTMODULE_SKIPCHECK_POLICY']
@ -4214,7 +4213,6 @@ def test_sigmoid_grad():
_test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad)
_test_helpers.assert_values_are_close(ort_loss, pt_loss)
def test_tanh_grad():
class NeuralNetTanh(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
@ -4246,4 +4244,28 @@ def test_tanh_grad():
pt_prediction, pt_loss = run_step(pt_model, pt_x)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad)
_test_helpers.assert_values_are_close(ort_loss, pt_loss)
_test_helpers.assert_values_are_close(ort_loss, pt_loss)
@pytest.mark.parametrize("opset_version", [12, 13])
def test_opset_version_change(opset_version):
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
x = torch.randn(N, D_in, device=device)
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
ort_model = ORTModule(model)
# Must import a namespace containing ONNX_OPSET_VERSION, not ONNX_OPSET_VERSION directly
from onnxruntime.training import ortmodule
ortmodule.ONNX_OPSET_VERSION=opset_version
# Make sure model runs without any exception
prediction = ort_model(x)
assert prediction is not None
prediction = prediction.sum()
prediction.backward()
# Check opset version on ONNX model
exported_model = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model
assert exported_model.opset_import[0].version == opset_version