From 5d67efb809fd207e7e344b34c68a3a4b0a68b56f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 1 Nov 2024 20:58:51 +0000 Subject: [PATCH] [ONNX] New registration API (#135403) The ONNX custom ops registration API. ## Design 1. Create a "custom_translation_table: dict[Callable, Sequence[Callable] | Callable" parameter for specifying extra functions 2. Use a callable as the key to support all possible call_function targets in the fx graph 3. Allow a callable or a Sequence of callables as values. - When there is a single callable, it is the translation function for the op - When there is a Sequence of callable, the exporter's dispatcher will dispatch to these callables in order based on input dtypes. - The translation functions can be a plain python function that calls onnxscript ops (traced), or an onnxscript function. - Complex input support: We create special type annotations for annotating real representations of complex inputs, which are needed to handle complex computation in the ONNX graph, as we don't have any ops in ONNX that handle complex inputs. The dispatcher will have knowledge of these newly created type annotations and dispatch correctly. The complex functions will be in the same overload pool as the real functions. ```py torch.onnx.export(dynamo=True, custom_translation_table = { torch.ops.aten.add: [overload1, overload2], torch.sym_not: sym_not_onnx, }) ``` Support for functions that handles complex inputs will be in separate PRs. fixes https://github.com/pytorch/pytorch/issues/138391 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135403 Approved by: https://github.com/titaiwangms --- test/onnx/exporter/test_api.py | 88 +++++++++++++++++++ torch/onnx/__init__.py | 24 +++-- torch/onnx/_internal/exporter/_compat.py | 18 +++- .../onnx/_internal/exporter/_registration.py | 27 ++++-- 4 files changed, 139 insertions(+), 18 deletions(-) diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 1b31cea1226..ce04b7e6a9a 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -5,6 +5,8 @@ from __future__ import annotations import os +from onnxscript import BOOL, FLOAT, opset18 as op + import torch import torch.onnx._flags from torch.onnx._internal.exporter import _testing as onnx_testing @@ -206,5 +208,91 @@ class TestExportAPIDynamo(common_utils.TestCase): self.assert_export(Model(), (input)) +class TestCustomTranslationTable(common_utils.TestCase): + def test_custom_translation_table_overrides_ops(self): + from onnxscript import opset18 as op + + class Model(torch.nn.Module): + def forward(self, x, y): + return x + y + + def custom_add(self, other): + # Replace add with sub + return op.Sub(self, other) + + custom_translation_table = {torch.ops.aten.add.Tensor: custom_add} + + onnx_program = torch.onnx.export( + Model(), + (torch.randn(2, 2), torch.randn(2, 2)), + custom_translation_table=custom_translation_table, + dynamo=True, + ) + all_nodes = [n.op_type for n in onnx_program.model.graph] + self.assertIn("Sub", all_nodes) + self.assertNotIn("Add", all_nodes) + + def test_custom_translation_table_supports_overloading_ops(self): + class Model(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aten.logical_and.default(x, y) + + def custom_add_bool(self: BOOL, other: BOOL) -> BOOL: + # Replace add with sub + return op.Sub(self, other) + + def custom_add(self: FLOAT, other: FLOAT) -> FLOAT: + # Replace add with mul + return op.Mul(self, other) + + custom_translation_table = { + torch.ops.aten.logical_and.default: [custom_add, custom_add_bool], + } + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor(1, dtype=torch.bool), torch.tensor(1, dtype=torch.bool)), + custom_translation_table=custom_translation_table, + dynamo=True, + ) + all_nodes = [n.op_type for n in onnx_program.model.graph] + # The dispatcher should pick the correct overload based on the input types + self.assertIn("Sub", all_nodes) + self.assertNotIn("Add", all_nodes) + self.assertNotIn("Mul", all_nodes) + + def test_custom_translation_table_supports_custom_op_as_target(self): + # Define the custom op and use it in the model + @torch.library.custom_op("custom::add", mutates_args=()) + def custom_add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a + b + + @custom_add.register_fake + def _(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.empty_like(a) + torch.empty_like(b) + + class Model(torch.nn.Module): + def forward(self, x, y): + return custom_add(x, y) + + def onnx_add(self: FLOAT, other: FLOAT) -> FLOAT: + # Replace add with Sub + return op.Sub(self, other) + + custom_translation_table = { + torch.ops.custom.add.default: onnx_add, + } + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor(1, dtype=torch.bool), torch.tensor(1, dtype=torch.bool)), + custom_translation_table=custom_translation_table, + dynamo=True, + ) + all_nodes = [n.op_type for n in onnx_program.model.graph] + self.assertIn("Sub", all_nodes) + self.assertNotIn("Add", all_nodes) + + if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index b15a45a4d17..02839e1ef4f 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -148,6 +148,8 @@ def export( # Dynamo only options external_data: bool = True, dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, + custom_translation_table: dict[Callable, Callable | Sequence[Callable]] + | None = None, report: bool = False, optimize: bool = False, verify: bool = False, @@ -280,15 +282,24 @@ def export( :func:`torch.export.export` for more details. This is only used (and preferred) when dynamo is True. Only one parameter `dynamic_axes` or `dynamic_shapes` should be set at the same time. - report: Whether to generate a markdown report for the export process. - optimize: Whether to optimize the exported model. - verify: Whether to verify the exported model using ONNX Runtime. - profile: Whether to profile the export process. + custom_translation_table: A dictionary of custom decompositions for operators in the model. + The dictionary should have the callable target in the fx Node as the key (e.g. ``torch.ops.aten.stft.default``), + and the value should be a function that builds that graph using ONNX Script. This option + is only valid when dynamo is True. + report: Whether to generate a markdown report for the export process. This option + is only valid when dynamo is True. + optimize: Whether to optimize the exported model. This option + is only valid when dynamo is True. + verify: Whether to verify the exported model using ONNX Runtime. This option + is only valid when dynamo is True. + profile: Whether to profile the export process. This option + is only valid when dynamo is True. dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file. - This is useful for debugging the exporter. + This is useful for debugging the exporter. This option is only valid when dynamo is True. artifacts_dir: The directory to save the debugging artifacts like the report and the serialized - exported program. + exported program. This option is only valid when dynamo is True. fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails. + This option is only valid when dynamo is True. training: Deprecated option. Instead, set the training mode of the model before exporting. operator_export_type: Deprecated option. Only ONNX is supported. @@ -346,6 +357,7 @@ def export( input_names=input_names, output_names=output_names, opset_version=opset_version, + custom_translation_table=custom_translation_table, dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs, external_data=external_data, diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index 411384f1d36..f54a7b39b83 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -6,11 +6,11 @@ from __future__ import annotations import inspect import logging -from typing import Any, Mapping, Sequence, TYPE_CHECKING +from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING import torch from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir -from torch.onnx._internal.exporter import _core, _onnx_program +from torch.onnx._internal.exporter import _core, _onnx_program, _registration if TYPE_CHECKING: @@ -125,6 +125,8 @@ def export_compat( input_names: Sequence[str] | None = None, output_names: Sequence[str] | None = None, opset_version: int | None = None, + custom_translation_table: dict[Callable, Callable | Sequence[Callable]] + | None = None, dynamic_axes: Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None = None, @@ -158,12 +160,22 @@ def export_compat( output_names=set(output_names or ()), ) + registry = _registration.ONNXRegistry.from_torchlib() + if custom_translation_table is not None: + for torch_op, onnx_ops in custom_translation_table.items(): + # TODO(justinchuby): Support complex inputs with annotations + if not isinstance(onnx_ops, Sequence): + onnx_ops = (onnx_ops,) + for op in reversed(onnx_ops): + # register_op places the op in the front of all onnx variants, + # so we reverse the list to maintain the order of the custom ops provided + registry.register_op(torch_op, op, is_complex=False) try: onnx_program = _core.export( model, args, kwargs, - registry=None, + registry=registry, dynamic_shapes=dynamic_shapes, input_names=input_names, output_names=output_names, diff --git a/torch/onnx/_internal/exporter/_registration.py b/torch/onnx/_internal/exporter/_registration.py index 0afa084b06c..a03052ef6f0 100644 --- a/torch/onnx/_internal/exporter/_registration.py +++ b/torch/onnx/_internal/exporter/_registration.py @@ -17,19 +17,15 @@ import logging import math import operator import types -import typing from typing import Callable, Literal, Union from typing_extensions import TypeAlias import torch import torch._ops -from torch.onnx._internal._lazy_import import onnxscript_apis +from torch.onnx._internal._lazy_import import onnxscript, onnxscript_apis from torch.onnx._internal.exporter import _schemas -if typing.TYPE_CHECKING: - import onnxscript - _DEFAULT_OPSET_VERSION = 18 @@ -153,9 +149,6 @@ class ONNXRegistry: try: # NOTE: This is heavily guarded with try-except because we don't want # to fail the entire registry population if one function fails. - if qualified_name.startswith("internal::"): - # Skip the custom defined internal functions - continue target = _get_overload(qualified_name) if target is None: continue @@ -203,7 +196,7 @@ class ONNXRegistry: def register_op( self, target: TorchOp, - function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, + function: Callable, is_complex: bool = False, ) -> None: """Registers a custom operator: torch.ops.... @@ -213,6 +206,22 @@ class ONNXRegistry: function: The onnx-script function to register. is_complex: Whether the function is a function that handles complex valued inputs. """ + if not hasattr(function, "signature"): + try: + # TODO(justinchuby): Use the op_signature attribute when onnxscript is updated in CI + if isinstance(function, onnxscript.OnnxFunction): + function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined] + function, function.function_ir.domain, function.name + ) + else: + function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined] + function, "__custom", function.__name__ + ) + except Exception: + logger.exception( + "Failed to infer the signature for function '%s'", function + ) + onnx_decomposition = OnnxDecompMeta( onnx_function=function, fx_target=target,