From b71ebf91a560127fa190f2ececcba7dfab5d1481 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 18 Jul 2023 09:06:58 -0700 Subject: [PATCH] [DORT] Reduce global configs to make enabling dynamic shape easier (#16720) There are several global configs used by DORT. ```py DEFAULT_ONNX_EXPORTER_OPTIONS = torch.onnx._internal.exporter.ResolvedExportOptions( torch.onnx._internal.exporter.ExportOptions() ) # TODO(wechi): This line must generate result identical to the call of # _create_onnx_supports_op_overload_table(...) inside # create_onnx_friendly_decomposition_table(...) in # torch/onnx/_internal/fx/decomposition_table.py. _SUPPORT_DICT = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table( DEFAULT_ONNX_EXPORTER_OPTIONS.onnx_registry ) # type: ignore _EXTRA_SUPPORT_DICT: Dict[str, Any] = { "getattr": None, "_operator.getitem": None, } DORT_DECOMPOSITION_TABLE = DEFAULT_ONNX_EXPORTER_OPTIONS.decomposition_table ``` We can see all but `_EXTRA_SUPPORT_DICT` are extracted from deduced from ONNX exporter's options. As there are many ways to configure ONNX exporter's options, we decided to move these variables to `OrtBackend`'s `__init__` so that the construction of `OrtBackend` becomes more flexible (especially for enabling dynamic shape or not). --- .../training/torchdynamo/ort_backend.py | 87 +++++++++++-------- .../training/torchdynamo/register_backend.py | 4 +- .../orttraining_test_dort_custom_ops.py | 54 +++++------- 3 files changed, 77 insertions(+), 68 deletions(-) diff --git a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py index 7493127924..2454079dc9 100644 --- a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py +++ b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py @@ -5,7 +5,7 @@ import dataclasses import logging -from typing import Any, Dict, Mapping, Tuple, Union +from typing import Any, Dict, Mapping, Optional, Set, Tuple, Union import numpy as np import onnx @@ -14,7 +14,6 @@ import torch._C import torch._ops import torch._prims.executor import torch.fx -import torch.jit import torch.onnx # TODO(wschin,justinchuby): Since the internal APIs are not stable, please @@ -24,7 +23,6 @@ import torch.onnx._internal.diagnostics import torch.onnx._internal.exporter import torch.onnx._internal.fx.decomposition_table import torch.onnx._internal.fx.passes -import torch.onnx._onnx_supported_ops from torch._subclasses.fake_tensor import FakeTensor from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner @@ -35,30 +33,6 @@ from torch.utils import _pytree import onnxruntime # type: ignore from onnxruntime.capi import _pybind_state as ORTC -# DEFAULT_ONNX_EXPORTER_OPTIONS contains shared information between exporter and DORT. -# For example, they should use the same decomposition table to maintain the same set -# operators when -# 1. capturing FX graph in torch.compile -# 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model. -DEFAULT_ONNX_EXPORTER_OPTIONS = torch.onnx._internal.exporter.ResolvedExportOptions( - torch.onnx._internal.exporter.ExportOptions() -) - -# TODO(wechi): This line must generate result identical to the call of -# _create_onnx_supports_op_overload_table(...) inside -# create_onnx_friendly_decomposition_table(...) in -# torch/onnx/_internal/fx/decomposition_table.py. -_SUPPORT_DICT = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table( - DEFAULT_ONNX_EXPORTER_OPTIONS.onnx_registry -) # type: ignore - -_EXTRA_SUPPORT_DICT: Dict[str, Any] = { - "getattr": None, - "_operator.getitem": None, -} - -DORT_DECOMPOSITION_TABLE = DEFAULT_ONNX_EXPORTER_OPTIONS.decomposition_table - _NP_DTYPE = { torch.float16: np.float16, torch.float32: np.float32, @@ -115,8 +89,13 @@ class OrtOperatorSupport(OperatorSupport): is used by OperatorSupport.is_node_supported. """ - def __init__(self): - super().__init__(_EXTRA_SUPPORT_DICT) + def __init__(self, support_dict: Set[Any], extra_support_dict: Dict[str, Any]): + # Use extra_support_dict[op_name] = None to indicate + # we support op_name with all input types. Otherwise, + # see support_dict (type: SupportDict) in operator_support.py + # for specifying supported types. + super().__init__(extra_support_dict) + self._support_dict = support_dict def is_node_supported(self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> bool: # OperatorSupport.is_node_supported returns True for non-callable nodes. @@ -125,7 +104,7 @@ class OrtOperatorSupport(OperatorSupport): if node.op not in CALLABLE_NODE_OPS: return False # This is the and the only place to decide if aten op is supported. - if node.op == "call_function" and node.target in _SUPPORT_DICT: + if node.op == "call_function" and node.target in self._support_dict: logger.info("support_dict supports node.target: %s (type: %s)", node.target, type(node.target)) return True logger.info("support_dict doesn't support node.target: %s (type: %s)", node.target, type(node.target)) @@ -390,8 +369,44 @@ class OrtBackend: 3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph. """ - def __init__(self, ep: str = "CPUExecutionProvider", preallocate_output: bool = False, session_options=None): - self._supported_ops = OrtOperatorSupport() + def __init__( + self, + ep: str = "CPUExecutionProvider", + preallocate_output: bool = False, + session_options=None, + onnx_exporter_options: Optional["torch.onnx._internal.exporter.ExportOptions"] = None, + ): + # onnx_exporter_options contains information shared between exporter and DORT. + # For example, they should use the same decomposition table when + # 1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py) + # 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model + # (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below). + if onnx_exporter_options is None: + onnx_exporter_options = torch.onnx._internal.exporter.ExportOptions() + # Convert user-facing option to internal option used by ONNX exporter + # to access required information. + # Some useful fields: + # - Decomposition table for decomposing FX operators in exporter is + # self.resolved_onnx_exporter_options.decomposition_table. + # - self.resolved_onnx_exporter_options.onnx_registry records what + # aten/prim ops are supported by exporter and their exporters (type: callable). + self.resolved_onnx_exporter_options = torch.onnx._internal.exporter.ResolvedExportOptions(onnx_exporter_options) + + # TODO(wechi): This line must generate result identical to the call of + # _create_onnx_supports_op_overload_table(...) inside + # create_onnx_friendly_decomposition_table(...) in + # torch/onnx/_internal/fx/decomposition_table.py. + support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table( + # This is identical to self.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry. + self.resolved_onnx_exporter_options.onnx_registry + ) # type: ignore + + extra_support_dict: Dict[str, Any] = { + "getattr": None, + "_operator.getitem": None, + } + + self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict) # TODO: this is a naive implementation of cache without proper guard self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {} # TODO: this is a naive implementation of cache without proper guard, this will only work for identical inputs @@ -447,18 +462,18 @@ class OrtBackend: # Create the object to iterate through the nodes in graph one-by-one # and calls the corresponding ONNX exporter for each node. fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter( - diagnostic_context=DEFAULT_ONNX_EXPORTER_OPTIONS.diagnostic_context + diagnostic_context=self.resolved_onnx_exporter_options.diagnostic_context ) # Start the per-node exporting process. It's conceptually a for loop # scanning through the nodes in the graph. exported = fx_interpreter.run( fx_graph_module=graph_module, - onnxfunction_dispatcher=DEFAULT_ONNX_EXPORTER_OPTIONS.onnxfunction_dispatcher, - op_level_debug=DEFAULT_ONNX_EXPORTER_OPTIONS.op_level_debug, + onnxfunction_dispatcher=self.resolved_onnx_exporter_options.onnxfunction_dispatcher, + op_level_debug=self.resolved_onnx_exporter_options.op_level_debug, ) # Convert the exported result to ONNX ModelProto. onnx_proto = exported.to_model_proto( - opset_version=DEFAULT_ONNX_EXPORTER_OPTIONS.opset_version + opset_version=self.resolved_onnx_exporter_options.opset_version ).SerializeToString() # Initialize a ORT session to execute this ONNX model. diff --git a/orttraining/orttraining/python/training/torchdynamo/register_backend.py b/orttraining/orttraining/python/training/torchdynamo/register_backend.py index 1aa2692e70..9030c6f8fb 100644 --- a/orttraining/orttraining/python/training/torchdynamo/register_backend.py +++ b/orttraining/orttraining/python/training/torchdynamo/register_backend.py @@ -6,7 +6,7 @@ from functorch.compile import min_cut_rematerialization_partition from torch._dynamo.backends.common import aot_autograd -from .ort_backend import DORT_DECOMPOSITION_TABLE, OrtBackend +from .ort_backend import OrtBackend # This should be the underlying compiler for ALL graphs if # the user uses ORT to accelerate PyTorch via Dynamo. @@ -32,7 +32,7 @@ DEFAULT_BACKEND = OrtBackend() aot_ort = aot_autograd( fw_compiler=DEFAULT_BACKEND, partition_fn=min_cut_rematerialization_partition, - decompositions=DORT_DECOMPOSITION_TABLE, + decompositions=DEFAULT_BACKEND.resolved_onnx_exporter_options.decomposition_table, ) # Declare ORT as a compiler in Dynamo for inference (i.e., when .backward is NOT called). diff --git a/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py b/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py index 66c1bc672d..9c18b0347f 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py @@ -13,12 +13,7 @@ from torch._dynamo.backends.common import aot_autograd from torch.library import Library import onnxruntime -from onnxruntime.training.torchdynamo.ort_backend import ( - _SUPPORT_DICT, - DEFAULT_ONNX_EXPORTER_OPTIONS, - DORT_DECOMPOSITION_TABLE, - OrtBackend, -) +from onnxruntime.training.torchdynamo.ort_backend import OrtBackend # Dummy operator set to map aten::mul.Tensor to test.customop::CustomOpOne # in ONNX model executed by DORT. @@ -35,18 +30,6 @@ def custom_exporter_for_aten_add_Tensor(x, y): return custom_opset.CustomOpOne(x, y) -# Register custom_exporter_for_aten_add_Tensor as "aten::mul.Tensor"'s -# exporter. -# Use custom_exporter_for_aten_add_Tensor.to_function_proto() to investigate -# function representing "aten::mul.Tensor". -DEFAULT_ONNX_EXPORTER_OPTIONS.onnxfunction_dispatcher.onnx_registry.register_custom_op( - function=custom_exporter_for_aten_add_Tensor, - namespace="aten", - op_name="mul", - overload="Tensor", -) - - # Exporter for torch.ops.foo.bar.default. @onnxscript.script(custom_opset) def custom_exporter_for_foo_bar_default(x): @@ -55,15 +38,6 @@ def custom_exporter_for_foo_bar_default(x): return custom_opset.CustomOpOne(x, x) -# Ask exporter to map "torch.ops.foo.bar" to -# custom_exporter_for_foo_bar_default. -DEFAULT_ONNX_EXPORTER_OPTIONS.onnxfunction_dispatcher.onnx_registry.register_custom_op( - function=custom_exporter_for_foo_bar_default, - namespace="foo", - op_name="bar", -) - - class TestTorchDynamoOrtCustomOp(unittest.TestCase): """Containers of custom op lib test for TorchDynamo ORT (DORT) backend.""" @@ -122,10 +96,21 @@ class TestTorchDynamoOrtCustomOp(unittest.TestCase): session_options = TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options() ort_backend = OrtBackend(ep="CPUExecutionProvider", session_options=session_options) + # Register custom_exporter_for_aten_add_Tensor as "aten::mul.Tensor"'s + # exporter. + # Use custom_exporter_for_aten_add_Tensor.to_function_proto() to see + # the sub-graph representing "aten::mul.Tensor". + ort_backend.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry.register_custom_op( + function=custom_exporter_for_aten_add_Tensor, + namespace="aten", + op_name="mul", + overload="Tensor", + ) + aot_ort = aot_autograd( fw_compiler=ort_backend, partition_fn=min_cut_rematerialization_partition, - decompositions=DORT_DECOMPOSITION_TABLE, + decompositions=ort_backend.resolved_onnx_exporter_options.decomposition_table, ) def one_mul(tensor_x: torch.Tensor, tensor_y: torch.Tensor): @@ -155,14 +140,23 @@ class TestTorchDynamoOrtCustomOp(unittest.TestCase): foo_lib.impl(bar_name, bar_impl, "CompositeExplicitAutograd") # TODO(wechi): Redesign API to expose this better. - _SUPPORT_DICT.add(torch.ops.foo.bar.default) session_options = TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options() ort_backend = OrtBackend(ep="CPUExecutionProvider", session_options=session_options) + # Allow torch.ops.foo.bar.default to be sent to DORT. + # _support_dict tells Dynamo which ops to sent to DORT. + ort_backend._supported_ops._support_dict.add(torch.ops.foo.bar.default) + # Ask exporter to map "torch.ops.foo.bar" to + # custom_exporter_for_foo_bar_default. + ort_backend.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry.register_custom_op( + function=custom_exporter_for_foo_bar_default, + namespace="foo", + op_name="bar", + ) aot_ort = aot_autograd( fw_compiler=ort_backend, partition_fn=min_cut_rematerialization_partition, - decompositions=DORT_DECOMPOSITION_TABLE, + decompositions=ort_backend.resolved_onnx_exporter_options.decomposition_table, ) def one_foo(tensor_x: torch.Tensor):