mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[ONNX] New registration API (#135403)
The ONNX custom ops registration API.
## Design
1. Create a "custom_translation_table: dict[Callable, Sequence[Callable] | Callable" parameter for specifying extra functions
2. Use a callable as the key to support all possible call_function targets in the fx graph
3. Allow a callable or a Sequence of callables as values.
- When there is a single callable, it is the translation function for the op
- When there is a Sequence of callable, the exporter's dispatcher will dispatch to these callables in order based on input dtypes.
- The translation functions can be a plain python function that calls onnxscript ops (traced), or an onnxscript function.
- Complex input support: We create special type annotations for annotating real representations of complex inputs, which are needed to handle complex computation in the ONNX graph, as we don't have any ops in ONNX that handle complex inputs. The dispatcher will have knowledge of these newly created type annotations and dispatch correctly. The complex functions will be in the same overload pool as the real functions.
```py
torch.onnx.export(dynamo=True,
custom_translation_table = {
torch.ops.aten.add: [overload1, overload2],
torch.sym_not: sym_not_onnx,
})
```
Support for functions that handles complex inputs will be in separate PRs.
fixes https://github.com/pytorch/pytorch/issues/138391
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135403
Approved by: https://github.com/titaiwangms
This commit is contained in:
parent
f5b9e725d1
commit
5d67efb809
4 changed files with 139 additions and 18 deletions
|
|
@ -5,6 +5,8 @@ from __future__ import annotations
|
|||
|
||||
import os
|
||||
|
||||
from onnxscript import BOOL, FLOAT, opset18 as op
|
||||
|
||||
import torch
|
||||
import torch.onnx._flags
|
||||
from torch.onnx._internal.exporter import _testing as onnx_testing
|
||||
|
|
@ -206,5 +208,91 @@ class TestExportAPIDynamo(common_utils.TestCase):
|
|||
self.assert_export(Model(), (input))
|
||||
|
||||
|
||||
class TestCustomTranslationTable(common_utils.TestCase):
|
||||
def test_custom_translation_table_overrides_ops(self):
|
||||
from onnxscript import opset18 as op
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return x + y
|
||||
|
||||
def custom_add(self, other):
|
||||
# Replace add with sub
|
||||
return op.Sub(self, other)
|
||||
|
||||
custom_translation_table = {torch.ops.aten.add.Tensor: custom_add}
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
Model(),
|
||||
(torch.randn(2, 2), torch.randn(2, 2)),
|
||||
custom_translation_table=custom_translation_table,
|
||||
dynamo=True,
|
||||
)
|
||||
all_nodes = [n.op_type for n in onnx_program.model.graph]
|
||||
self.assertIn("Sub", all_nodes)
|
||||
self.assertNotIn("Add", all_nodes)
|
||||
|
||||
def test_custom_translation_table_supports_overloading_ops(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_and.default(x, y)
|
||||
|
||||
def custom_add_bool(self: BOOL, other: BOOL) -> BOOL:
|
||||
# Replace add with sub
|
||||
return op.Sub(self, other)
|
||||
|
||||
def custom_add(self: FLOAT, other: FLOAT) -> FLOAT:
|
||||
# Replace add with mul
|
||||
return op.Mul(self, other)
|
||||
|
||||
custom_translation_table = {
|
||||
torch.ops.aten.logical_and.default: [custom_add, custom_add_bool],
|
||||
}
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
Model(),
|
||||
(torch.tensor(1, dtype=torch.bool), torch.tensor(1, dtype=torch.bool)),
|
||||
custom_translation_table=custom_translation_table,
|
||||
dynamo=True,
|
||||
)
|
||||
all_nodes = [n.op_type for n in onnx_program.model.graph]
|
||||
# The dispatcher should pick the correct overload based on the input types
|
||||
self.assertIn("Sub", all_nodes)
|
||||
self.assertNotIn("Add", all_nodes)
|
||||
self.assertNotIn("Mul", all_nodes)
|
||||
|
||||
def test_custom_translation_table_supports_custom_op_as_target(self):
|
||||
# Define the custom op and use it in the model
|
||||
@torch.library.custom_op("custom::add", mutates_args=())
|
||||
def custom_add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
return a + b
|
||||
|
||||
@custom_add.register_fake
|
||||
def _(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(a) + torch.empty_like(b)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return custom_add(x, y)
|
||||
|
||||
def onnx_add(self: FLOAT, other: FLOAT) -> FLOAT:
|
||||
# Replace add with Sub
|
||||
return op.Sub(self, other)
|
||||
|
||||
custom_translation_table = {
|
||||
torch.ops.custom.add.default: onnx_add,
|
||||
}
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
Model(),
|
||||
(torch.tensor(1, dtype=torch.bool), torch.tensor(1, dtype=torch.bool)),
|
||||
custom_translation_table=custom_translation_table,
|
||||
dynamo=True,
|
||||
)
|
||||
all_nodes = [n.op_type for n in onnx_program.model.graph]
|
||||
self.assertIn("Sub", all_nodes)
|
||||
self.assertNotIn("Add", all_nodes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
|
|
|||
|
|
@ -148,6 +148,8 @@ def export(
|
|||
# Dynamo only options
|
||||
external_data: bool = True,
|
||||
dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
|
||||
custom_translation_table: dict[Callable, Callable | Sequence[Callable]]
|
||||
| None = None,
|
||||
report: bool = False,
|
||||
optimize: bool = False,
|
||||
verify: bool = False,
|
||||
|
|
@ -280,15 +282,24 @@ def export(
|
|||
:func:`torch.export.export` for more details. This is only used (and preferred) when dynamo is True.
|
||||
Only one parameter `dynamic_axes` or `dynamic_shapes` should be set
|
||||
at the same time.
|
||||
report: Whether to generate a markdown report for the export process.
|
||||
optimize: Whether to optimize the exported model.
|
||||
verify: Whether to verify the exported model using ONNX Runtime.
|
||||
profile: Whether to profile the export process.
|
||||
custom_translation_table: A dictionary of custom decompositions for operators in the model.
|
||||
The dictionary should have the callable target in the fx Node as the key (e.g. ``torch.ops.aten.stft.default``),
|
||||
and the value should be a function that builds that graph using ONNX Script. This option
|
||||
is only valid when dynamo is True.
|
||||
report: Whether to generate a markdown report for the export process. This option
|
||||
is only valid when dynamo is True.
|
||||
optimize: Whether to optimize the exported model. This option
|
||||
is only valid when dynamo is True.
|
||||
verify: Whether to verify the exported model using ONNX Runtime. This option
|
||||
is only valid when dynamo is True.
|
||||
profile: Whether to profile the export process. This option
|
||||
is only valid when dynamo is True.
|
||||
dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file.
|
||||
This is useful for debugging the exporter.
|
||||
This is useful for debugging the exporter. This option is only valid when dynamo is True.
|
||||
artifacts_dir: The directory to save the debugging artifacts like the report and the serialized
|
||||
exported program.
|
||||
exported program. This option is only valid when dynamo is True.
|
||||
fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails.
|
||||
This option is only valid when dynamo is True.
|
||||
|
||||
training: Deprecated option. Instead, set the training mode of the model before exporting.
|
||||
operator_export_type: Deprecated option. Only ONNX is supported.
|
||||
|
|
@ -346,6 +357,7 @@ def export(
|
|||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
opset_version=opset_version,
|
||||
custom_translation_table=custom_translation_table,
|
||||
dynamic_axes=dynamic_axes,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
external_data=external_data,
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@ from __future__ import annotations
|
|||
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Mapping, Sequence, TYPE_CHECKING
|
||||
from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
|
||||
from torch.onnx._internal.exporter import _core, _onnx_program
|
||||
from torch.onnx._internal.exporter import _core, _onnx_program, _registration
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -125,6 +125,8 @@ def export_compat(
|
|||
input_names: Sequence[str] | None = None,
|
||||
output_names: Sequence[str] | None = None,
|
||||
opset_version: int | None = None,
|
||||
custom_translation_table: dict[Callable, Callable | Sequence[Callable]]
|
||||
| None = None,
|
||||
dynamic_axes: Mapping[str, Mapping[int, str]]
|
||||
| Mapping[str, Sequence[int]]
|
||||
| None = None,
|
||||
|
|
@ -158,12 +160,22 @@ def export_compat(
|
|||
output_names=set(output_names or ()),
|
||||
)
|
||||
|
||||
registry = _registration.ONNXRegistry.from_torchlib()
|
||||
if custom_translation_table is not None:
|
||||
for torch_op, onnx_ops in custom_translation_table.items():
|
||||
# TODO(justinchuby): Support complex inputs with annotations
|
||||
if not isinstance(onnx_ops, Sequence):
|
||||
onnx_ops = (onnx_ops,)
|
||||
for op in reversed(onnx_ops):
|
||||
# register_op places the op in the front of all onnx variants,
|
||||
# so we reverse the list to maintain the order of the custom ops provided
|
||||
registry.register_op(torch_op, op, is_complex=False)
|
||||
try:
|
||||
onnx_program = _core.export(
|
||||
model,
|
||||
args,
|
||||
kwargs,
|
||||
registry=None,
|
||||
registry=registry,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
|
|
|
|||
|
|
@ -17,19 +17,15 @@ import logging
|
|||
import math
|
||||
import operator
|
||||
import types
|
||||
import typing
|
||||
from typing import Callable, Literal, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import torch
|
||||
import torch._ops
|
||||
from torch.onnx._internal._lazy_import import onnxscript_apis
|
||||
from torch.onnx._internal._lazy_import import onnxscript, onnxscript_apis
|
||||
from torch.onnx._internal.exporter import _schemas
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import onnxscript
|
||||
|
||||
_DEFAULT_OPSET_VERSION = 18
|
||||
|
||||
|
||||
|
|
@ -153,9 +149,6 @@ class ONNXRegistry:
|
|||
try:
|
||||
# NOTE: This is heavily guarded with try-except because we don't want
|
||||
# to fail the entire registry population if one function fails.
|
||||
if qualified_name.startswith("internal::"):
|
||||
# Skip the custom defined internal functions
|
||||
continue
|
||||
target = _get_overload(qualified_name)
|
||||
if target is None:
|
||||
continue
|
||||
|
|
@ -203,7 +196,7 @@ class ONNXRegistry:
|
|||
def register_op(
|
||||
self,
|
||||
target: TorchOp,
|
||||
function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction,
|
||||
function: Callable,
|
||||
is_complex: bool = False,
|
||||
) -> None:
|
||||
"""Registers a custom operator: torch.ops.<namespace>.<op_name>.<overload>.
|
||||
|
|
@ -213,6 +206,22 @@ class ONNXRegistry:
|
|||
function: The onnx-script function to register.
|
||||
is_complex: Whether the function is a function that handles complex valued inputs.
|
||||
"""
|
||||
if not hasattr(function, "signature"):
|
||||
try:
|
||||
# TODO(justinchuby): Use the op_signature attribute when onnxscript is updated in CI
|
||||
if isinstance(function, onnxscript.OnnxFunction):
|
||||
function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
|
||||
function, function.function_ir.domain, function.name
|
||||
)
|
||||
else:
|
||||
function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
|
||||
function, "__custom", function.__name__
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to infer the signature for function '%s'", function
|
||||
)
|
||||
|
||||
onnx_decomposition = OnnxDecompMeta(
|
||||
onnx_function=function,
|
||||
fx_target=target,
|
||||
|
|
|
|||
Loading…
Reference in a new issue