diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 452fb8a09a..cb41b8b86d 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -19,11 +19,11 @@ endif() file(GLOB onnxruntime_pybind_srcs CONFIGURE_DEPENDS ${onnxruntime_pybind_srcs_pattern} ) - + if(NOT onnxruntime_PYBIND_EXPORT_OPSCHEMA) list(REMOVE_ITEM onnxruntime_pybind_srcs ${ONNXRUNTIME_ROOT}/python/onnxruntime_pybind_schema.cc) endif() - + if(onnxruntime_ENABLE_TRAINING) list(REMOVE_ITEM onnxruntime_pybind_srcs ${ONNXRUNTIME_ROOT}/python/onnxruntime_pybind_module.cc) endif() @@ -38,11 +38,11 @@ if (onnxruntime_ENABLE_EAGER_MODE) ) if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) - list(APPEND onnxruntime_eager_extension_srcs + list(APPEND onnxruntime_eager_extension_srcs "${ORTTRAINING_ROOT}/orttraining/core/framework/torch/dlpack_python.cc") endif() - list(APPEND onnxruntime_pybind_srcs + list(APPEND onnxruntime_pybind_srcs ${onnxruntime_eager_extension_srcs}) endif() @@ -286,9 +286,6 @@ if (onnxruntime_ENABLE_TRAINING) file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_torch_gpu_allocator_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/*" ) - file(GLOB onnxruntime_python_train_tools_srcs CONFIGURE_DEPENDS - "${REPO_ROOT}/tools/python/register_custom_ops_pytorch_exporter.py" - ) else() file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/training/*.py" @@ -557,9 +554,6 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ortmodule_torch_cpp_ext_torch_gpu_allocator_srcs} $/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_train_tools_srcs} - $/onnxruntime/training/ ) endif() diff --git a/tools/python/register_custom_ops_pytorch_exporter.py b/onnxruntime/python/tools/pytorch_export_contrib_ops.py similarity index 60% rename from tools/python/register_custom_ops_pytorch_exporter.py rename to onnxruntime/python/tools/pytorch_export_contrib_ops.py index 04ecca8693..d217c1086f 100644 --- a/tools/python/register_custom_ops_pytorch_exporter.py +++ b/onnxruntime/python/tools/pytorch_export_contrib_ops.py @@ -1,21 +1,38 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# -# Register pytorch symbolic for export using ONNX Runtime contrib ops -from torch.onnx import register_custom_op_symbolic +""" +Support for registering ONNX Runtime's built-in contrib ops with +PyTorch-ONNX exporter (torch.onnx.export). +""" + +import typing + +try: + from torch.onnx import register_custom_op_symbolic +except ModuleNotFoundError: + raise ModuleNotFoundError( + "This module is only useful in combination with PyTorch. " + "To install PyTorch see https://pytorch.org/.") import torch.onnx.symbolic_helper as sym_help +import torch.onnx.symbolic_registry as sym_registry -_onnx_opset_version = 1 +_OPSET_VERSION = 1 +_registered_ops: typing.AbstractSet[str] = set() -def register_custom_op(): - """ - This function registers symbolic functions for - custom ops that are implemented as part of ONNX Runtime +def _reg(symbolic_fn: typing.Callable): + name = "::%s" % symbolic_fn.__name__ + register_custom_op_symbolic(name, symbolic_fn, _OPSET_VERSION) + _registered_ops.add(name) + + +def register(): + """Register ONNX Runtime's built-in contrib ops. + + Should be run before torch.onnx.export(). """ - # Symbolic definition def grid_sample(g, input, grid, mode, padding_mode, align_corners): # mode # 'bilinear' : onnx::Constant[value={0}] @@ -42,46 +59,33 @@ def register_custom_op(): mode_s=mode_str, padding_mode_s=padding_mode_str, align_corners_i=align_corners) + _reg(grid_sample) def inverse(g, self): return g.op("com.microsoft::Inverse", self).setType(self.type()) + _reg(inverse) def gelu(g, self): return g.op("com.microsoft::Gelu", self).setType(self.type()) + _reg(gelu) def triu(g, self, diagonal): return g.op("com.microsoft::Trilu", self, diagonal, upper_i=1).setType(self.type()) + _reg(triu) def tril(g, self, diagonal): return g.op("com.microsoft::Trilu", self, diagonal, upper_i=0).setType(self.type()) - - # Op Registration - register_custom_op_symbolic('::grid_sampler', grid_sample, _onnx_opset_version) - register_custom_op_symbolic('::inverse', inverse, _onnx_opset_version) - register_custom_op_symbolic('::gelu', gelu, _onnx_opset_version) - register_custom_op_symbolic('::triu', triu, _onnx_opset_version) - register_custom_op_symbolic('::tril', tril, _onnx_opset_version) + _reg(tril) -def unregister_custom_op(): - """ - This function unregisters symbolic functions for - custom ops that are implemented as part of ONNX Runtime - """ - - import torch.onnx.symbolic_registry as sym_registry +def unregister(): + """Unregister ONNX Runtime's built-in contrib ops.""" # TODO: replace this once PyTorch supports unregister natively. - def unregister(name, opset_version): + # https://msdata.visualstudio.com/Vienna/_workitems/edit/1342343 + for name in _registered_ops: ns, kind = name.split("::") - from torch.onnx.symbolic_helper import _onnx_stable_opsets - - for version in _onnx_stable_opsets: - if version >= opset_version and sym_registry.is_registered_op(kind, ns, version): + for version in sym_help._onnx_stable_opsets: + if (version >= _OPSET_VERSION and + sym_registry.is_registered_op(kind, ns, version)): del sym_registry._registry[(ns, version)][kind] - - unregister('::grid_sampler', _onnx_opset_version) - unregister('::inverse', _onnx_opset_version) - unregister('::gelu', _onnx_opset_version) - unregister('::triu', _onnx_opset_version) - unregister('::tril', _onnx_opset_version) diff --git a/tools/test/test_custom_ops_pytorch_exporter.py b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py similarity index 79% rename from tools/test/test_custom_ops_pytorch_exporter.py rename to onnxruntime/test/python/test_pytorch_export_contrib_ops.py index 035dc6c59a..b115cb23f0 100644 --- a/tools/test/test_custom_ops_pytorch_exporter.py +++ b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py @@ -1,15 +1,15 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# -# Test export of pytorch operators using ONNX Runtime contrib ops + +"""Test export of PyTorch operators using ONNX Runtime contrib ops.""" import torch import onnxruntime +from onnxruntime.tools import pytorch_export_contrib_ops import numpy as np import unittest import io import copy -from python.register_custom_ops_pytorch_exporter import register_custom_op def ort_test_with_input(ort_sess, input, output, rtol, atol): @@ -35,9 +35,8 @@ def ort_test_with_input(ort_sess, input, output, rtol, atol): [np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)] -# These set of tests verify ONNX model export and compare onnxruntime outputs to pytorch. -# To register custom ops and run the tests, you should set PYTHONPATH as: -# PYTHONPATH= python -m pytest -v test_custom_ops_pytorch_exporter.py +# These set of tests verify ONNX model export and compares outputs between +# PyTorch and ORT. class ONNXExporterTest(unittest.TestCase): from torch.onnx.symbolic_helper import _export_onnx_opset_version opset_version = _export_onnx_opset_version @@ -45,7 +44,7 @@ class ONNXExporterTest(unittest.TestCase): def setUp(self): torch.manual_seed(0) - register_custom_op() + pytorch_export_contrib_ops.register() def run_test(self, model, input=None, custom_opsets=None, @@ -103,12 +102,12 @@ class ONNXExporterTest(unittest.TestCase): return torch.inverse(x) + x x = torch.randn(2, 3, 3) - self.run_test(CustomInverse(), x, custom_opsets={'com.microsoft': 1}) + self.run_test(CustomInverse(), x, custom_opsets={"com.microsoft": 1}) def test_gelu(self): model = torch.nn.GELU() x = torch.randn(3, 3) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) def test_triu(self): for i in range(-5, 5): @@ -118,13 +117,13 @@ class ONNXExporterTest(unittest.TestCase): model = Module() x = torch.randn(5, 4, 7, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(5, 4, 0, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(5, 0, 0, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) for i in range(-5, 5): class Module2D(torch.nn.Module): @@ -133,13 +132,13 @@ class ONNXExporterTest(unittest.TestCase): model = Module2D() x = torch.randn(4, 7, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(0, 7, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(0, 0, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) def test_tril(self): for i in range(-5, 5): @@ -149,13 +148,13 @@ class ONNXExporterTest(unittest.TestCase): model = Module() x = torch.randn(5, 4, 7, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(5, 4, 0, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(5, 0, 0, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) for i in range(-5, 5): class Module2D(torch.nn.Module): @@ -164,13 +163,13 @@ class ONNXExporterTest(unittest.TestCase): model = Module2D() x = torch.randn(4, 7, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(0, 7, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(0, 0, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) # opset 9 tests, with keep_initializers_as_inputs=False for @@ -181,5 +180,5 @@ ONNXExporterTest_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"), keep_initializers_as_inputs=False)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index 82b65d7813..5d1046ad34 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -355,8 +355,8 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op sample_inputs_copy = copy.deepcopy(sample_inputs) # Enable contrib ops export from PyTorch - from onnxruntime.training import register_custom_ops_pytorch_exporter - register_custom_ops_pytorch_exporter.register_custom_op() + from onnxruntime.tools import pytorch_export_contrib_ops + pytorch_export_contrib_ops.register() torch.onnx._export(model, tuple(sample_inputs_copy), f, input_names=input_names, diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 61d1e951a6..964582c107 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -10,7 +10,7 @@ from ._custom_gradient_registry import CustomGradientRegistry from .debug_options import DebugOptions from ._fallback import _FallbackManager, _FallbackPolicy, ORTModuleFallbackException, ORTModuleTorchModelException, wrap_exception from . import _FALLBACK_INIT_EXCEPTION, MINIMUM_RUNTIME_PYTORCH_VERSION_STR, ORTMODULE_FALLBACK_POLICY, ORTMODULE_FALLBACK_RETRY -from onnxruntime.training import register_custom_ops_pytorch_exporter +from onnxruntime.tools import pytorch_export_contrib_ops import functools import torch @@ -71,7 +71,7 @@ class ORTModule(torch.nn.Module): super(ORTModule, self).__init__() # Support contrib OPs - register_custom_ops_pytorch_exporter.register_custom_op() + pytorch_export_contrib_ops.register() CustomOpSymbolicRegistry.register_all() CustomGradientRegistry.register_all() diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 2d371a891a..b2f6cad907 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -120,8 +120,8 @@ class ORTTrainer(object): ort_trainer = ORTTrainer(model, model_desc, optim_config, loss_fn) """ - def __init__(self, model, model_desc, optim_config, - loss_fn=None, + def __init__(self, model, model_desc, optim_config, + loss_fn=None, options=None): assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model" assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'" @@ -532,13 +532,12 @@ class ORTTrainer(object): sample_inputs_copy = copy.deepcopy(sample_inputs) # Handle contrib OPs support - from onnxruntime.training import register_custom_ops_pytorch_exporter + from onnxruntime.tools import pytorch_export_contrib_ops if self.options._internal_use.enable_onnx_contrib_ops: - # Enable contrib ops export from PyTorch - register_custom_ops_pytorch_exporter.register_custom_op() + pytorch_export_contrib_ops.register() else: - # Unregister contrib ops, if they were registered in previous calls - register_custom_ops_pytorch_exporter.unregister_custom_op() + # Unregister in case they were registered in previous calls. + pytorch_export_contrib_ops.unregister() # Export torch.nn.Module to ONNX torch.onnx._export(model, tuple(sample_inputs_copy), f, @@ -566,9 +565,9 @@ class ORTTrainer(object): return onnx_model - def _create_ort_training_session(self, - optimizer_state_dict={}, - session_options=None, + def _create_ort_training_session(self, + optimizer_state_dict={}, + session_options=None, provider_options=None): # Validating frozen_weights names unused_frozen_weights = [n for n in self.options.utils.frozen_weights\ @@ -622,7 +621,7 @@ class ORTTrainer(object): self.options.distributed.horizontal_parallel_size = max(self.options.distributed.horizontal_parallel_size, 1) self.options.distributed.data_parallel_size = self.options.distributed.world_size // self.options.distributed.horizontal_parallel_size - + # TrainingParameters ort_parameters = ort.TrainingParameters() ort_parameters.loss_output_name = loss_name @@ -753,7 +752,7 @@ class ORTTrainer(object): # Create training session used by train_step # pass all optimizer states to the backend self._create_ort_training_session(optimizer_state_dict, - session_options=session_options, + session_options=session_options, provider_options=provider_options) # Update model description to update dtype when mixed precision is enabled @@ -880,8 +879,8 @@ class ORTTrainer(object): # This prevents CPU -> GPU -> CPU copies between frontend and backend target_device = 'cpu' # the self.options.device may be a device that pytorch does not recognize. - # in that case, we temporary prefer to leave the input/output on CPU and let ORT session - # to move the data between device and host. + # in that case, we temporary prefer to leave the input/output on CPU and let ORT session + # to move the data between device and host. # so output will be on the same device as input. try: test_pt_device = torch.device(target_device) @@ -889,7 +888,7 @@ class ORTTrainer(object): #in this case, input/output must on CPU assert(input.device.type == 'cpu') target_device = 'cpu' - + torch_tensor = torch.zeros(output_desc.shape, device=target_device, dtype=output_desc.dtype_amp if output_desc.dtype_amp else output_desc.dtype) iobinding.bind_output(output_desc.name, torch_tensor.device.type, _utils.get_device_index(target_device), @@ -1282,7 +1281,7 @@ class ORTTrainer(object): if self._training_session: current_state_dict = self.state_dict() if strict: - # for Zero enabled, the current trainer might not have the complete state, and we must allow + # for Zero enabled, the current trainer might not have the complete state, and we must allow # extra keys to be present in the state dict allow_unexpected = True if self.options.distributed.deepspeed_zero_optimization.stage > 0 else False _check_key_mismatch(current_state_dict, state_dict, allow_unexpected) @@ -1360,7 +1359,7 @@ class ORTTrainer(object): def _aggregation_required(self, loaded_trainer_options): """Checks if aggregation is required for the loading the state_dict into the ORTTrainer""" - # To load states in the backend, aggregation is required for every ZeRO + # To load states in the backend, aggregation is required for every ZeRO # or Megatron checkpoint return loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 or \ loaded_trainer_options[_utils.state_dict_trainer_options_horizontal_parallel_size_key()] > 1 diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 9f6902222c..d96418cc7e 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1436,6 +1436,9 @@ def run_training_python_frontend_tests(cwd): run_subprocess([ sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_orttrainer_checkpoint_functions.py'], cwd=cwd) + # Not technically training related, but it needs torch to be installed. + run_subprocess([ + sys.executable, '-m', 'pytest', '-sv', 'test_pytorch_export_contrib_ops.py'], cwd=cwd) def run_training_python_frontend_e2e_tests(cwd):