diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 1b31cea1226..ce04b7e6a9a 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -5,6 +5,8 @@ from __future__ import annotations import os +from onnxscript import BOOL, FLOAT, opset18 as op + import torch import torch.onnx._flags from torch.onnx._internal.exporter import _testing as onnx_testing @@ -206,5 +208,91 @@ class TestExportAPIDynamo(common_utils.TestCase): self.assert_export(Model(), (input)) +class TestCustomTranslationTable(common_utils.TestCase): + def test_custom_translation_table_overrides_ops(self): + from onnxscript import opset18 as op + + class Model(torch.nn.Module): + def forward(self, x, y): + return x + y + + def custom_add(self, other): + # Replace add with sub + return op.Sub(self, other) + + custom_translation_table = {torch.ops.aten.add.Tensor: custom_add} + + onnx_program = torch.onnx.export( + Model(), + (torch.randn(2, 2), torch.randn(2, 2)), + custom_translation_table=custom_translation_table, + dynamo=True, + ) + all_nodes = [n.op_type for n in onnx_program.model.graph] + self.assertIn("Sub", all_nodes) + self.assertNotIn("Add", all_nodes) + + def test_custom_translation_table_supports_overloading_ops(self): + class Model(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aten.logical_and.default(x, y) + + def custom_add_bool(self: BOOL, other: BOOL) -> BOOL: + # Replace add with sub + return op.Sub(self, other) + + def custom_add(self: FLOAT, other: FLOAT) -> FLOAT: + # Replace add with mul + return op.Mul(self, other) + + custom_translation_table = { + torch.ops.aten.logical_and.default: [custom_add, custom_add_bool], + } + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor(1, dtype=torch.bool), torch.tensor(1, dtype=torch.bool)), + custom_translation_table=custom_translation_table, + dynamo=True, + ) + all_nodes = [n.op_type for n in onnx_program.model.graph] + # The dispatcher should pick the correct overload based on the input types + self.assertIn("Sub", all_nodes) + self.assertNotIn("Add", all_nodes) + self.assertNotIn("Mul", all_nodes) + + def test_custom_translation_table_supports_custom_op_as_target(self): + # Define the custom op and use it in the model + @torch.library.custom_op("custom::add", mutates_args=()) + def custom_add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a + b + + @custom_add.register_fake + def _(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.empty_like(a) + torch.empty_like(b) + + class Model(torch.nn.Module): + def forward(self, x, y): + return custom_add(x, y) + + def onnx_add(self: FLOAT, other: FLOAT) -> FLOAT: + # Replace add with Sub + return op.Sub(self, other) + + custom_translation_table = { + torch.ops.custom.add.default: onnx_add, + } + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor(1, dtype=torch.bool), torch.tensor(1, dtype=torch.bool)), + custom_translation_table=custom_translation_table, + dynamo=True, + ) + all_nodes = [n.op_type for n in onnx_program.model.graph] + self.assertIn("Sub", all_nodes) + self.assertNotIn("Add", all_nodes) + + if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index b15a45a4d17..02839e1ef4f 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -148,6 +148,8 @@ def export( # Dynamo only options external_data: bool = True, dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, + custom_translation_table: dict[Callable, Callable | Sequence[Callable]] + | None = None, report: bool = False, optimize: bool = False, verify: bool = False, @@ -280,15 +282,24 @@ def export( :func:`torch.export.export` for more details. This is only used (and preferred) when dynamo is True. Only one parameter `dynamic_axes` or `dynamic_shapes` should be set at the same time. - report: Whether to generate a markdown report for the export process. - optimize: Whether to optimize the exported model. - verify: Whether to verify the exported model using ONNX Runtime. - profile: Whether to profile the export process. + custom_translation_table: A dictionary of custom decompositions for operators in the model. + The dictionary should have the callable target in the fx Node as the key (e.g. ``torch.ops.aten.stft.default``), + and the value should be a function that builds that graph using ONNX Script. This option + is only valid when dynamo is True. + report: Whether to generate a markdown report for the export process. This option + is only valid when dynamo is True. + optimize: Whether to optimize the exported model. This option + is only valid when dynamo is True. + verify: Whether to verify the exported model using ONNX Runtime. This option + is only valid when dynamo is True. + profile: Whether to profile the export process. This option + is only valid when dynamo is True. dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file. - This is useful for debugging the exporter. + This is useful for debugging the exporter. This option is only valid when dynamo is True. artifacts_dir: The directory to save the debugging artifacts like the report and the serialized - exported program. + exported program. This option is only valid when dynamo is True. fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails. + This option is only valid when dynamo is True. training: Deprecated option. Instead, set the training mode of the model before exporting. operator_export_type: Deprecated option. Only ONNX is supported. @@ -346,6 +357,7 @@ def export( input_names=input_names, output_names=output_names, opset_version=opset_version, + custom_translation_table=custom_translation_table, dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs, external_data=external_data, diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index 411384f1d36..f54a7b39b83 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -6,11 +6,11 @@ from __future__ import annotations import inspect import logging -from typing import Any, Mapping, Sequence, TYPE_CHECKING +from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING import torch from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir -from torch.onnx._internal.exporter import _core, _onnx_program +from torch.onnx._internal.exporter import _core, _onnx_program, _registration if TYPE_CHECKING: @@ -125,6 +125,8 @@ def export_compat( input_names: Sequence[str] | None = None, output_names: Sequence[str] | None = None, opset_version: int | None = None, + custom_translation_table: dict[Callable, Callable | Sequence[Callable]] + | None = None, dynamic_axes: Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None = None, @@ -158,12 +160,22 @@ def export_compat( output_names=set(output_names or ()), ) + registry = _registration.ONNXRegistry.from_torchlib() + if custom_translation_table is not None: + for torch_op, onnx_ops in custom_translation_table.items(): + # TODO(justinchuby): Support complex inputs with annotations + if not isinstance(onnx_ops, Sequence): + onnx_ops = (onnx_ops,) + for op in reversed(onnx_ops): + # register_op places the op in the front of all onnx variants, + # so we reverse the list to maintain the order of the custom ops provided + registry.register_op(torch_op, op, is_complex=False) try: onnx_program = _core.export( model, args, kwargs, - registry=None, + registry=registry, dynamic_shapes=dynamic_shapes, input_names=input_names, output_names=output_names, diff --git a/torch/onnx/_internal/exporter/_registration.py b/torch/onnx/_internal/exporter/_registration.py index 0afa084b06c..a03052ef6f0 100644 --- a/torch/onnx/_internal/exporter/_registration.py +++ b/torch/onnx/_internal/exporter/_registration.py @@ -17,19 +17,15 @@ import logging import math import operator import types -import typing from typing import Callable, Literal, Union from typing_extensions import TypeAlias import torch import torch._ops -from torch.onnx._internal._lazy_import import onnxscript_apis +from torch.onnx._internal._lazy_import import onnxscript, onnxscript_apis from torch.onnx._internal.exporter import _schemas -if typing.TYPE_CHECKING: - import onnxscript - _DEFAULT_OPSET_VERSION = 18 @@ -153,9 +149,6 @@ class ONNXRegistry: try: # NOTE: This is heavily guarded with try-except because we don't want # to fail the entire registry population if one function fails. - if qualified_name.startswith("internal::"): - # Skip the custom defined internal functions - continue target = _get_overload(qualified_name) if target is None: continue @@ -203,7 +196,7 @@ class ONNXRegistry: def register_op( self, target: TorchOp, - function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, + function: Callable, is_complex: bool = False, ) -> None: """Registers a custom operator: torch.ops.... @@ -213,6 +206,22 @@ class ONNXRegistry: function: The onnx-script function to register. is_complex: Whether the function is a function that handles complex valued inputs. """ + if not hasattr(function, "signature"): + try: + # TODO(justinchuby): Use the op_signature attribute when onnxscript is updated in CI + if isinstance(function, onnxscript.OnnxFunction): + function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined] + function, function.function_ir.domain, function.name + ) + else: + function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined] + function, "__custom", function.__name__ + ) + except Exception: + logger.exception( + "Failed to infer the signature for function '%s'", function + ) + onnx_decomposition = OnnxDecompMeta( onnx_function=function, fx_target=target,