mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
[DORT] Reduce global configs to make enabling dynamic shape easier (#16720)
There are several global configs used by DORT.
```py
DEFAULT_ONNX_EXPORTER_OPTIONS = torch.onnx._internal.exporter.ResolvedExportOptions(
torch.onnx._internal.exporter.ExportOptions()
)
# TODO(wechi): This line must generate result identical to the call of
# _create_onnx_supports_op_overload_table(...) inside
# create_onnx_friendly_decomposition_table(...) in
# torch/onnx/_internal/fx/decomposition_table.py.
_SUPPORT_DICT = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table(
DEFAULT_ONNX_EXPORTER_OPTIONS.onnx_registry
) # type: ignore
_EXTRA_SUPPORT_DICT: Dict[str, Any] = {
"getattr": None,
"_operator.getitem": None,
}
DORT_DECOMPOSITION_TABLE = DEFAULT_ONNX_EXPORTER_OPTIONS.decomposition_table
```
We can see all but `_EXTRA_SUPPORT_DICT` are extracted from deduced from
ONNX exporter's options. As there are many ways to configure ONNX
exporter's options, we decided to move these variables to `OrtBackend`'s
`__init__` so that the construction of `OrtBackend` becomes more
flexible (especially for enabling dynamic shape or not).
This commit is contained in:
parent
9b549c646c
commit
b71ebf91a5
3 changed files with 77 additions and 68 deletions
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Any, Dict, Mapping, Tuple, Union
|
||||
from typing import Any, Dict, Mapping, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
|
|
@ -14,7 +14,6 @@ import torch._C
|
|||
import torch._ops
|
||||
import torch._prims.executor
|
||||
import torch.fx
|
||||
import torch.jit
|
||||
import torch.onnx
|
||||
|
||||
# TODO(wschin,justinchuby): Since the internal APIs are not stable, please
|
||||
|
|
@ -24,7 +23,6 @@ import torch.onnx._internal.diagnostics
|
|||
import torch.onnx._internal.exporter
|
||||
import torch.onnx._internal.fx.decomposition_table
|
||||
import torch.onnx._internal.fx.passes
|
||||
import torch.onnx._onnx_supported_ops
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
||||
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
||||
|
|
@ -35,30 +33,6 @@ from torch.utils import _pytree
|
|||
import onnxruntime # type: ignore
|
||||
from onnxruntime.capi import _pybind_state as ORTC
|
||||
|
||||
# DEFAULT_ONNX_EXPORTER_OPTIONS contains shared information between exporter and DORT.
|
||||
# For example, they should use the same decomposition table to maintain the same set
|
||||
# operators when
|
||||
# 1. capturing FX graph in torch.compile
|
||||
# 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model.
|
||||
DEFAULT_ONNX_EXPORTER_OPTIONS = torch.onnx._internal.exporter.ResolvedExportOptions(
|
||||
torch.onnx._internal.exporter.ExportOptions()
|
||||
)
|
||||
|
||||
# TODO(wechi): This line must generate result identical to the call of
|
||||
# _create_onnx_supports_op_overload_table(...) inside
|
||||
# create_onnx_friendly_decomposition_table(...) in
|
||||
# torch/onnx/_internal/fx/decomposition_table.py.
|
||||
_SUPPORT_DICT = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table(
|
||||
DEFAULT_ONNX_EXPORTER_OPTIONS.onnx_registry
|
||||
) # type: ignore
|
||||
|
||||
_EXTRA_SUPPORT_DICT: Dict[str, Any] = {
|
||||
"getattr": None,
|
||||
"_operator.getitem": None,
|
||||
}
|
||||
|
||||
DORT_DECOMPOSITION_TABLE = DEFAULT_ONNX_EXPORTER_OPTIONS.decomposition_table
|
||||
|
||||
_NP_DTYPE = {
|
||||
torch.float16: np.float16,
|
||||
torch.float32: np.float32,
|
||||
|
|
@ -115,8 +89,13 @@ class OrtOperatorSupport(OperatorSupport):
|
|||
is used by OperatorSupport.is_node_supported.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(_EXTRA_SUPPORT_DICT)
|
||||
def __init__(self, support_dict: Set[Any], extra_support_dict: Dict[str, Any]):
|
||||
# Use extra_support_dict[op_name] = None to indicate
|
||||
# we support op_name with all input types. Otherwise,
|
||||
# see support_dict (type: SupportDict) in operator_support.py
|
||||
# for specifying supported types.
|
||||
super().__init__(extra_support_dict)
|
||||
self._support_dict = support_dict
|
||||
|
||||
def is_node_supported(self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> bool:
|
||||
# OperatorSupport.is_node_supported returns True for non-callable nodes.
|
||||
|
|
@ -125,7 +104,7 @@ class OrtOperatorSupport(OperatorSupport):
|
|||
if node.op not in CALLABLE_NODE_OPS:
|
||||
return False
|
||||
# This is the and the only place to decide if aten op is supported.
|
||||
if node.op == "call_function" and node.target in _SUPPORT_DICT:
|
||||
if node.op == "call_function" and node.target in self._support_dict:
|
||||
logger.info("support_dict supports node.target: %s (type: %s)", node.target, type(node.target))
|
||||
return True
|
||||
logger.info("support_dict doesn't support node.target: %s (type: %s)", node.target, type(node.target))
|
||||
|
|
@ -390,8 +369,44 @@ class OrtBackend:
|
|||
3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph.
|
||||
"""
|
||||
|
||||
def __init__(self, ep: str = "CPUExecutionProvider", preallocate_output: bool = False, session_options=None):
|
||||
self._supported_ops = OrtOperatorSupport()
|
||||
def __init__(
|
||||
self,
|
||||
ep: str = "CPUExecutionProvider",
|
||||
preallocate_output: bool = False,
|
||||
session_options=None,
|
||||
onnx_exporter_options: Optional["torch.onnx._internal.exporter.ExportOptions"] = None,
|
||||
):
|
||||
# onnx_exporter_options contains information shared between exporter and DORT.
|
||||
# For example, they should use the same decomposition table when
|
||||
# 1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py)
|
||||
# 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model
|
||||
# (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below).
|
||||
if onnx_exporter_options is None:
|
||||
onnx_exporter_options = torch.onnx._internal.exporter.ExportOptions()
|
||||
# Convert user-facing option to internal option used by ONNX exporter
|
||||
# to access required information.
|
||||
# Some useful fields:
|
||||
# - Decomposition table for decomposing FX operators in exporter is
|
||||
# self.resolved_onnx_exporter_options.decomposition_table.
|
||||
# - self.resolved_onnx_exporter_options.onnx_registry records what
|
||||
# aten/prim ops are supported by exporter and their exporters (type: callable).
|
||||
self.resolved_onnx_exporter_options = torch.onnx._internal.exporter.ResolvedExportOptions(onnx_exporter_options)
|
||||
|
||||
# TODO(wechi): This line must generate result identical to the call of
|
||||
# _create_onnx_supports_op_overload_table(...) inside
|
||||
# create_onnx_friendly_decomposition_table(...) in
|
||||
# torch/onnx/_internal/fx/decomposition_table.py.
|
||||
support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table(
|
||||
# This is identical to self.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry.
|
||||
self.resolved_onnx_exporter_options.onnx_registry
|
||||
) # type: ignore
|
||||
|
||||
extra_support_dict: Dict[str, Any] = {
|
||||
"getattr": None,
|
||||
"_operator.getitem": None,
|
||||
}
|
||||
|
||||
self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict)
|
||||
# TODO: this is a naive implementation of cache without proper guard
|
||||
self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {}
|
||||
# TODO: this is a naive implementation of cache without proper guard, this will only work for identical inputs
|
||||
|
|
@ -447,18 +462,18 @@ class OrtBackend:
|
|||
# Create the object to iterate through the nodes in graph one-by-one
|
||||
# and calls the corresponding ONNX exporter for each node.
|
||||
fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter(
|
||||
diagnostic_context=DEFAULT_ONNX_EXPORTER_OPTIONS.diagnostic_context
|
||||
diagnostic_context=self.resolved_onnx_exporter_options.diagnostic_context
|
||||
)
|
||||
# Start the per-node exporting process. It's conceptually a for loop
|
||||
# scanning through the nodes in the graph.
|
||||
exported = fx_interpreter.run(
|
||||
fx_graph_module=graph_module,
|
||||
onnxfunction_dispatcher=DEFAULT_ONNX_EXPORTER_OPTIONS.onnxfunction_dispatcher,
|
||||
op_level_debug=DEFAULT_ONNX_EXPORTER_OPTIONS.op_level_debug,
|
||||
onnxfunction_dispatcher=self.resolved_onnx_exporter_options.onnxfunction_dispatcher,
|
||||
op_level_debug=self.resolved_onnx_exporter_options.op_level_debug,
|
||||
)
|
||||
# Convert the exported result to ONNX ModelProto.
|
||||
onnx_proto = exported.to_model_proto(
|
||||
opset_version=DEFAULT_ONNX_EXPORTER_OPTIONS.opset_version
|
||||
opset_version=self.resolved_onnx_exporter_options.opset_version
|
||||
).SerializeToString()
|
||||
|
||||
# Initialize a ORT session to execute this ONNX model.
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
|
||||
from .ort_backend import DORT_DECOMPOSITION_TABLE, OrtBackend
|
||||
from .ort_backend import OrtBackend
|
||||
|
||||
# This should be the underlying compiler for ALL graphs if
|
||||
# the user uses ORT to accelerate PyTorch via Dynamo.
|
||||
|
|
@ -32,7 +32,7 @@ DEFAULT_BACKEND = OrtBackend()
|
|||
aot_ort = aot_autograd(
|
||||
fw_compiler=DEFAULT_BACKEND,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
decompositions=DORT_DECOMPOSITION_TABLE,
|
||||
decompositions=DEFAULT_BACKEND.resolved_onnx_exporter_options.decomposition_table,
|
||||
)
|
||||
|
||||
# Declare ORT as a compiler in Dynamo for inference (i.e., when .backward is NOT called).
|
||||
|
|
|
|||
|
|
@ -13,12 +13,7 @@ from torch._dynamo.backends.common import aot_autograd
|
|||
from torch.library import Library
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.training.torchdynamo.ort_backend import (
|
||||
_SUPPORT_DICT,
|
||||
DEFAULT_ONNX_EXPORTER_OPTIONS,
|
||||
DORT_DECOMPOSITION_TABLE,
|
||||
OrtBackend,
|
||||
)
|
||||
from onnxruntime.training.torchdynamo.ort_backend import OrtBackend
|
||||
|
||||
# Dummy operator set to map aten::mul.Tensor to test.customop::CustomOpOne
|
||||
# in ONNX model executed by DORT.
|
||||
|
|
@ -35,18 +30,6 @@ def custom_exporter_for_aten_add_Tensor(x, y):
|
|||
return custom_opset.CustomOpOne(x, y)
|
||||
|
||||
|
||||
# Register custom_exporter_for_aten_add_Tensor as "aten::mul.Tensor"'s
|
||||
# exporter.
|
||||
# Use custom_exporter_for_aten_add_Tensor.to_function_proto() to investigate
|
||||
# function representing "aten::mul.Tensor".
|
||||
DEFAULT_ONNX_EXPORTER_OPTIONS.onnxfunction_dispatcher.onnx_registry.register_custom_op(
|
||||
function=custom_exporter_for_aten_add_Tensor,
|
||||
namespace="aten",
|
||||
op_name="mul",
|
||||
overload="Tensor",
|
||||
)
|
||||
|
||||
|
||||
# Exporter for torch.ops.foo.bar.default.
|
||||
@onnxscript.script(custom_opset)
|
||||
def custom_exporter_for_foo_bar_default(x):
|
||||
|
|
@ -55,15 +38,6 @@ def custom_exporter_for_foo_bar_default(x):
|
|||
return custom_opset.CustomOpOne(x, x)
|
||||
|
||||
|
||||
# Ask exporter to map "torch.ops.foo.bar" to
|
||||
# custom_exporter_for_foo_bar_default.
|
||||
DEFAULT_ONNX_EXPORTER_OPTIONS.onnxfunction_dispatcher.onnx_registry.register_custom_op(
|
||||
function=custom_exporter_for_foo_bar_default,
|
||||
namespace="foo",
|
||||
op_name="bar",
|
||||
)
|
||||
|
||||
|
||||
class TestTorchDynamoOrtCustomOp(unittest.TestCase):
|
||||
"""Containers of custom op lib test for TorchDynamo ORT (DORT) backend."""
|
||||
|
||||
|
|
@ -122,10 +96,21 @@ class TestTorchDynamoOrtCustomOp(unittest.TestCase):
|
|||
session_options = TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options()
|
||||
|
||||
ort_backend = OrtBackend(ep="CPUExecutionProvider", session_options=session_options)
|
||||
# Register custom_exporter_for_aten_add_Tensor as "aten::mul.Tensor"'s
|
||||
# exporter.
|
||||
# Use custom_exporter_for_aten_add_Tensor.to_function_proto() to see
|
||||
# the sub-graph representing "aten::mul.Tensor".
|
||||
ort_backend.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry.register_custom_op(
|
||||
function=custom_exporter_for_aten_add_Tensor,
|
||||
namespace="aten",
|
||||
op_name="mul",
|
||||
overload="Tensor",
|
||||
)
|
||||
|
||||
aot_ort = aot_autograd(
|
||||
fw_compiler=ort_backend,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
decompositions=DORT_DECOMPOSITION_TABLE,
|
||||
decompositions=ort_backend.resolved_onnx_exporter_options.decomposition_table,
|
||||
)
|
||||
|
||||
def one_mul(tensor_x: torch.Tensor, tensor_y: torch.Tensor):
|
||||
|
|
@ -155,14 +140,23 @@ class TestTorchDynamoOrtCustomOp(unittest.TestCase):
|
|||
foo_lib.impl(bar_name, bar_impl, "CompositeExplicitAutograd")
|
||||
|
||||
# TODO(wechi): Redesign API to expose this better.
|
||||
_SUPPORT_DICT.add(torch.ops.foo.bar.default)
|
||||
|
||||
session_options = TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options()
|
||||
ort_backend = OrtBackend(ep="CPUExecutionProvider", session_options=session_options)
|
||||
# Allow torch.ops.foo.bar.default to be sent to DORT.
|
||||
# _support_dict tells Dynamo which ops to sent to DORT.
|
||||
ort_backend._supported_ops._support_dict.add(torch.ops.foo.bar.default)
|
||||
# Ask exporter to map "torch.ops.foo.bar" to
|
||||
# custom_exporter_for_foo_bar_default.
|
||||
ort_backend.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry.register_custom_op(
|
||||
function=custom_exporter_for_foo_bar_default,
|
||||
namespace="foo",
|
||||
op_name="bar",
|
||||
)
|
||||
aot_ort = aot_autograd(
|
||||
fw_compiler=ort_backend,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
decompositions=DORT_DECOMPOSITION_TABLE,
|
||||
decompositions=ort_backend.resolved_onnx_exporter_options.decomposition_table,
|
||||
)
|
||||
|
||||
def one_foo(tensor_x: torch.Tensor):
|
||||
|
|
|
|||
Loading…
Reference in a new issue