mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Update torch.onnx.OnnxRegistry usage in DORT tests (#17009)
Update the usage of torch.onnx.OnnxRegistry, as it's officially published in PyTorch: https://github.com/pytorch/pytorch/pull/106140. --------- Co-authored-by: Wei-Sheng Chin <wechi@microsoft.com>
This commit is contained in:
parent
4e6ea730d6
commit
8a335b8347
1 changed files with 23 additions and 19 deletions
|
|
@ -11,7 +11,6 @@ import torch._dynamo
|
|||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch.library import Library
|
||||
from torch.onnx._internal.exporter import ExportOptions
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.training.torchdynamo.ort_backend import OrtBackend
|
||||
|
|
@ -99,25 +98,25 @@ class TestTorchDynamoOrtCustomOp(unittest.TestCase):
|
|||
"""
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Create executor of ONNX model.
|
||||
# We will register a custom exporter for aten.mul.Tensor
|
||||
# in the following step.
|
||||
ort_backend = OrtBackend(
|
||||
ep="CPUExecutionProvider",
|
||||
session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(),
|
||||
onnx_exporter_options=ExportOptions(dynamic_shapes=True),
|
||||
)
|
||||
# Register custom_exporter_for_aten_add_Tensor as "aten::mul.Tensor"'s
|
||||
# exporter.
|
||||
# Use custom_exporter_for_aten_add_Tensor.to_function_proto() to see
|
||||
# the sub-graph representing "aten::mul.Tensor".
|
||||
ort_backend.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry.register_custom_op(
|
||||
onnx_registry = torch.onnx.OnnxRegistry()
|
||||
onnx_registry.register_op(
|
||||
function=custom_exporter_for_aten_add_Tensor,
|
||||
namespace="aten",
|
||||
op_name="mul",
|
||||
overload="Tensor",
|
||||
)
|
||||
|
||||
# In order to use custom exporting function inside PyTorch-to-ONNX exporter used in DORT, create executor of ONNX model with custom `onnx_registry`.
|
||||
ort_backend = OrtBackend(
|
||||
ep="CPUExecutionProvider",
|
||||
session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(),
|
||||
onnx_exporter_options=torch.onnx.ExportOptions(dynamic_shapes=True, onnx_registry=onnx_registry),
|
||||
)
|
||||
|
||||
# Wrap ORT executor as a Dynamo backend.
|
||||
aot_ort = aot_autograd(
|
||||
fw_compiler=ort_backend,
|
||||
|
|
@ -159,21 +158,26 @@ class TestTorchDynamoOrtCustomOp(unittest.TestCase):
|
|||
|
||||
foo_lib.impl(bar_name, bar_impl, "CompositeExplicitAutograd")
|
||||
|
||||
# Ask exporter to map "torch.ops.foo.bar" to
|
||||
# custom_exporter_for_foo_bar_default.
|
||||
onnx_registry = torch.onnx.OnnxRegistry()
|
||||
onnx_registry.register_op(
|
||||
function=custom_exporter_for_aten_add_Tensor,
|
||||
namespace="aten",
|
||||
op_name="mul",
|
||||
overload="Tensor",
|
||||
)
|
||||
|
||||
# Create executor of ONNX model.
|
||||
ort_backend = OrtBackend(
|
||||
ep="CPUExecutionProvider", session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options()
|
||||
ep="CPUExecutionProvider",
|
||||
session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(),
|
||||
onnx_exporter_options=torch.onnx.ExportOptions(onnx_registry=onnx_registry),
|
||||
)
|
||||
# Allow torch.ops.foo.bar.default to be sent to DORT.
|
||||
# _support_dict tells Dynamo which ops to sent to DORT.
|
||||
ort_backend._supported_ops._support_dict.add(torch.ops.foo.bar.default)
|
||||
# Ask exporter to map "torch.ops.foo.bar" to
|
||||
# custom_exporter_for_foo_bar_default.
|
||||
# TODO(wechi): Redesign API to expose this better.
|
||||
ort_backend.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry.register_custom_op(
|
||||
function=custom_exporter_for_foo_bar_default,
|
||||
namespace="foo",
|
||||
op_name="bar",
|
||||
)
|
||||
|
||||
# Wrap ORT executor as a Dynamo backend.
|
||||
aot_ort = aot_autograd(
|
||||
fw_compiler=ort_backend,
|
||||
|
|
|
|||
Loading…
Reference in a new issue