[DORT] Reduce global configs to make enabling dynamic shape easier (#16720)

There are several global configs used by DORT.
```py
DEFAULT_ONNX_EXPORTER_OPTIONS = torch.onnx._internal.exporter.ResolvedExportOptions(
    torch.onnx._internal.exporter.ExportOptions()
)

# TODO(wechi): This line must generate result identical to the call of
# _create_onnx_supports_op_overload_table(...) inside
# create_onnx_friendly_decomposition_table(...) in
# torch/onnx/_internal/fx/decomposition_table.py.
_SUPPORT_DICT = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table(
    DEFAULT_ONNX_EXPORTER_OPTIONS.onnx_registry
)  # type: ignore

_EXTRA_SUPPORT_DICT: Dict[str, Any] = {
    "getattr": None,
    "_operator.getitem": None,
}

DORT_DECOMPOSITION_TABLE = DEFAULT_ONNX_EXPORTER_OPTIONS.decomposition_table
```

We can see all but `_EXTRA_SUPPORT_DICT` are extracted from deduced from
ONNX exporter's options. As there are many ways to configure ONNX
exporter's options, we decided to move these variables to `OrtBackend`'s
`__init__` so that the construction of `OrtBackend` becomes more
flexible (especially for enabling dynamic shape or not).
This commit is contained in:
Wei-Sheng Chin 2023-07-18 09:06:58 -07:00 committed by GitHub
parent 9b549c646c
commit b71ebf91a5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 77 additions and 68 deletions

View file

@ -5,7 +5,7 @@
import dataclasses
import logging
from typing import Any, Dict, Mapping, Tuple, Union
from typing import Any, Dict, Mapping, Optional, Set, Tuple, Union
import numpy as np
import onnx
@ -14,7 +14,6 @@ import torch._C
import torch._ops
import torch._prims.executor
import torch.fx
import torch.jit
import torch.onnx
# TODO(wschin,justinchuby): Since the internal APIs are not stable, please
@ -24,7 +23,6 @@ import torch.onnx._internal.diagnostics
import torch.onnx._internal.exporter
import torch.onnx._internal.fx.decomposition_table
import torch.onnx._internal.fx.passes
import torch.onnx._onnx_supported_ops
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
@ -35,30 +33,6 @@ from torch.utils import _pytree
import onnxruntime # type: ignore
from onnxruntime.capi import _pybind_state as ORTC
# DEFAULT_ONNX_EXPORTER_OPTIONS contains shared information between exporter and DORT.
# For example, they should use the same decomposition table to maintain the same set
# operators when
# 1. capturing FX graph in torch.compile
# 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model.
DEFAULT_ONNX_EXPORTER_OPTIONS = torch.onnx._internal.exporter.ResolvedExportOptions(
torch.onnx._internal.exporter.ExportOptions()
)
# TODO(wechi): This line must generate result identical to the call of
# _create_onnx_supports_op_overload_table(...) inside
# create_onnx_friendly_decomposition_table(...) in
# torch/onnx/_internal/fx/decomposition_table.py.
_SUPPORT_DICT = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table(
DEFAULT_ONNX_EXPORTER_OPTIONS.onnx_registry
) # type: ignore
_EXTRA_SUPPORT_DICT: Dict[str, Any] = {
"getattr": None,
"_operator.getitem": None,
}
DORT_DECOMPOSITION_TABLE = DEFAULT_ONNX_EXPORTER_OPTIONS.decomposition_table
_NP_DTYPE = {
torch.float16: np.float16,
torch.float32: np.float32,
@ -115,8 +89,13 @@ class OrtOperatorSupport(OperatorSupport):
is used by OperatorSupport.is_node_supported.
"""
def __init__(self):
super().__init__(_EXTRA_SUPPORT_DICT)
def __init__(self, support_dict: Set[Any], extra_support_dict: Dict[str, Any]):
# Use extra_support_dict[op_name] = None to indicate
# we support op_name with all input types. Otherwise,
# see support_dict (type: SupportDict) in operator_support.py
# for specifying supported types.
super().__init__(extra_support_dict)
self._support_dict = support_dict
def is_node_supported(self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> bool:
# OperatorSupport.is_node_supported returns True for non-callable nodes.
@ -125,7 +104,7 @@ class OrtOperatorSupport(OperatorSupport):
if node.op not in CALLABLE_NODE_OPS:
return False
# This is the and the only place to decide if aten op is supported.
if node.op == "call_function" and node.target in _SUPPORT_DICT:
if node.op == "call_function" and node.target in self._support_dict:
logger.info("support_dict supports node.target: %s (type: %s)", node.target, type(node.target))
return True
logger.info("support_dict doesn't support node.target: %s (type: %s)", node.target, type(node.target))
@ -390,8 +369,44 @@ class OrtBackend:
3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph.
"""
def __init__(self, ep: str = "CPUExecutionProvider", preallocate_output: bool = False, session_options=None):
self._supported_ops = OrtOperatorSupport()
def __init__(
self,
ep: str = "CPUExecutionProvider",
preallocate_output: bool = False,
session_options=None,
onnx_exporter_options: Optional["torch.onnx._internal.exporter.ExportOptions"] = None,
):
# onnx_exporter_options contains information shared between exporter and DORT.
# For example, they should use the same decomposition table when
# 1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py)
# 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model
# (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below).
if onnx_exporter_options is None:
onnx_exporter_options = torch.onnx._internal.exporter.ExportOptions()
# Convert user-facing option to internal option used by ONNX exporter
# to access required information.
# Some useful fields:
# - Decomposition table for decomposing FX operators in exporter is
# self.resolved_onnx_exporter_options.decomposition_table.
# - self.resolved_onnx_exporter_options.onnx_registry records what
# aten/prim ops are supported by exporter and their exporters (type: callable).
self.resolved_onnx_exporter_options = torch.onnx._internal.exporter.ResolvedExportOptions(onnx_exporter_options)
# TODO(wechi): This line must generate result identical to the call of
# _create_onnx_supports_op_overload_table(...) inside
# create_onnx_friendly_decomposition_table(...) in
# torch/onnx/_internal/fx/decomposition_table.py.
support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table(
# This is identical to self.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry.
self.resolved_onnx_exporter_options.onnx_registry
) # type: ignore
extra_support_dict: Dict[str, Any] = {
"getattr": None,
"_operator.getitem": None,
}
self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict)
# TODO: this is a naive implementation of cache without proper guard
self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {}
# TODO: this is a naive implementation of cache without proper guard, this will only work for identical inputs
@ -447,18 +462,18 @@ class OrtBackend:
# Create the object to iterate through the nodes in graph one-by-one
# and calls the corresponding ONNX exporter for each node.
fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter(
diagnostic_context=DEFAULT_ONNX_EXPORTER_OPTIONS.diagnostic_context
diagnostic_context=self.resolved_onnx_exporter_options.diagnostic_context
)
# Start the per-node exporting process. It's conceptually a for loop
# scanning through the nodes in the graph.
exported = fx_interpreter.run(
fx_graph_module=graph_module,
onnxfunction_dispatcher=DEFAULT_ONNX_EXPORTER_OPTIONS.onnxfunction_dispatcher,
op_level_debug=DEFAULT_ONNX_EXPORTER_OPTIONS.op_level_debug,
onnxfunction_dispatcher=self.resolved_onnx_exporter_options.onnxfunction_dispatcher,
op_level_debug=self.resolved_onnx_exporter_options.op_level_debug,
)
# Convert the exported result to ONNX ModelProto.
onnx_proto = exported.to_model_proto(
opset_version=DEFAULT_ONNX_EXPORTER_OPTIONS.opset_version
opset_version=self.resolved_onnx_exporter_options.opset_version
).SerializeToString()
# Initialize a ORT session to execute this ONNX model.

View file

@ -6,7 +6,7 @@
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
from .ort_backend import DORT_DECOMPOSITION_TABLE, OrtBackend
from .ort_backend import OrtBackend
# This should be the underlying compiler for ALL graphs if
# the user uses ORT to accelerate PyTorch via Dynamo.
@ -32,7 +32,7 @@ DEFAULT_BACKEND = OrtBackend()
aot_ort = aot_autograd(
fw_compiler=DEFAULT_BACKEND,
partition_fn=min_cut_rematerialization_partition,
decompositions=DORT_DECOMPOSITION_TABLE,
decompositions=DEFAULT_BACKEND.resolved_onnx_exporter_options.decomposition_table,
)
# Declare ORT as a compiler in Dynamo for inference (i.e., when .backward is NOT called).

View file

@ -13,12 +13,7 @@ from torch._dynamo.backends.common import aot_autograd
from torch.library import Library
import onnxruntime
from onnxruntime.training.torchdynamo.ort_backend import (
_SUPPORT_DICT,
DEFAULT_ONNX_EXPORTER_OPTIONS,
DORT_DECOMPOSITION_TABLE,
OrtBackend,
)
from onnxruntime.training.torchdynamo.ort_backend import OrtBackend
# Dummy operator set to map aten::mul.Tensor to test.customop::CustomOpOne
# in ONNX model executed by DORT.
@ -35,18 +30,6 @@ def custom_exporter_for_aten_add_Tensor(x, y):
return custom_opset.CustomOpOne(x, y)
# Register custom_exporter_for_aten_add_Tensor as "aten::mul.Tensor"'s
# exporter.
# Use custom_exporter_for_aten_add_Tensor.to_function_proto() to investigate
# function representing "aten::mul.Tensor".
DEFAULT_ONNX_EXPORTER_OPTIONS.onnxfunction_dispatcher.onnx_registry.register_custom_op(
function=custom_exporter_for_aten_add_Tensor,
namespace="aten",
op_name="mul",
overload="Tensor",
)
# Exporter for torch.ops.foo.bar.default.
@onnxscript.script(custom_opset)
def custom_exporter_for_foo_bar_default(x):
@ -55,15 +38,6 @@ def custom_exporter_for_foo_bar_default(x):
return custom_opset.CustomOpOne(x, x)
# Ask exporter to map "torch.ops.foo.bar" to
# custom_exporter_for_foo_bar_default.
DEFAULT_ONNX_EXPORTER_OPTIONS.onnxfunction_dispatcher.onnx_registry.register_custom_op(
function=custom_exporter_for_foo_bar_default,
namespace="foo",
op_name="bar",
)
class TestTorchDynamoOrtCustomOp(unittest.TestCase):
"""Containers of custom op lib test for TorchDynamo ORT (DORT) backend."""
@ -122,10 +96,21 @@ class TestTorchDynamoOrtCustomOp(unittest.TestCase):
session_options = TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options()
ort_backend = OrtBackend(ep="CPUExecutionProvider", session_options=session_options)
# Register custom_exporter_for_aten_add_Tensor as "aten::mul.Tensor"'s
# exporter.
# Use custom_exporter_for_aten_add_Tensor.to_function_proto() to see
# the sub-graph representing "aten::mul.Tensor".
ort_backend.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry.register_custom_op(
function=custom_exporter_for_aten_add_Tensor,
namespace="aten",
op_name="mul",
overload="Tensor",
)
aot_ort = aot_autograd(
fw_compiler=ort_backend,
partition_fn=min_cut_rematerialization_partition,
decompositions=DORT_DECOMPOSITION_TABLE,
decompositions=ort_backend.resolved_onnx_exporter_options.decomposition_table,
)
def one_mul(tensor_x: torch.Tensor, tensor_y: torch.Tensor):
@ -155,14 +140,23 @@ class TestTorchDynamoOrtCustomOp(unittest.TestCase):
foo_lib.impl(bar_name, bar_impl, "CompositeExplicitAutograd")
# TODO(wechi): Redesign API to expose this better.
_SUPPORT_DICT.add(torch.ops.foo.bar.default)
session_options = TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options()
ort_backend = OrtBackend(ep="CPUExecutionProvider", session_options=session_options)
# Allow torch.ops.foo.bar.default to be sent to DORT.
# _support_dict tells Dynamo which ops to sent to DORT.
ort_backend._supported_ops._support_dict.add(torch.ops.foo.bar.default)
# Ask exporter to map "torch.ops.foo.bar" to
# custom_exporter_for_foo_bar_default.
ort_backend.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry.register_custom_op(
function=custom_exporter_for_foo_bar_default,
namespace="foo",
op_name="bar",
)
aot_ort = aot_autograd(
fw_compiler=ort_backend,
partition_fn=min_cut_rematerialization_partition,
decompositions=DORT_DECOMPOSITION_TABLE,
decompositions=ort_backend.resolved_onnx_exporter_options.decomposition_table,
)
def one_foo(tensor_x: torch.Tensor):