mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
Include pytorch_export_contrib_ops in inference builds (#8878)
* Include pytorch_export_contrib_ops in inference builds Rename / move it from tools/python/register_custom_ops_pytorch_exporter to onnxruntime/python/tools/pytorch_export_contrib_ops. Rationale for inclusion in inference builds: This code is potentially useful for anyone using ORT, not just training. Rationale for new name: "Contrib op" is the nomenclature used within ORT to refer to the set of ops that are not in the standard op set but are included by default with ORT. This is more specific than "custom op", which is what the PyTorch exporter uses to refer to any non-standard op. Step 1 of addressing #8818. After this is merged I will update the docs. * Enable test_pytorch_export_contrib_ops.py in CI Fixes AB#1342330
This commit is contained in:
parent
06bb2ec561
commit
47435311f4
7 changed files with 86 additions and 87 deletions
|
|
@ -19,11 +19,11 @@ endif()
|
|||
file(GLOB onnxruntime_pybind_srcs CONFIGURE_DEPENDS
|
||||
${onnxruntime_pybind_srcs_pattern}
|
||||
)
|
||||
|
||||
|
||||
if(NOT onnxruntime_PYBIND_EXPORT_OPSCHEMA)
|
||||
list(REMOVE_ITEM onnxruntime_pybind_srcs ${ONNXRUNTIME_ROOT}/python/onnxruntime_pybind_schema.cc)
|
||||
endif()
|
||||
|
||||
|
||||
if(onnxruntime_ENABLE_TRAINING)
|
||||
list(REMOVE_ITEM onnxruntime_pybind_srcs ${ONNXRUNTIME_ROOT}/python/onnxruntime_pybind_module.cc)
|
||||
endif()
|
||||
|
|
@ -38,11 +38,11 @@ if (onnxruntime_ENABLE_EAGER_MODE)
|
|||
)
|
||||
|
||||
if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
|
||||
list(APPEND onnxruntime_eager_extension_srcs
|
||||
list(APPEND onnxruntime_eager_extension_srcs
|
||||
"${ORTTRAINING_ROOT}/orttraining/core/framework/torch/dlpack_python.cc")
|
||||
endif()
|
||||
|
||||
list(APPEND onnxruntime_pybind_srcs
|
||||
list(APPEND onnxruntime_pybind_srcs
|
||||
${onnxruntime_eager_extension_srcs})
|
||||
endif()
|
||||
|
||||
|
|
@ -286,9 +286,6 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_torch_gpu_allocator_srcs CONFIGURE_DEPENDS
|
||||
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/*"
|
||||
)
|
||||
file(GLOB onnxruntime_python_train_tools_srcs CONFIGURE_DEPENDS
|
||||
"${REPO_ROOT}/tools/python/register_custom_ops_pytorch_exporter.py"
|
||||
)
|
||||
else()
|
||||
file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/training/*.py"
|
||||
|
|
@ -557,9 +554,6 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_ortmodule_torch_cpp_ext_torch_gpu_allocator_srcs}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_train_tools_srcs}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/
|
||||
)
|
||||
endif()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,21 +1,38 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
# Register pytorch symbolic for export using ONNX Runtime contrib ops
|
||||
|
||||
from torch.onnx import register_custom_op_symbolic
|
||||
"""
|
||||
Support for registering ONNX Runtime's built-in contrib ops with
|
||||
PyTorch-ONNX exporter (torch.onnx.export).
|
||||
"""
|
||||
|
||||
import typing
|
||||
|
||||
try:
|
||||
from torch.onnx import register_custom_op_symbolic
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"This module is only useful in combination with PyTorch. "
|
||||
"To install PyTorch see https://pytorch.org/.")
|
||||
import torch.onnx.symbolic_helper as sym_help
|
||||
import torch.onnx.symbolic_registry as sym_registry
|
||||
|
||||
_onnx_opset_version = 1
|
||||
_OPSET_VERSION = 1
|
||||
_registered_ops: typing.AbstractSet[str] = set()
|
||||
|
||||
|
||||
def register_custom_op():
|
||||
"""
|
||||
This function registers symbolic functions for
|
||||
custom ops that are implemented as part of ONNX Runtime
|
||||
def _reg(symbolic_fn: typing.Callable):
|
||||
name = "::%s" % symbolic_fn.__name__
|
||||
register_custom_op_symbolic(name, symbolic_fn, _OPSET_VERSION)
|
||||
_registered_ops.add(name)
|
||||
|
||||
|
||||
def register():
|
||||
"""Register ONNX Runtime's built-in contrib ops.
|
||||
|
||||
Should be run before torch.onnx.export().
|
||||
"""
|
||||
|
||||
# Symbolic definition
|
||||
def grid_sample(g, input, grid, mode, padding_mode, align_corners):
|
||||
# mode
|
||||
# 'bilinear' : onnx::Constant[value={0}]
|
||||
|
|
@ -42,46 +59,33 @@ def register_custom_op():
|
|||
mode_s=mode_str,
|
||||
padding_mode_s=padding_mode_str,
|
||||
align_corners_i=align_corners)
|
||||
_reg(grid_sample)
|
||||
|
||||
def inverse(g, self):
|
||||
return g.op("com.microsoft::Inverse", self).setType(self.type())
|
||||
_reg(inverse)
|
||||
|
||||
def gelu(g, self):
|
||||
return g.op("com.microsoft::Gelu", self).setType(self.type())
|
||||
_reg(gelu)
|
||||
|
||||
def triu(g, self, diagonal):
|
||||
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=1).setType(self.type())
|
||||
_reg(triu)
|
||||
|
||||
def tril(g, self, diagonal):
|
||||
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=0).setType(self.type())
|
||||
|
||||
# Op Registration
|
||||
register_custom_op_symbolic('::grid_sampler', grid_sample, _onnx_opset_version)
|
||||
register_custom_op_symbolic('::inverse', inverse, _onnx_opset_version)
|
||||
register_custom_op_symbolic('::gelu', gelu, _onnx_opset_version)
|
||||
register_custom_op_symbolic('::triu', triu, _onnx_opset_version)
|
||||
register_custom_op_symbolic('::tril', tril, _onnx_opset_version)
|
||||
_reg(tril)
|
||||
|
||||
|
||||
def unregister_custom_op():
|
||||
"""
|
||||
This function unregisters symbolic functions for
|
||||
custom ops that are implemented as part of ONNX Runtime
|
||||
"""
|
||||
|
||||
import torch.onnx.symbolic_registry as sym_registry
|
||||
|
||||
def unregister():
|
||||
"""Unregister ONNX Runtime's built-in contrib ops."""
|
||||
# TODO: replace this once PyTorch supports unregister natively.
|
||||
def unregister(name, opset_version):
|
||||
# https://msdata.visualstudio.com/Vienna/_workitems/edit/1342343
|
||||
for name in _registered_ops:
|
||||
ns, kind = name.split("::")
|
||||
from torch.onnx.symbolic_helper import _onnx_stable_opsets
|
||||
|
||||
for version in _onnx_stable_opsets:
|
||||
if version >= opset_version and sym_registry.is_registered_op(kind, ns, version):
|
||||
for version in sym_help._onnx_stable_opsets:
|
||||
if (version >= _OPSET_VERSION and
|
||||
sym_registry.is_registered_op(kind, ns, version)):
|
||||
del sym_registry._registry[(ns, version)][kind]
|
||||
|
||||
unregister('::grid_sampler', _onnx_opset_version)
|
||||
unregister('::inverse', _onnx_opset_version)
|
||||
unregister('::gelu', _onnx_opset_version)
|
||||
unregister('::triu', _onnx_opset_version)
|
||||
unregister('::tril', _onnx_opset_version)
|
||||
|
|
@ -1,15 +1,15 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
# Test export of pytorch operators using ONNX Runtime contrib ops
|
||||
|
||||
"""Test export of PyTorch operators using ONNX Runtime contrib ops."""
|
||||
|
||||
import torch
|
||||
import onnxruntime
|
||||
from onnxruntime.tools import pytorch_export_contrib_ops
|
||||
import numpy as np
|
||||
import unittest
|
||||
import io
|
||||
import copy
|
||||
from python.register_custom_ops_pytorch_exporter import register_custom_op
|
||||
|
||||
|
||||
def ort_test_with_input(ort_sess, input, output, rtol, atol):
|
||||
|
|
@ -35,9 +35,8 @@ def ort_test_with_input(ort_sess, input, output, rtol, atol):
|
|||
[np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
|
||||
|
||||
|
||||
# These set of tests verify ONNX model export and compare onnxruntime outputs to pytorch.
|
||||
# To register custom ops and run the tests, you should set PYTHONPATH as:
|
||||
# PYTHONPATH=<path_to_onnxruntime/tools> python -m pytest -v test_custom_ops_pytorch_exporter.py
|
||||
# These set of tests verify ONNX model export and compares outputs between
|
||||
# PyTorch and ORT.
|
||||
class ONNXExporterTest(unittest.TestCase):
|
||||
from torch.onnx.symbolic_helper import _export_onnx_opset_version
|
||||
opset_version = _export_onnx_opset_version
|
||||
|
|
@ -45,7 +44,7 @@ class ONNXExporterTest(unittest.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
torch.manual_seed(0)
|
||||
register_custom_op()
|
||||
pytorch_export_contrib_ops.register()
|
||||
|
||||
def run_test(self, model, input=None,
|
||||
custom_opsets=None,
|
||||
|
|
@ -103,12 +102,12 @@ class ONNXExporterTest(unittest.TestCase):
|
|||
return torch.inverse(x) + x
|
||||
|
||||
x = torch.randn(2, 3, 3)
|
||||
self.run_test(CustomInverse(), x, custom_opsets={'com.microsoft': 1})
|
||||
self.run_test(CustomInverse(), x, custom_opsets={"com.microsoft": 1})
|
||||
|
||||
def test_gelu(self):
|
||||
model = torch.nn.GELU()
|
||||
x = torch.randn(3, 3)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
|
||||
|
||||
def test_triu(self):
|
||||
for i in range(-5, 5):
|
||||
|
|
@ -118,13 +117,13 @@ class ONNXExporterTest(unittest.TestCase):
|
|||
|
||||
model = Module()
|
||||
x = torch.randn(5, 4, 7, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
|
||||
|
||||
x = torch.randn(5, 4, 0, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
|
||||
|
||||
x = torch.randn(5, 0, 0, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
|
||||
|
||||
for i in range(-5, 5):
|
||||
class Module2D(torch.nn.Module):
|
||||
|
|
@ -133,13 +132,13 @@ class ONNXExporterTest(unittest.TestCase):
|
|||
|
||||
model = Module2D()
|
||||
x = torch.randn(4, 7, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
|
||||
|
||||
x = torch.randn(0, 7, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
|
||||
|
||||
x = torch.randn(0, 0, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
|
||||
|
||||
def test_tril(self):
|
||||
for i in range(-5, 5):
|
||||
|
|
@ -149,13 +148,13 @@ class ONNXExporterTest(unittest.TestCase):
|
|||
|
||||
model = Module()
|
||||
x = torch.randn(5, 4, 7, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
|
||||
|
||||
x = torch.randn(5, 4, 0, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
|
||||
|
||||
x = torch.randn(5, 0, 0, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
|
||||
|
||||
for i in range(-5, 5):
|
||||
class Module2D(torch.nn.Module):
|
||||
|
|
@ -164,13 +163,13 @@ class ONNXExporterTest(unittest.TestCase):
|
|||
|
||||
model = Module2D()
|
||||
x = torch.randn(4, 7, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
|
||||
|
||||
x = torch.randn(0, 7, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
|
||||
|
||||
x = torch.randn(0, 0, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
|
||||
|
||||
|
||||
# opset 9 tests, with keep_initializers_as_inputs=False for
|
||||
|
|
@ -181,5 +180,5 @@ ONNXExporterTest_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"),
|
|||
keep_initializers_as_inputs=False))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -355,8 +355,8 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op
|
|||
sample_inputs_copy = copy.deepcopy(sample_inputs)
|
||||
|
||||
# Enable contrib ops export from PyTorch
|
||||
from onnxruntime.training import register_custom_ops_pytorch_exporter
|
||||
register_custom_ops_pytorch_exporter.register_custom_op()
|
||||
from onnxruntime.tools import pytorch_export_contrib_ops
|
||||
pytorch_export_contrib_ops.register()
|
||||
|
||||
torch.onnx._export(model, tuple(sample_inputs_copy), f,
|
||||
input_names=input_names,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from ._custom_gradient_registry import CustomGradientRegistry
|
|||
from .debug_options import DebugOptions
|
||||
from ._fallback import _FallbackManager, _FallbackPolicy, ORTModuleFallbackException, ORTModuleTorchModelException, wrap_exception
|
||||
from . import _FALLBACK_INIT_EXCEPTION, MINIMUM_RUNTIME_PYTORCH_VERSION_STR, ORTMODULE_FALLBACK_POLICY, ORTMODULE_FALLBACK_RETRY
|
||||
from onnxruntime.training import register_custom_ops_pytorch_exporter
|
||||
from onnxruntime.tools import pytorch_export_contrib_ops
|
||||
|
||||
import functools
|
||||
import torch
|
||||
|
|
@ -71,7 +71,7 @@ class ORTModule(torch.nn.Module):
|
|||
super(ORTModule, self).__init__()
|
||||
|
||||
# Support contrib OPs
|
||||
register_custom_ops_pytorch_exporter.register_custom_op()
|
||||
pytorch_export_contrib_ops.register()
|
||||
CustomOpSymbolicRegistry.register_all()
|
||||
CustomGradientRegistry.register_all()
|
||||
|
||||
|
|
|
|||
|
|
@ -120,8 +120,8 @@ class ORTTrainer(object):
|
|||
ort_trainer = ORTTrainer(model, model_desc, optim_config, loss_fn)
|
||||
"""
|
||||
|
||||
def __init__(self, model, model_desc, optim_config,
|
||||
loss_fn=None,
|
||||
def __init__(self, model, model_desc, optim_config,
|
||||
loss_fn=None,
|
||||
options=None):
|
||||
assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model"
|
||||
assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'"
|
||||
|
|
@ -532,13 +532,12 @@ class ORTTrainer(object):
|
|||
sample_inputs_copy = copy.deepcopy(sample_inputs)
|
||||
|
||||
# Handle contrib OPs support
|
||||
from onnxruntime.training import register_custom_ops_pytorch_exporter
|
||||
from onnxruntime.tools import pytorch_export_contrib_ops
|
||||
if self.options._internal_use.enable_onnx_contrib_ops:
|
||||
# Enable contrib ops export from PyTorch
|
||||
register_custom_ops_pytorch_exporter.register_custom_op()
|
||||
pytorch_export_contrib_ops.register()
|
||||
else:
|
||||
# Unregister contrib ops, if they were registered in previous calls
|
||||
register_custom_ops_pytorch_exporter.unregister_custom_op()
|
||||
# Unregister in case they were registered in previous calls.
|
||||
pytorch_export_contrib_ops.unregister()
|
||||
|
||||
# Export torch.nn.Module to ONNX
|
||||
torch.onnx._export(model, tuple(sample_inputs_copy), f,
|
||||
|
|
@ -566,9 +565,9 @@ class ORTTrainer(object):
|
|||
|
||||
return onnx_model
|
||||
|
||||
def _create_ort_training_session(self,
|
||||
optimizer_state_dict={},
|
||||
session_options=None,
|
||||
def _create_ort_training_session(self,
|
||||
optimizer_state_dict={},
|
||||
session_options=None,
|
||||
provider_options=None):
|
||||
# Validating frozen_weights names
|
||||
unused_frozen_weights = [n for n in self.options.utils.frozen_weights\
|
||||
|
|
@ -622,7 +621,7 @@ class ORTTrainer(object):
|
|||
|
||||
self.options.distributed.horizontal_parallel_size = max(self.options.distributed.horizontal_parallel_size, 1)
|
||||
self.options.distributed.data_parallel_size = self.options.distributed.world_size // self.options.distributed.horizontal_parallel_size
|
||||
|
||||
|
||||
# TrainingParameters
|
||||
ort_parameters = ort.TrainingParameters()
|
||||
ort_parameters.loss_output_name = loss_name
|
||||
|
|
@ -753,7 +752,7 @@ class ORTTrainer(object):
|
|||
# Create training session used by train_step
|
||||
# pass all optimizer states to the backend
|
||||
self._create_ort_training_session(optimizer_state_dict,
|
||||
session_options=session_options,
|
||||
session_options=session_options,
|
||||
provider_options=provider_options)
|
||||
|
||||
# Update model description to update dtype when mixed precision is enabled
|
||||
|
|
@ -880,8 +879,8 @@ class ORTTrainer(object):
|
|||
# This prevents CPU -> GPU -> CPU copies between frontend and backend
|
||||
target_device = 'cpu'
|
||||
# the self.options.device may be a device that pytorch does not recognize.
|
||||
# in that case, we temporary prefer to leave the input/output on CPU and let ORT session
|
||||
# to move the data between device and host.
|
||||
# in that case, we temporary prefer to leave the input/output on CPU and let ORT session
|
||||
# to move the data between device and host.
|
||||
# so output will be on the same device as input.
|
||||
try:
|
||||
test_pt_device = torch.device(target_device)
|
||||
|
|
@ -889,7 +888,7 @@ class ORTTrainer(object):
|
|||
#in this case, input/output must on CPU
|
||||
assert(input.device.type == 'cpu')
|
||||
target_device = 'cpu'
|
||||
|
||||
|
||||
torch_tensor = torch.zeros(output_desc.shape, device=target_device,
|
||||
dtype=output_desc.dtype_amp if output_desc.dtype_amp else output_desc.dtype)
|
||||
iobinding.bind_output(output_desc.name, torch_tensor.device.type, _utils.get_device_index(target_device),
|
||||
|
|
@ -1282,7 +1281,7 @@ class ORTTrainer(object):
|
|||
if self._training_session:
|
||||
current_state_dict = self.state_dict()
|
||||
if strict:
|
||||
# for Zero enabled, the current trainer might not have the complete state, and we must allow
|
||||
# for Zero enabled, the current trainer might not have the complete state, and we must allow
|
||||
# extra keys to be present in the state dict
|
||||
allow_unexpected = True if self.options.distributed.deepspeed_zero_optimization.stage > 0 else False
|
||||
_check_key_mismatch(current_state_dict, state_dict, allow_unexpected)
|
||||
|
|
@ -1360,7 +1359,7 @@ class ORTTrainer(object):
|
|||
def _aggregation_required(self, loaded_trainer_options):
|
||||
"""Checks if aggregation is required for the loading the state_dict into the ORTTrainer"""
|
||||
|
||||
# To load states in the backend, aggregation is required for every ZeRO
|
||||
# To load states in the backend, aggregation is required for every ZeRO
|
||||
# or Megatron checkpoint
|
||||
return loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 or \
|
||||
loaded_trainer_options[_utils.state_dict_trainer_options_horizontal_parallel_size_key()] > 1
|
||||
|
|
|
|||
|
|
@ -1436,6 +1436,9 @@ def run_training_python_frontend_tests(cwd):
|
|||
|
||||
run_subprocess([
|
||||
sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_orttrainer_checkpoint_functions.py'], cwd=cwd)
|
||||
# Not technically training related, but it needs torch to be installed.
|
||||
run_subprocess([
|
||||
sys.executable, '-m', 'pytest', '-sv', 'test_pytorch_export_contrib_ops.py'], cwd=cwd)
|
||||
|
||||
|
||||
def run_training_python_frontend_e2e_tests(cwd):
|
||||
|
|
|
|||
Loading…
Reference in a new issue