diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index b87bee3bb5..258fc68906 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -18,6 +18,8 @@ from .torch_cpp_extensions import is_installed as is_torch_cpp_extensions_instal ################################################################################ # All global constant goes here, before ORTModule is imported ################## +# NOTE: To *change* values in runtime, import onnxruntime.training.ortmodule and +# assign them new values. Importing them directly do not propagate changes. ################################################################################ ONNX_OPSET_VERSION = 12 MINIMUM_RUNTIME_PYTORCH_VERSION_STR = '1.8.1' diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index c4b71f19f9..fbd27af64c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -19,7 +19,7 @@ from ._fallback import (_FallbackManager, ORTModuleTorchModelException, wrap_exception) from ._gradient_accumulation_manager import GradientAccumulationManager -from onnxruntime.training.ortmodule import ONNX_OPSET_VERSION +from onnxruntime.training import ortmodule from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference @@ -99,7 +99,7 @@ class GraphExecutionManager(GraphExecutionInterface): # indicators of some logic have been executed previously thus could be skipped for faster training # default is enabled, if not define in os env - self._skip_check = _SkipCheck(_SkipCheck.SKIP_CHECK_DEVICE | _SkipCheck.SKIP_CHECK_BUILD_GRADIENT | _SkipCheck.SKIP_CHECK_EXECUTION_AGENT) + self._skip_check = _SkipCheck(_SkipCheck.SKIP_CHECK_DEVICE | _SkipCheck.SKIP_CHECK_BUILD_GRADIENT | _SkipCheck.SKIP_CHECK_EXECUTION_AGENT) if os.getenv('ORTMODULE_SKIPCHECK_POLICY') is not None: self._skip_check = reduce(lambda x, y: x | y, [_SkipCheck[name] for name in @@ -363,7 +363,7 @@ class GraphExecutionManager(GraphExecutionInterface): _logger.suppress_os_stream_output(log_level=self._debug_options.logging.log_level): required_export_kwargs = {'input_names': self._input_info.names, 'output_names': output_names, - 'opset_version': ONNX_OPSET_VERSION, + 'opset_version': ortmodule.ONNX_OPSET_VERSION, 'do_constant_folding': False, 'training': self._export_mode, 'dynamic_axes': self._input_info.dynamic_axes, diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py b/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py index a1af396c02..5c79c68f37 100644 --- a/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py @@ -4,7 +4,7 @@ import tempfile import torch from ... import ORTModule -from ... import ONNX_OPSET_VERSION +from .... import ortmodule from ...debug_options import DebugOptions @@ -91,7 +91,7 @@ class HierarchicalORTModule(torch.nn.Module): try: with tempfile.NamedTemporaryFile(prefix='sub-module') as temp: torch.onnx.export( - module, args, temp, opset_version=ONNX_OPSET_VERSION, + module, args, temp, opset_version=ortmodule.ONNX_OPSET_VERSION, do_constant_folding=False, export_params=False, keep_initializers_as_inputs=True, training=torch.onnx.TrainingMode.TRAINING) @@ -125,7 +125,7 @@ class HierarchicalORTModule(torch.nn.Module): try: with tempfile.NamedTemporaryFile(prefix='sub-module') as temp: torch.onnx.export( - module, args, temp, opset_version=ONNX_OPSET_VERSION, + module, args, temp, opset_version=ortmodule.ONNX_OPSET_VERSION, do_constant_folding=False, export_params=False, keep_initializers_as_inputs=True, training=torch.onnx.TrainingMode.TRAINING) diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 1c92ffcaaf..35d0306bd1 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -13,9 +13,7 @@ from .debug_options import DebugOptions from ._fallback import (_FallbackManager, _FallbackPolicy, ORTModuleFallbackException) -from . import (_FALLBACK_INIT_EXCEPTION, - ORTMODULE_FALLBACK_POLICY, - ORTMODULE_FALLBACK_RETRY) +from onnxruntime.training import ortmodule from onnxruntime.tools import pytorch_export_contrib_ops @@ -60,13 +58,13 @@ class ORTModule(torch.nn.Module): # Fallback settings self._fallback_manager = _FallbackManager(pytorch_module=module, - policy=ORTMODULE_FALLBACK_POLICY, - retry=ORTMODULE_FALLBACK_RETRY) + policy=ortmodule.ORTMODULE_FALLBACK_POLICY, + retry=ortmodule.ORTMODULE_FALLBACK_RETRY) try: # Read ORTModule module initialization status - if _FALLBACK_INIT_EXCEPTION: - raise _FALLBACK_INIT_EXCEPTION + if ortmodule._FALLBACK_INIT_EXCEPTION: + raise ortmodule._FALLBACK_INIT_EXCEPTION super(ORTModule, self).__init__() diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py index 1ff667a777..7511266e7a 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py @@ -3,9 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from onnxruntime.training.ortmodule import (ORTMODULE_TORCH_CPP_DIR, - ONNXRUNTIME_CUDA_VERSION, - ONNXRUNTIME_ROCM_VERSION) +from onnxruntime.training import ortmodule from glob import glob from shutil import copyfile @@ -24,11 +22,11 @@ def _list_extensions(path): def _list_cpu_extensions(): - return _list_extensions(os.path.join(ORTMODULE_TORCH_CPP_DIR, 'cpu')) + return _list_extensions(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, 'cpu')) def _list_cuda_extensions(): - return _list_extensions(os.path.join(ORTMODULE_TORCH_CPP_DIR, 'cuda')) + return _list_extensions(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, 'cuda')) def _install_extension(ext_name, ext_path, cwd): @@ -44,12 +42,15 @@ def build_torch_cpp_extensions(): '''Builds PyTorch CPP extensions and returns metadata''' # Run this from within onnxruntime package folder - is_gpu_available = ONNXRUNTIME_CUDA_VERSION is not None or ONNXRUNTIME_ROCM_VERSION is not None - os.chdir(ORTMODULE_TORCH_CPP_DIR) + is_gpu_available = ortmodule.ONNXRUNTIME_CUDA_VERSION is not None or\ + ortmodule.ONNXRUNTIME_ROCM_VERSION is not None + os.chdir(ortmodule.ORTMODULE_TORCH_CPP_DIR) # Extensions might leverage CUDA/ROCM versions internally - os.environ["ONNXRUNTIME_CUDA_VERSION"] = ONNXRUNTIME_CUDA_VERSION if not ONNXRUNTIME_CUDA_VERSION is None else '' - os.environ["ONNXRUNTIME_ROCM_VERSION"] = ONNXRUNTIME_ROCM_VERSION if not ONNXRUNTIME_ROCM_VERSION is None else '' + os.environ["ONNXRUNTIME_CUDA_VERSION"] = ortmodule.ONNXRUNTIME_CUDA_VERSION \ + if not ortmodule.ONNXRUNTIME_CUDA_VERSION is None else '' + os.environ["ONNXRUNTIME_ROCM_VERSION"] = ortmodule.ONNXRUNTIME_ROCM_VERSION \ + if not ortmodule.ONNXRUNTIME_ROCM_VERSION is None else '' ############################################################################ # Pytorch CPP Extensions that DO require CUDA/ROCM @@ -57,32 +58,32 @@ def build_torch_cpp_extensions(): if is_gpu_available: for ext_setup in _list_cuda_extensions(): _install_extension(ext_setup.split( - os.sep)[-2], ext_setup, ORTMODULE_TORCH_CPP_DIR) + os.sep)[-2], ext_setup, ortmodule.ORTMODULE_TORCH_CPP_DIR) ############################################################################ # Pytorch CPP Extensions that DO NOT require CUDA/ROCM ############################################################################ for ext_setup in _list_cpu_extensions(): _install_extension(ext_setup.split( - os.sep)[-2], ext_setup, ORTMODULE_TORCH_CPP_DIR) + os.sep)[-2], ext_setup, ortmodule.ORTMODULE_TORCH_CPP_DIR) ############################################################################ # Install Pytorch CPP Extensions into local onnxruntime package folder ############################################################################ - torch_cpp_exts = glob(os.path.join(ORTMODULE_TORCH_CPP_DIR, + torch_cpp_exts = glob(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, 'build', 'lib.*', '*.so')) - torch_cpp_exts.extend(glob(os.path.join(ORTMODULE_TORCH_CPP_DIR, + torch_cpp_exts.extend(glob(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, 'build', 'lib.*', '*.dll'))) - torch_cpp_exts.extend(glob(os.path.join(ORTMODULE_TORCH_CPP_DIR, + torch_cpp_exts.extend(glob(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, 'build', 'lib.*', '*.dylib'))) for ext in torch_cpp_exts: - dest_ext = os.path.join(ORTMODULE_TORCH_CPP_DIR, os.path.basename(ext)) + dest_ext = os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, os.path.basename(ext)) print(f'Installing {ext} -> {dest_ext}') copyfile(ext, dest_ext) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 5739cc0bd9..6e4545f17f 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -924,7 +924,6 @@ def test_export_correctness_pool2d(pool_type, stride): super(NeuralNetPool2d, self).__init__() self.conv = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.pool_type = pool_type - def forward(self, input): x = self.conv(input) @@ -1053,8 +1052,8 @@ def test_gradient_correctness_reducesum(dim, keepdim): _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) # Since multinomial is a generator function, we do not have to test for gradient -# Two consecutive calls on the torch.multinomail on a probability distribution with more -# than one index with non-zero probability(eg, [0, 10, 3, 0]) will not result in +# Two consecutive calls on the torch.multinomail on a probability distribution with more +# than one index with non-zero probability(eg, [0, 10, 3, 0]) will not result in # the same output. Thus we reset the seed before each call to the op torch.multinomial. @pytest.mark.parametrize("input_shape", ([5], [2,5])) @pytest.mark.parametrize("num_samples, replacement", ((1, False), (2, True))) @@ -1083,8 +1082,8 @@ def test_aten_multinomial(input_shape, num_samples, replacement): ort_input = copy.deepcopy(pt_input) pt_prediction = run_step(pt_model, pt_input) ort_prediction = run_step(ort_model, ort_input) - # run the ort prediction again since the first call involves export - # and run step, which means the torch.multinomial is called twice in a row without + # run the ort prediction again since the first call involves export + # and run step, which means the torch.multinomial is called twice in a row without # resetting the generator in between, which will result in a different output ort_prediction = run_step(ort_model, ort_input) @@ -1843,7 +1842,7 @@ def test_exception_raised_for_custom_class_return_value_module(device): with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: ort_model(x, y, z) assert 'ORTModule does not support the following model output type' in str(runtime_error.value) - + del os.environ['ORTMODULE_SKIPCHECK_POLICY'] def test_dynamic_axes_config(): @@ -2683,7 +2682,7 @@ def test_forward_dynamic_args(): assert output is not None hash_args_size3 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) assert hash_args_size3 != hash_args_size2 - + del os.environ['ORTMODULE_SKIPCHECK_POLICY'] @@ -4214,7 +4213,6 @@ def test_sigmoid_grad(): _test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad) _test_helpers.assert_values_are_close(ort_loss, pt_loss) - def test_tanh_grad(): class NeuralNetTanh(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -4246,4 +4244,28 @@ def test_tanh_grad(): pt_prediction, pt_loss = run_step(pt_model, pt_x) _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad) - _test_helpers.assert_values_are_close(ort_loss, pt_loss) \ No newline at end of file + _test_helpers.assert_values_are_close(ort_loss, pt_loss) + +@pytest.mark.parametrize("opset_version", [12, 13]) +def test_opset_version_change(opset_version): + device = 'cuda' + + N, D_in, H, D_out = 64, 784, 500, 10 + x = torch.randn(N, D_in, device=device) + model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + + ort_model = ORTModule(model) + + # Must import a namespace containing ONNX_OPSET_VERSION, not ONNX_OPSET_VERSION directly + from onnxruntime.training import ortmodule + ortmodule.ONNX_OPSET_VERSION=opset_version + + # Make sure model runs without any exception + prediction = ort_model(x) + assert prediction is not None + prediction = prediction.sum() + prediction.backward() + + # Check opset version on ONNX model + exported_model = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model + assert exported_model.opset_import[0].version == opset_version