[ONNX] New registration API (#135403)

The ONNX custom ops registration API.

## Design

1. Create a "custom_translation_table: dict[Callable, Sequence[Callable] | Callable" parameter for specifying extra functions
2. Use a callable as the key to support all possible call_function targets in the fx graph
3. Allow a callable or a Sequence of callables as values.
		- When there is a single callable, it is the translation function for the op
		- When there is a Sequence of callable, the exporter's dispatcher will dispatch to these callables in order based on input dtypes.
		- The translation functions can be a plain python function that calls onnxscript ops (traced), or an onnxscript function.
		- Complex input support: We create special type annotations for annotating real representations of complex inputs, which are needed to handle complex computation in the ONNX graph, as we don't have any ops in ONNX that handle complex inputs. The dispatcher will have knowledge of these newly created type annotations and dispatch correctly. The complex functions will be in the same overload pool as the real functions.

```py
torch.onnx.export(dynamo=True,
	custom_translation_table = {
	torch.ops.aten.add: [overload1, overload2],
	torch.sym_not: sym_not_onnx,
})
```
Support for functions that handles complex inputs will be in separate PRs.

fixes https://github.com/pytorch/pytorch/issues/138391

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135403
Approved by: https://github.com/titaiwangms
This commit is contained in:
Justin Chu 2024-11-01 20:58:51 +00:00 committed by PyTorch MergeBot
parent f5b9e725d1
commit 5d67efb809
4 changed files with 139 additions and 18 deletions

View file

@ -5,6 +5,8 @@ from __future__ import annotations
import os
from onnxscript import BOOL, FLOAT, opset18 as op
import torch
import torch.onnx._flags
from torch.onnx._internal.exporter import _testing as onnx_testing
@ -206,5 +208,91 @@ class TestExportAPIDynamo(common_utils.TestCase):
self.assert_export(Model(), (input))
class TestCustomTranslationTable(common_utils.TestCase):
def test_custom_translation_table_overrides_ops(self):
from onnxscript import opset18 as op
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
def custom_add(self, other):
# Replace add with sub
return op.Sub(self, other)
custom_translation_table = {torch.ops.aten.add.Tensor: custom_add}
onnx_program = torch.onnx.export(
Model(),
(torch.randn(2, 2), torch.randn(2, 2)),
custom_translation_table=custom_translation_table,
dynamo=True,
)
all_nodes = [n.op_type for n in onnx_program.model.graph]
self.assertIn("Sub", all_nodes)
self.assertNotIn("Add", all_nodes)
def test_custom_translation_table_supports_overloading_ops(self):
class Model(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.logical_and.default(x, y)
def custom_add_bool(self: BOOL, other: BOOL) -> BOOL:
# Replace add with sub
return op.Sub(self, other)
def custom_add(self: FLOAT, other: FLOAT) -> FLOAT:
# Replace add with mul
return op.Mul(self, other)
custom_translation_table = {
torch.ops.aten.logical_and.default: [custom_add, custom_add_bool],
}
onnx_program = torch.onnx.export(
Model(),
(torch.tensor(1, dtype=torch.bool), torch.tensor(1, dtype=torch.bool)),
custom_translation_table=custom_translation_table,
dynamo=True,
)
all_nodes = [n.op_type for n in onnx_program.model.graph]
# The dispatcher should pick the correct overload based on the input types
self.assertIn("Sub", all_nodes)
self.assertNotIn("Add", all_nodes)
self.assertNotIn("Mul", all_nodes)
def test_custom_translation_table_supports_custom_op_as_target(self):
# Define the custom op and use it in the model
@torch.library.custom_op("custom::add", mutates_args=())
def custom_add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a + b
@custom_add.register_fake
def _(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a) + torch.empty_like(b)
class Model(torch.nn.Module):
def forward(self, x, y):
return custom_add(x, y)
def onnx_add(self: FLOAT, other: FLOAT) -> FLOAT:
# Replace add with Sub
return op.Sub(self, other)
custom_translation_table = {
torch.ops.custom.add.default: onnx_add,
}
onnx_program = torch.onnx.export(
Model(),
(torch.tensor(1, dtype=torch.bool), torch.tensor(1, dtype=torch.bool)),
custom_translation_table=custom_translation_table,
dynamo=True,
)
all_nodes = [n.op_type for n in onnx_program.model.graph]
self.assertIn("Sub", all_nodes)
self.assertNotIn("Add", all_nodes)
if __name__ == "__main__":
common_utils.run_tests()

View file

@ -148,6 +148,8 @@ def export(
# Dynamo only options
external_data: bool = True,
dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
custom_translation_table: dict[Callable, Callable | Sequence[Callable]]
| None = None,
report: bool = False,
optimize: bool = False,
verify: bool = False,
@ -280,15 +282,24 @@ def export(
:func:`torch.export.export` for more details. This is only used (and preferred) when dynamo is True.
Only one parameter `dynamic_axes` or `dynamic_shapes` should be set
at the same time.
report: Whether to generate a markdown report for the export process.
optimize: Whether to optimize the exported model.
verify: Whether to verify the exported model using ONNX Runtime.
profile: Whether to profile the export process.
custom_translation_table: A dictionary of custom decompositions for operators in the model.
The dictionary should have the callable target in the fx Node as the key (e.g. ``torch.ops.aten.stft.default``),
and the value should be a function that builds that graph using ONNX Script. This option
is only valid when dynamo is True.
report: Whether to generate a markdown report for the export process. This option
is only valid when dynamo is True.
optimize: Whether to optimize the exported model. This option
is only valid when dynamo is True.
verify: Whether to verify the exported model using ONNX Runtime. This option
is only valid when dynamo is True.
profile: Whether to profile the export process. This option
is only valid when dynamo is True.
dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file.
This is useful for debugging the exporter.
This is useful for debugging the exporter. This option is only valid when dynamo is True.
artifacts_dir: The directory to save the debugging artifacts like the report and the serialized
exported program.
exported program. This option is only valid when dynamo is True.
fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails.
This option is only valid when dynamo is True.
training: Deprecated option. Instead, set the training mode of the model before exporting.
operator_export_type: Deprecated option. Only ONNX is supported.
@ -346,6 +357,7 @@ def export(
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
custom_translation_table=custom_translation_table,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
external_data=external_data,

View file

@ -6,11 +6,11 @@ from __future__ import annotations
import inspect
import logging
from typing import Any, Mapping, Sequence, TYPE_CHECKING
from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING
import torch
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
from torch.onnx._internal.exporter import _core, _onnx_program
from torch.onnx._internal.exporter import _core, _onnx_program, _registration
if TYPE_CHECKING:
@ -125,6 +125,8 @@ def export_compat(
input_names: Sequence[str] | None = None,
output_names: Sequence[str] | None = None,
opset_version: int | None = None,
custom_translation_table: dict[Callable, Callable | Sequence[Callable]]
| None = None,
dynamic_axes: Mapping[str, Mapping[int, str]]
| Mapping[str, Sequence[int]]
| None = None,
@ -158,12 +160,22 @@ def export_compat(
output_names=set(output_names or ()),
)
registry = _registration.ONNXRegistry.from_torchlib()
if custom_translation_table is not None:
for torch_op, onnx_ops in custom_translation_table.items():
# TODO(justinchuby): Support complex inputs with annotations
if not isinstance(onnx_ops, Sequence):
onnx_ops = (onnx_ops,)
for op in reversed(onnx_ops):
# register_op places the op in the front of all onnx variants,
# so we reverse the list to maintain the order of the custom ops provided
registry.register_op(torch_op, op, is_complex=False)
try:
onnx_program = _core.export(
model,
args,
kwargs,
registry=None,
registry=registry,
dynamic_shapes=dynamic_shapes,
input_names=input_names,
output_names=output_names,

View file

@ -17,19 +17,15 @@ import logging
import math
import operator
import types
import typing
from typing import Callable, Literal, Union
from typing_extensions import TypeAlias
import torch
import torch._ops
from torch.onnx._internal._lazy_import import onnxscript_apis
from torch.onnx._internal._lazy_import import onnxscript, onnxscript_apis
from torch.onnx._internal.exporter import _schemas
if typing.TYPE_CHECKING:
import onnxscript
_DEFAULT_OPSET_VERSION = 18
@ -153,9 +149,6 @@ class ONNXRegistry:
try:
# NOTE: This is heavily guarded with try-except because we don't want
# to fail the entire registry population if one function fails.
if qualified_name.startswith("internal::"):
# Skip the custom defined internal functions
continue
target = _get_overload(qualified_name)
if target is None:
continue
@ -203,7 +196,7 @@ class ONNXRegistry:
def register_op(
self,
target: TorchOp,
function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction,
function: Callable,
is_complex: bool = False,
) -> None:
"""Registers a custom operator: torch.ops.<namespace>.<op_name>.<overload>.
@ -213,6 +206,22 @@ class ONNXRegistry:
function: The onnx-script function to register.
is_complex: Whether the function is a function that handles complex valued inputs.
"""
if not hasattr(function, "signature"):
try:
# TODO(justinchuby): Use the op_signature attribute when onnxscript is updated in CI
if isinstance(function, onnxscript.OnnxFunction):
function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
function, function.function_ir.domain, function.name
)
else:
function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
function, "__custom", function.__name__
)
except Exception:
logger.exception(
"Failed to infer the signature for function '%s'", function
)
onnx_decomposition = OnnxDecompMeta(
onnx_function=function,
fx_target=target,