Include pytorch_export_contrib_ops in inference builds (#8878)

* Include pytorch_export_contrib_ops in inference builds

Rename / move it from tools/python/register_custom_ops_pytorch_exporter
to onnxruntime/python/tools/pytorch_export_contrib_ops.

Rationale for inclusion in inference builds:
This code is potentially useful for anyone using ORT, not just training.

Rationale for new name:
"Contrib op" is the nomenclature used within ORT to refer to the set of
ops that are not in the standard op set but are included by default with
ORT. This is more specific than "custom op", which is what the PyTorch
exporter uses to refer to any non-standard op.

Step 1 of addressing #8818. After this is merged I will update the docs.

* Enable test_pytorch_export_contrib_ops.py in CI

Fixes AB#1342330
This commit is contained in:
Gary Miguel 2021-09-02 14:26:58 -07:00 committed by GitHub
parent 06bb2ec561
commit 47435311f4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 86 additions and 87 deletions

View file

@ -19,11 +19,11 @@ endif()
file(GLOB onnxruntime_pybind_srcs CONFIGURE_DEPENDS
${onnxruntime_pybind_srcs_pattern}
)
if(NOT onnxruntime_PYBIND_EXPORT_OPSCHEMA)
list(REMOVE_ITEM onnxruntime_pybind_srcs ${ONNXRUNTIME_ROOT}/python/onnxruntime_pybind_schema.cc)
endif()
if(onnxruntime_ENABLE_TRAINING)
list(REMOVE_ITEM onnxruntime_pybind_srcs ${ONNXRUNTIME_ROOT}/python/onnxruntime_pybind_module.cc)
endif()
@ -38,11 +38,11 @@ if (onnxruntime_ENABLE_EAGER_MODE)
)
if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
list(APPEND onnxruntime_eager_extension_srcs
list(APPEND onnxruntime_eager_extension_srcs
"${ORTTRAINING_ROOT}/orttraining/core/framework/torch/dlpack_python.cc")
endif()
list(APPEND onnxruntime_pybind_srcs
list(APPEND onnxruntime_pybind_srcs
${onnxruntime_eager_extension_srcs})
endif()
@ -286,9 +286,6 @@ if (onnxruntime_ENABLE_TRAINING)
file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_torch_gpu_allocator_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/*"
)
file(GLOB onnxruntime_python_train_tools_srcs CONFIGURE_DEPENDS
"${REPO_ROOT}/tools/python/register_custom_ops_pytorch_exporter.py"
)
else()
file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/training/*.py"
@ -557,9 +554,6 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ortmodule_torch_cpp_ext_torch_gpu_allocator_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_train_tools_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/
)
endif()

View file

@ -1,21 +1,38 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# Register pytorch symbolic for export using ONNX Runtime contrib ops
from torch.onnx import register_custom_op_symbolic
"""
Support for registering ONNX Runtime's built-in contrib ops with
PyTorch-ONNX exporter (torch.onnx.export).
"""
import typing
try:
from torch.onnx import register_custom_op_symbolic
except ModuleNotFoundError:
raise ModuleNotFoundError(
"This module is only useful in combination with PyTorch. "
"To install PyTorch see https://pytorch.org/.")
import torch.onnx.symbolic_helper as sym_help
import torch.onnx.symbolic_registry as sym_registry
_onnx_opset_version = 1
_OPSET_VERSION = 1
_registered_ops: typing.AbstractSet[str] = set()
def register_custom_op():
"""
This function registers symbolic functions for
custom ops that are implemented as part of ONNX Runtime
def _reg(symbolic_fn: typing.Callable):
name = "::%s" % symbolic_fn.__name__
register_custom_op_symbolic(name, symbolic_fn, _OPSET_VERSION)
_registered_ops.add(name)
def register():
"""Register ONNX Runtime's built-in contrib ops.
Should be run before torch.onnx.export().
"""
# Symbolic definition
def grid_sample(g, input, grid, mode, padding_mode, align_corners):
# mode
# 'bilinear' : onnx::Constant[value={0}]
@ -42,46 +59,33 @@ def register_custom_op():
mode_s=mode_str,
padding_mode_s=padding_mode_str,
align_corners_i=align_corners)
_reg(grid_sample)
def inverse(g, self):
return g.op("com.microsoft::Inverse", self).setType(self.type())
_reg(inverse)
def gelu(g, self):
return g.op("com.microsoft::Gelu", self).setType(self.type())
_reg(gelu)
def triu(g, self, diagonal):
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=1).setType(self.type())
_reg(triu)
def tril(g, self, diagonal):
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=0).setType(self.type())
# Op Registration
register_custom_op_symbolic('::grid_sampler', grid_sample, _onnx_opset_version)
register_custom_op_symbolic('::inverse', inverse, _onnx_opset_version)
register_custom_op_symbolic('::gelu', gelu, _onnx_opset_version)
register_custom_op_symbolic('::triu', triu, _onnx_opset_version)
register_custom_op_symbolic('::tril', tril, _onnx_opset_version)
_reg(tril)
def unregister_custom_op():
"""
This function unregisters symbolic functions for
custom ops that are implemented as part of ONNX Runtime
"""
import torch.onnx.symbolic_registry as sym_registry
def unregister():
"""Unregister ONNX Runtime's built-in contrib ops."""
# TODO: replace this once PyTorch supports unregister natively.
def unregister(name, opset_version):
# https://msdata.visualstudio.com/Vienna/_workitems/edit/1342343
for name in _registered_ops:
ns, kind = name.split("::")
from torch.onnx.symbolic_helper import _onnx_stable_opsets
for version in _onnx_stable_opsets:
if version >= opset_version and sym_registry.is_registered_op(kind, ns, version):
for version in sym_help._onnx_stable_opsets:
if (version >= _OPSET_VERSION and
sym_registry.is_registered_op(kind, ns, version)):
del sym_registry._registry[(ns, version)][kind]
unregister('::grid_sampler', _onnx_opset_version)
unregister('::inverse', _onnx_opset_version)
unregister('::gelu', _onnx_opset_version)
unregister('::triu', _onnx_opset_version)
unregister('::tril', _onnx_opset_version)

View file

@ -1,15 +1,15 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# Test export of pytorch operators using ONNX Runtime contrib ops
"""Test export of PyTorch operators using ONNX Runtime contrib ops."""
import torch
import onnxruntime
from onnxruntime.tools import pytorch_export_contrib_ops
import numpy as np
import unittest
import io
import copy
from python.register_custom_ops_pytorch_exporter import register_custom_op
def ort_test_with_input(ort_sess, input, output, rtol, atol):
@ -35,9 +35,8 @@ def ort_test_with_input(ort_sess, input, output, rtol, atol):
[np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
# These set of tests verify ONNX model export and compare onnxruntime outputs to pytorch.
# To register custom ops and run the tests, you should set PYTHONPATH as:
# PYTHONPATH=<path_to_onnxruntime/tools> python -m pytest -v test_custom_ops_pytorch_exporter.py
# These set of tests verify ONNX model export and compares outputs between
# PyTorch and ORT.
class ONNXExporterTest(unittest.TestCase):
from torch.onnx.symbolic_helper import _export_onnx_opset_version
opset_version = _export_onnx_opset_version
@ -45,7 +44,7 @@ class ONNXExporterTest(unittest.TestCase):
def setUp(self):
torch.manual_seed(0)
register_custom_op()
pytorch_export_contrib_ops.register()
def run_test(self, model, input=None,
custom_opsets=None,
@ -103,12 +102,12 @@ class ONNXExporterTest(unittest.TestCase):
return torch.inverse(x) + x
x = torch.randn(2, 3, 3)
self.run_test(CustomInverse(), x, custom_opsets={'com.microsoft': 1})
self.run_test(CustomInverse(), x, custom_opsets={"com.microsoft": 1})
def test_gelu(self):
model = torch.nn.GELU()
x = torch.randn(3, 3)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
def test_triu(self):
for i in range(-5, 5):
@ -118,13 +117,13 @@ class ONNXExporterTest(unittest.TestCase):
model = Module()
x = torch.randn(5, 4, 7, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
x = torch.randn(5, 4, 0, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
x = torch.randn(5, 0, 0, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
for i in range(-5, 5):
class Module2D(torch.nn.Module):
@ -133,13 +132,13 @@ class ONNXExporterTest(unittest.TestCase):
model = Module2D()
x = torch.randn(4, 7, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
x = torch.randn(0, 7, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
x = torch.randn(0, 0, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
def test_tril(self):
for i in range(-5, 5):
@ -149,13 +148,13 @@ class ONNXExporterTest(unittest.TestCase):
model = Module()
x = torch.randn(5, 4, 7, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
x = torch.randn(5, 4, 0, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
x = torch.randn(5, 0, 0, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
for i in range(-5, 5):
class Module2D(torch.nn.Module):
@ -164,13 +163,13 @@ class ONNXExporterTest(unittest.TestCase):
model = Module2D()
x = torch.randn(4, 7, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
x = torch.randn(0, 7, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
x = torch.randn(0, 0, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
self.run_test(model, x, custom_opsets={"com.microsoft": 1})
# opset 9 tests, with keep_initializers_as_inputs=False for
@ -181,5 +180,5 @@ ONNXExporterTest_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"),
keep_initializers_as_inputs=False))
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View file

@ -355,8 +355,8 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op
sample_inputs_copy = copy.deepcopy(sample_inputs)
# Enable contrib ops export from PyTorch
from onnxruntime.training import register_custom_ops_pytorch_exporter
register_custom_ops_pytorch_exporter.register_custom_op()
from onnxruntime.tools import pytorch_export_contrib_ops
pytorch_export_contrib_ops.register()
torch.onnx._export(model, tuple(sample_inputs_copy), f,
input_names=input_names,

View file

@ -10,7 +10,7 @@ from ._custom_gradient_registry import CustomGradientRegistry
from .debug_options import DebugOptions
from ._fallback import _FallbackManager, _FallbackPolicy, ORTModuleFallbackException, ORTModuleTorchModelException, wrap_exception
from . import _FALLBACK_INIT_EXCEPTION, MINIMUM_RUNTIME_PYTORCH_VERSION_STR, ORTMODULE_FALLBACK_POLICY, ORTMODULE_FALLBACK_RETRY
from onnxruntime.training import register_custom_ops_pytorch_exporter
from onnxruntime.tools import pytorch_export_contrib_ops
import functools
import torch
@ -71,7 +71,7 @@ class ORTModule(torch.nn.Module):
super(ORTModule, self).__init__()
# Support contrib OPs
register_custom_ops_pytorch_exporter.register_custom_op()
pytorch_export_contrib_ops.register()
CustomOpSymbolicRegistry.register_all()
CustomGradientRegistry.register_all()

View file

@ -120,8 +120,8 @@ class ORTTrainer(object):
ort_trainer = ORTTrainer(model, model_desc, optim_config, loss_fn)
"""
def __init__(self, model, model_desc, optim_config,
loss_fn=None,
def __init__(self, model, model_desc, optim_config,
loss_fn=None,
options=None):
assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model"
assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'"
@ -532,13 +532,12 @@ class ORTTrainer(object):
sample_inputs_copy = copy.deepcopy(sample_inputs)
# Handle contrib OPs support
from onnxruntime.training import register_custom_ops_pytorch_exporter
from onnxruntime.tools import pytorch_export_contrib_ops
if self.options._internal_use.enable_onnx_contrib_ops:
# Enable contrib ops export from PyTorch
register_custom_ops_pytorch_exporter.register_custom_op()
pytorch_export_contrib_ops.register()
else:
# Unregister contrib ops, if they were registered in previous calls
register_custom_ops_pytorch_exporter.unregister_custom_op()
# Unregister in case they were registered in previous calls.
pytorch_export_contrib_ops.unregister()
# Export torch.nn.Module to ONNX
torch.onnx._export(model, tuple(sample_inputs_copy), f,
@ -566,9 +565,9 @@ class ORTTrainer(object):
return onnx_model
def _create_ort_training_session(self,
optimizer_state_dict={},
session_options=None,
def _create_ort_training_session(self,
optimizer_state_dict={},
session_options=None,
provider_options=None):
# Validating frozen_weights names
unused_frozen_weights = [n for n in self.options.utils.frozen_weights\
@ -622,7 +621,7 @@ class ORTTrainer(object):
self.options.distributed.horizontal_parallel_size = max(self.options.distributed.horizontal_parallel_size, 1)
self.options.distributed.data_parallel_size = self.options.distributed.world_size // self.options.distributed.horizontal_parallel_size
# TrainingParameters
ort_parameters = ort.TrainingParameters()
ort_parameters.loss_output_name = loss_name
@ -753,7 +752,7 @@ class ORTTrainer(object):
# Create training session used by train_step
# pass all optimizer states to the backend
self._create_ort_training_session(optimizer_state_dict,
session_options=session_options,
session_options=session_options,
provider_options=provider_options)
# Update model description to update dtype when mixed precision is enabled
@ -880,8 +879,8 @@ class ORTTrainer(object):
# This prevents CPU -> GPU -> CPU copies between frontend and backend
target_device = 'cpu'
# the self.options.device may be a device that pytorch does not recognize.
# in that case, we temporary prefer to leave the input/output on CPU and let ORT session
# to move the data between device and host.
# in that case, we temporary prefer to leave the input/output on CPU and let ORT session
# to move the data between device and host.
# so output will be on the same device as input.
try:
test_pt_device = torch.device(target_device)
@ -889,7 +888,7 @@ class ORTTrainer(object):
#in this case, input/output must on CPU
assert(input.device.type == 'cpu')
target_device = 'cpu'
torch_tensor = torch.zeros(output_desc.shape, device=target_device,
dtype=output_desc.dtype_amp if output_desc.dtype_amp else output_desc.dtype)
iobinding.bind_output(output_desc.name, torch_tensor.device.type, _utils.get_device_index(target_device),
@ -1282,7 +1281,7 @@ class ORTTrainer(object):
if self._training_session:
current_state_dict = self.state_dict()
if strict:
# for Zero enabled, the current trainer might not have the complete state, and we must allow
# for Zero enabled, the current trainer might not have the complete state, and we must allow
# extra keys to be present in the state dict
allow_unexpected = True if self.options.distributed.deepspeed_zero_optimization.stage > 0 else False
_check_key_mismatch(current_state_dict, state_dict, allow_unexpected)
@ -1360,7 +1359,7 @@ class ORTTrainer(object):
def _aggregation_required(self, loaded_trainer_options):
"""Checks if aggregation is required for the loading the state_dict into the ORTTrainer"""
# To load states in the backend, aggregation is required for every ZeRO
# To load states in the backend, aggregation is required for every ZeRO
# or Megatron checkpoint
return loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 or \
loaded_trainer_options[_utils.state_dict_trainer_options_horizontal_parallel_size_key()] > 1

View file

@ -1436,6 +1436,9 @@ def run_training_python_frontend_tests(cwd):
run_subprocess([
sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_orttrainer_checkpoint_functions.py'], cwd=cwd)
# Not technically training related, but it needs torch to be installed.
run_subprocess([
sys.executable, '-m', 'pytest', '-sv', 'test_pytorch_export_contrib_ops.py'], cwd=cwd)
def run_training_python_frontend_e2e_tests(cwd):