Update torch.onnx.OnnxRegistry usage in DORT tests (#17009)

Update the usage of torch.onnx.OnnxRegistry, as it's officially
published in PyTorch: https://github.com/pytorch/pytorch/pull/106140.

---------

Co-authored-by: Wei-Sheng Chin <wechi@microsoft.com>
This commit is contained in:
Ti-Tai Wang 2023-08-07 10:15:51 -07:00 committed by GitHub
parent 4e6ea730d6
commit 8a335b8347
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -11,7 +11,6 @@ import torch._dynamo
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
from torch.library import Library
from torch.onnx._internal.exporter import ExportOptions
import onnxruntime
from onnxruntime.training.torchdynamo.ort_backend import OrtBackend
@ -99,25 +98,25 @@ class TestTorchDynamoOrtCustomOp(unittest.TestCase):
"""
torch._dynamo.reset()
# Create executor of ONNX model.
# We will register a custom exporter for aten.mul.Tensor
# in the following step.
ort_backend = OrtBackend(
ep="CPUExecutionProvider",
session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(),
onnx_exporter_options=ExportOptions(dynamic_shapes=True),
)
# Register custom_exporter_for_aten_add_Tensor as "aten::mul.Tensor"'s
# exporter.
# Use custom_exporter_for_aten_add_Tensor.to_function_proto() to see
# the sub-graph representing "aten::mul.Tensor".
ort_backend.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry.register_custom_op(
onnx_registry = torch.onnx.OnnxRegistry()
onnx_registry.register_op(
function=custom_exporter_for_aten_add_Tensor,
namespace="aten",
op_name="mul",
overload="Tensor",
)
# In order to use custom exporting function inside PyTorch-to-ONNX exporter used in DORT, create executor of ONNX model with custom `onnx_registry`.
ort_backend = OrtBackend(
ep="CPUExecutionProvider",
session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(),
onnx_exporter_options=torch.onnx.ExportOptions(dynamic_shapes=True, onnx_registry=onnx_registry),
)
# Wrap ORT executor as a Dynamo backend.
aot_ort = aot_autograd(
fw_compiler=ort_backend,
@ -159,21 +158,26 @@ class TestTorchDynamoOrtCustomOp(unittest.TestCase):
foo_lib.impl(bar_name, bar_impl, "CompositeExplicitAutograd")
# Ask exporter to map "torch.ops.foo.bar" to
# custom_exporter_for_foo_bar_default.
onnx_registry = torch.onnx.OnnxRegistry()
onnx_registry.register_op(
function=custom_exporter_for_aten_add_Tensor,
namespace="aten",
op_name="mul",
overload="Tensor",
)
# Create executor of ONNX model.
ort_backend = OrtBackend(
ep="CPUExecutionProvider", session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options()
ep="CPUExecutionProvider",
session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(),
onnx_exporter_options=torch.onnx.ExportOptions(onnx_registry=onnx_registry),
)
# Allow torch.ops.foo.bar.default to be sent to DORT.
# _support_dict tells Dynamo which ops to sent to DORT.
ort_backend._supported_ops._support_dict.add(torch.ops.foo.bar.default)
# Ask exporter to map "torch.ops.foo.bar" to
# custom_exporter_for_foo_bar_default.
# TODO(wechi): Redesign API to expose this better.
ort_backend.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry.register_custom_op(
function=custom_exporter_for_foo_bar_default,
namespace="foo",
op_name="bar",
)
# Wrap ORT executor as a Dynamo backend.
aot_ort = aot_autograd(
fw_compiler=ort_backend,