From 47435311f43cf587b40da20dc33d94bbf950e1d4 Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Thu, 2 Sep 2021 14:26:58 -0700 Subject: [PATCH] Include pytorch_export_contrib_ops in inference builds (#8878) * Include pytorch_export_contrib_ops in inference builds Rename / move it from tools/python/register_custom_ops_pytorch_exporter to onnxruntime/python/tools/pytorch_export_contrib_ops. Rationale for inclusion in inference builds: This code is potentially useful for anyone using ORT, not just training. Rationale for new name: "Contrib op" is the nomenclature used within ORT to refer to the set of ops that are not in the standard op set but are included by default with ORT. This is more specific than "custom op", which is what the PyTorch exporter uses to refer to any non-standard op. Step 1 of addressing #8818. After this is merged I will update the docs. * Enable test_pytorch_export_contrib_ops.py in CI Fixes AB#1342330 --- cmake/onnxruntime_python.cmake | 14 ++-- .../tools/pytorch_export_contrib_ops.py | 72 ++++++++++--------- .../python/test_pytorch_export_contrib_ops.py | 43 ++++++----- orttraining/orttraining/python/ort_trainer.py | 4 +- .../python/training/ortmodule/ortmodule.py | 4 +- .../orttraining/python/training/orttrainer.py | 33 +++++---- tools/ci_build/build.py | 3 + 7 files changed, 86 insertions(+), 87 deletions(-) rename tools/python/register_custom_ops_pytorch_exporter.py => onnxruntime/python/tools/pytorch_export_contrib_ops.py (60%) rename tools/test/test_custom_ops_pytorch_exporter.py => onnxruntime/test/python/test_pytorch_export_contrib_ops.py (79%) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 452fb8a09a..cb41b8b86d 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -19,11 +19,11 @@ endif() file(GLOB onnxruntime_pybind_srcs CONFIGURE_DEPENDS ${onnxruntime_pybind_srcs_pattern} ) - + if(NOT onnxruntime_PYBIND_EXPORT_OPSCHEMA) list(REMOVE_ITEM onnxruntime_pybind_srcs ${ONNXRUNTIME_ROOT}/python/onnxruntime_pybind_schema.cc) endif() - + if(onnxruntime_ENABLE_TRAINING) list(REMOVE_ITEM onnxruntime_pybind_srcs ${ONNXRUNTIME_ROOT}/python/onnxruntime_pybind_module.cc) endif() @@ -38,11 +38,11 @@ if (onnxruntime_ENABLE_EAGER_MODE) ) if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) - list(APPEND onnxruntime_eager_extension_srcs + list(APPEND onnxruntime_eager_extension_srcs "${ORTTRAINING_ROOT}/orttraining/core/framework/torch/dlpack_python.cc") endif() - list(APPEND onnxruntime_pybind_srcs + list(APPEND onnxruntime_pybind_srcs ${onnxruntime_eager_extension_srcs}) endif() @@ -286,9 +286,6 @@ if (onnxruntime_ENABLE_TRAINING) file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_torch_gpu_allocator_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/*" ) - file(GLOB onnxruntime_python_train_tools_srcs CONFIGURE_DEPENDS - "${REPO_ROOT}/tools/python/register_custom_ops_pytorch_exporter.py" - ) else() file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/training/*.py" @@ -557,9 +554,6 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ortmodule_torch_cpp_ext_torch_gpu_allocator_srcs} $/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_train_tools_srcs} - $/onnxruntime/training/ ) endif() diff --git a/tools/python/register_custom_ops_pytorch_exporter.py b/onnxruntime/python/tools/pytorch_export_contrib_ops.py similarity index 60% rename from tools/python/register_custom_ops_pytorch_exporter.py rename to onnxruntime/python/tools/pytorch_export_contrib_ops.py index 04ecca8693..d217c1086f 100644 --- a/tools/python/register_custom_ops_pytorch_exporter.py +++ b/onnxruntime/python/tools/pytorch_export_contrib_ops.py @@ -1,21 +1,38 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# -# Register pytorch symbolic for export using ONNX Runtime contrib ops -from torch.onnx import register_custom_op_symbolic +""" +Support for registering ONNX Runtime's built-in contrib ops with +PyTorch-ONNX exporter (torch.onnx.export). +""" + +import typing + +try: + from torch.onnx import register_custom_op_symbolic +except ModuleNotFoundError: + raise ModuleNotFoundError( + "This module is only useful in combination with PyTorch. " + "To install PyTorch see https://pytorch.org/.") import torch.onnx.symbolic_helper as sym_help +import torch.onnx.symbolic_registry as sym_registry -_onnx_opset_version = 1 +_OPSET_VERSION = 1 +_registered_ops: typing.AbstractSet[str] = set() -def register_custom_op(): - """ - This function registers symbolic functions for - custom ops that are implemented as part of ONNX Runtime +def _reg(symbolic_fn: typing.Callable): + name = "::%s" % symbolic_fn.__name__ + register_custom_op_symbolic(name, symbolic_fn, _OPSET_VERSION) + _registered_ops.add(name) + + +def register(): + """Register ONNX Runtime's built-in contrib ops. + + Should be run before torch.onnx.export(). """ - # Symbolic definition def grid_sample(g, input, grid, mode, padding_mode, align_corners): # mode # 'bilinear' : onnx::Constant[value={0}] @@ -42,46 +59,33 @@ def register_custom_op(): mode_s=mode_str, padding_mode_s=padding_mode_str, align_corners_i=align_corners) + _reg(grid_sample) def inverse(g, self): return g.op("com.microsoft::Inverse", self).setType(self.type()) + _reg(inverse) def gelu(g, self): return g.op("com.microsoft::Gelu", self).setType(self.type()) + _reg(gelu) def triu(g, self, diagonal): return g.op("com.microsoft::Trilu", self, diagonal, upper_i=1).setType(self.type()) + _reg(triu) def tril(g, self, diagonal): return g.op("com.microsoft::Trilu", self, diagonal, upper_i=0).setType(self.type()) - - # Op Registration - register_custom_op_symbolic('::grid_sampler', grid_sample, _onnx_opset_version) - register_custom_op_symbolic('::inverse', inverse, _onnx_opset_version) - register_custom_op_symbolic('::gelu', gelu, _onnx_opset_version) - register_custom_op_symbolic('::triu', triu, _onnx_opset_version) - register_custom_op_symbolic('::tril', tril, _onnx_opset_version) + _reg(tril) -def unregister_custom_op(): - """ - This function unregisters symbolic functions for - custom ops that are implemented as part of ONNX Runtime - """ - - import torch.onnx.symbolic_registry as sym_registry +def unregister(): + """Unregister ONNX Runtime's built-in contrib ops.""" # TODO: replace this once PyTorch supports unregister natively. - def unregister(name, opset_version): + # https://msdata.visualstudio.com/Vienna/_workitems/edit/1342343 + for name in _registered_ops: ns, kind = name.split("::") - from torch.onnx.symbolic_helper import _onnx_stable_opsets - - for version in _onnx_stable_opsets: - if version >= opset_version and sym_registry.is_registered_op(kind, ns, version): + for version in sym_help._onnx_stable_opsets: + if (version >= _OPSET_VERSION and + sym_registry.is_registered_op(kind, ns, version)): del sym_registry._registry[(ns, version)][kind] - - unregister('::grid_sampler', _onnx_opset_version) - unregister('::inverse', _onnx_opset_version) - unregister('::gelu', _onnx_opset_version) - unregister('::triu', _onnx_opset_version) - unregister('::tril', _onnx_opset_version) diff --git a/tools/test/test_custom_ops_pytorch_exporter.py b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py similarity index 79% rename from tools/test/test_custom_ops_pytorch_exporter.py rename to onnxruntime/test/python/test_pytorch_export_contrib_ops.py index 035dc6c59a..b115cb23f0 100644 --- a/tools/test/test_custom_ops_pytorch_exporter.py +++ b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py @@ -1,15 +1,15 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# -# Test export of pytorch operators using ONNX Runtime contrib ops + +"""Test export of PyTorch operators using ONNX Runtime contrib ops.""" import torch import onnxruntime +from onnxruntime.tools import pytorch_export_contrib_ops import numpy as np import unittest import io import copy -from python.register_custom_ops_pytorch_exporter import register_custom_op def ort_test_with_input(ort_sess, input, output, rtol, atol): @@ -35,9 +35,8 @@ def ort_test_with_input(ort_sess, input, output, rtol, atol): [np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)] -# These set of tests verify ONNX model export and compare onnxruntime outputs to pytorch. -# To register custom ops and run the tests, you should set PYTHONPATH as: -# PYTHONPATH= python -m pytest -v test_custom_ops_pytorch_exporter.py +# These set of tests verify ONNX model export and compares outputs between +# PyTorch and ORT. class ONNXExporterTest(unittest.TestCase): from torch.onnx.symbolic_helper import _export_onnx_opset_version opset_version = _export_onnx_opset_version @@ -45,7 +44,7 @@ class ONNXExporterTest(unittest.TestCase): def setUp(self): torch.manual_seed(0) - register_custom_op() + pytorch_export_contrib_ops.register() def run_test(self, model, input=None, custom_opsets=None, @@ -103,12 +102,12 @@ class ONNXExporterTest(unittest.TestCase): return torch.inverse(x) + x x = torch.randn(2, 3, 3) - self.run_test(CustomInverse(), x, custom_opsets={'com.microsoft': 1}) + self.run_test(CustomInverse(), x, custom_opsets={"com.microsoft": 1}) def test_gelu(self): model = torch.nn.GELU() x = torch.randn(3, 3) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) def test_triu(self): for i in range(-5, 5): @@ -118,13 +117,13 @@ class ONNXExporterTest(unittest.TestCase): model = Module() x = torch.randn(5, 4, 7, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(5, 4, 0, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(5, 0, 0, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) for i in range(-5, 5): class Module2D(torch.nn.Module): @@ -133,13 +132,13 @@ class ONNXExporterTest(unittest.TestCase): model = Module2D() x = torch.randn(4, 7, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(0, 7, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(0, 0, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) def test_tril(self): for i in range(-5, 5): @@ -149,13 +148,13 @@ class ONNXExporterTest(unittest.TestCase): model = Module() x = torch.randn(5, 4, 7, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(5, 4, 0, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(5, 0, 0, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) for i in range(-5, 5): class Module2D(torch.nn.Module): @@ -164,13 +163,13 @@ class ONNXExporterTest(unittest.TestCase): model = Module2D() x = torch.randn(4, 7, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(0, 7, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) x = torch.randn(0, 0, dtype=torch.float32) - self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + self.run_test(model, x, custom_opsets={"com.microsoft": 1}) # opset 9 tests, with keep_initializers_as_inputs=False for @@ -181,5 +180,5 @@ ONNXExporterTest_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"), keep_initializers_as_inputs=False)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index 82b65d7813..5d1046ad34 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -355,8 +355,8 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op sample_inputs_copy = copy.deepcopy(sample_inputs) # Enable contrib ops export from PyTorch - from onnxruntime.training import register_custom_ops_pytorch_exporter - register_custom_ops_pytorch_exporter.register_custom_op() + from onnxruntime.tools import pytorch_export_contrib_ops + pytorch_export_contrib_ops.register() torch.onnx._export(model, tuple(sample_inputs_copy), f, input_names=input_names, diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 61d1e951a6..964582c107 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -10,7 +10,7 @@ from ._custom_gradient_registry import CustomGradientRegistry from .debug_options import DebugOptions from ._fallback import _FallbackManager, _FallbackPolicy, ORTModuleFallbackException, ORTModuleTorchModelException, wrap_exception from . import _FALLBACK_INIT_EXCEPTION, MINIMUM_RUNTIME_PYTORCH_VERSION_STR, ORTMODULE_FALLBACK_POLICY, ORTMODULE_FALLBACK_RETRY -from onnxruntime.training import register_custom_ops_pytorch_exporter +from onnxruntime.tools import pytorch_export_contrib_ops import functools import torch @@ -71,7 +71,7 @@ class ORTModule(torch.nn.Module): super(ORTModule, self).__init__() # Support contrib OPs - register_custom_ops_pytorch_exporter.register_custom_op() + pytorch_export_contrib_ops.register() CustomOpSymbolicRegistry.register_all() CustomGradientRegistry.register_all() diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 2d371a891a..b2f6cad907 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -120,8 +120,8 @@ class ORTTrainer(object): ort_trainer = ORTTrainer(model, model_desc, optim_config, loss_fn) """ - def __init__(self, model, model_desc, optim_config, - loss_fn=None, + def __init__(self, model, model_desc, optim_config, + loss_fn=None, options=None): assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model" assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'" @@ -532,13 +532,12 @@ class ORTTrainer(object): sample_inputs_copy = copy.deepcopy(sample_inputs) # Handle contrib OPs support - from onnxruntime.training import register_custom_ops_pytorch_exporter + from onnxruntime.tools import pytorch_export_contrib_ops if self.options._internal_use.enable_onnx_contrib_ops: - # Enable contrib ops export from PyTorch - register_custom_ops_pytorch_exporter.register_custom_op() + pytorch_export_contrib_ops.register() else: - # Unregister contrib ops, if they were registered in previous calls - register_custom_ops_pytorch_exporter.unregister_custom_op() + # Unregister in case they were registered in previous calls. + pytorch_export_contrib_ops.unregister() # Export torch.nn.Module to ONNX torch.onnx._export(model, tuple(sample_inputs_copy), f, @@ -566,9 +565,9 @@ class ORTTrainer(object): return onnx_model - def _create_ort_training_session(self, - optimizer_state_dict={}, - session_options=None, + def _create_ort_training_session(self, + optimizer_state_dict={}, + session_options=None, provider_options=None): # Validating frozen_weights names unused_frozen_weights = [n for n in self.options.utils.frozen_weights\ @@ -622,7 +621,7 @@ class ORTTrainer(object): self.options.distributed.horizontal_parallel_size = max(self.options.distributed.horizontal_parallel_size, 1) self.options.distributed.data_parallel_size = self.options.distributed.world_size // self.options.distributed.horizontal_parallel_size - + # TrainingParameters ort_parameters = ort.TrainingParameters() ort_parameters.loss_output_name = loss_name @@ -753,7 +752,7 @@ class ORTTrainer(object): # Create training session used by train_step # pass all optimizer states to the backend self._create_ort_training_session(optimizer_state_dict, - session_options=session_options, + session_options=session_options, provider_options=provider_options) # Update model description to update dtype when mixed precision is enabled @@ -880,8 +879,8 @@ class ORTTrainer(object): # This prevents CPU -> GPU -> CPU copies between frontend and backend target_device = 'cpu' # the self.options.device may be a device that pytorch does not recognize. - # in that case, we temporary prefer to leave the input/output on CPU and let ORT session - # to move the data between device and host. + # in that case, we temporary prefer to leave the input/output on CPU and let ORT session + # to move the data between device and host. # so output will be on the same device as input. try: test_pt_device = torch.device(target_device) @@ -889,7 +888,7 @@ class ORTTrainer(object): #in this case, input/output must on CPU assert(input.device.type == 'cpu') target_device = 'cpu' - + torch_tensor = torch.zeros(output_desc.shape, device=target_device, dtype=output_desc.dtype_amp if output_desc.dtype_amp else output_desc.dtype) iobinding.bind_output(output_desc.name, torch_tensor.device.type, _utils.get_device_index(target_device), @@ -1282,7 +1281,7 @@ class ORTTrainer(object): if self._training_session: current_state_dict = self.state_dict() if strict: - # for Zero enabled, the current trainer might not have the complete state, and we must allow + # for Zero enabled, the current trainer might not have the complete state, and we must allow # extra keys to be present in the state dict allow_unexpected = True if self.options.distributed.deepspeed_zero_optimization.stage > 0 else False _check_key_mismatch(current_state_dict, state_dict, allow_unexpected) @@ -1360,7 +1359,7 @@ class ORTTrainer(object): def _aggregation_required(self, loaded_trainer_options): """Checks if aggregation is required for the loading the state_dict into the ORTTrainer""" - # To load states in the backend, aggregation is required for every ZeRO + # To load states in the backend, aggregation is required for every ZeRO # or Megatron checkpoint return loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 or \ loaded_trainer_options[_utils.state_dict_trainer_options_horizontal_parallel_size_key()] > 1 diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 9f6902222c..d96418cc7e 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1436,6 +1436,9 @@ def run_training_python_frontend_tests(cwd): run_subprocess([ sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_orttrainer_checkpoint_functions.py'], cwd=cwd) + # Not technically training related, but it needs torch to be installed. + run_subprocess([ + sys.executable, '-m', 'pytest', '-sv', 'test_pytorch_export_contrib_ops.py'], cwd=cwd) def run_training_python_frontend_e2e_tests(cwd):