mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
# Motivation Currently, ORT minimal builds use kernel def hashes to map from nodes to kernels to execute when loading the model. As the kernel def hashes must be known ahead of time, this works for statically registered kernels. This works well for the CPU EP. For this approach to work, the kernel def hashes must also be known at ORT format model conversion time, which means the EP with statically registered kernels must also be enabled then. This is not an issue for the always-available CPU EP. However, we do not want to require that any EP which statically registers kernels is always available too. Consequently, we explore another approach to match nodes to kernels that does not rely on kernel def hashes. An added benefit of this is the possibility of moving away from kernel def hashes completely, which would eliminate the maintenance burden of keeping the hashes stable. # Approach In a full build, ORT uses some information from the ONNX op schema to match a node to a kernel. We want to avoid including the ONNX op schema in a minimal build to reduce binary size. Essentially, we take the necessary information from the ONNX op schema and make it available in a minimal build. We decouple the ONNX op schema from the kernel matching logic. The kernel matching logic instead relies on per-op information which can either be obtained from the ONNX op schema or another source. This per-op information must be available in a minimal build when there are no ONNX op schemas. We put it in the ORT format model. Existing uses of kernel def hashes to look up kernels are replaced with the updated kernel matching logic. We no longer store kernel def hashes in the ORT format model’s session state and runtime optimization representations. We no longer keep the logic to generate and ensure stability of kernel def hashes.
359 lines
14 KiB
Python
Executable file
359 lines
14 KiB
Python
Executable file
# !/usr/bin/env python3
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
|
|
import argparse
|
|
import io
|
|
import re
|
|
import shutil
|
|
import sys
|
|
import typing
|
|
from pathlib import Path
|
|
|
|
import op_registration_utils
|
|
from logger import get_logger
|
|
|
|
# directory containing the reduced op files, relative to the build directory
|
|
OP_REDUCTION_DIR = "op_reduction.generated"
|
|
|
|
# add the path to tools/python so we can import the config parsing and type reduction processing
|
|
SCRIPT_DIR = Path(__file__).parent.resolve()
|
|
ORT_ROOT = SCRIPT_DIR.parents[1]
|
|
sys.path.append(str(ORT_ROOT / "tools" / "python"))
|
|
|
|
from util import parse_config # noqa
|
|
from util.ort_format_model.operator_type_usage_processors import OpTypeImplFilterInterface # noqa
|
|
|
|
log = get_logger("reduce_op_kernels")
|
|
|
|
|
|
def _adapt_filters_for_extended_minimal_build(
|
|
base_required_ops: typing.Optional[dict], base_op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface]
|
|
):
|
|
"""
|
|
Adapts the values returned by parse_config() for an extended minimal build or higher.
|
|
In particular:
|
|
- Includes ONNX ops needed by layout transformation
|
|
- Includes MS ops needed by NHWC optimizer
|
|
"""
|
|
# graph transformations in an extended minimal build require certain ops to be available
|
|
extended_minimal_build_required_op_ids = set() # set of (domain, optype, opset)
|
|
with open(
|
|
ORT_ROOT / "onnxruntime/core/optimizer/transpose_optimizer/layout_transformation_potentially_added_ops.h",
|
|
mode="r",
|
|
) as f:
|
|
region_boundary_pattern = re.compile(r"@@region_(begin|end)\(extended_minimal_build_required_kernels\)@@")
|
|
op_id_pattern = re.compile(
|
|
r'OpIdentifierWithStringViews{(?P<domain>\w+),\s+"(?P<optype>\w+)",\s+(?P<opset>\d+)}'
|
|
)
|
|
in_region = False
|
|
for line in f:
|
|
region_boundary_match = region_boundary_pattern.search(line)
|
|
if region_boundary_match:
|
|
in_region = region_boundary_match.group(1) == "begin"
|
|
continue
|
|
|
|
if not in_region:
|
|
continue
|
|
|
|
op_id_match = op_id_pattern.search(line)
|
|
if op_id_match:
|
|
domain = op_registration_utils.map_ort_constant_to_domain(
|
|
op_id_match.group("domain"), allow_unknown_constant=False
|
|
)
|
|
optype = op_id_match.group("optype")
|
|
opset = int(op_id_match.group("opset"))
|
|
extended_minimal_build_required_op_ids.add((domain, optype, opset))
|
|
|
|
adapted_required_ops = None
|
|
if base_required_ops is not None:
|
|
adapted_required_ops = base_required_ops.copy()
|
|
for domain, optype, opset in extended_minimal_build_required_op_ids:
|
|
adapted_required_ops.setdefault(domain, dict()).setdefault(opset, set()).add(optype)
|
|
|
|
adapted_op_type_impl_filter = None
|
|
if base_op_type_impl_filter is not None:
|
|
|
|
class _AdaptedFilter(OpTypeImplFilterInterface):
|
|
def __init__(
|
|
self,
|
|
filter_to_adapt: OpTypeImplFilterInterface,
|
|
required_domain_and_optypes: typing.Set[typing.Tuple[str, str]],
|
|
):
|
|
self.filter_to_adapt = filter_to_adapt
|
|
self.required_domain_and_optypes = required_domain_and_optypes
|
|
|
|
def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str):
|
|
# Always require registration for ops in self.required_domain_and_optypes.
|
|
if (domain, optype) in self.required_domain_and_optypes:
|
|
return True
|
|
return self.filter_to_adapt.is_typed_registration_needed(domain, optype, type_registration_str)
|
|
|
|
def get_cpp_entries(self):
|
|
# The required types for ops in self.required_optypes must be specified in the C++ implementation.
|
|
# Doing that also accounts for globally allowed types.
|
|
# We don't need to do anything special with the allowed type overrides here.
|
|
return self.filter_to_adapt.get_cpp_entries()
|
|
|
|
adapted_op_type_impl_filter = _AdaptedFilter(
|
|
base_op_type_impl_filter,
|
|
set([(domain, optype) for (domain, optype, opset) in extended_minimal_build_required_op_ids]),
|
|
)
|
|
|
|
return (adapted_required_ops, adapted_op_type_impl_filter)
|
|
|
|
|
|
class _ExcludingRegistrationProcessor(op_registration_utils.RegistrationProcessor):
|
|
"""Registration processor that excludes registrations and writes the result to an output file."""
|
|
|
|
def __init__(
|
|
self,
|
|
required_ops: typing.Optional[dict],
|
|
op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface],
|
|
output_file: io.TextIOWrapper,
|
|
):
|
|
self._required_ops = required_ops
|
|
self._op_type_impl_filter = op_type_impl_filter
|
|
self._output_file = output_file
|
|
|
|
def _is_op_required(
|
|
self, domain: str, operator: str, start_version: int, end_version: typing.Optional[int]
|
|
) -> bool:
|
|
"""See if an op is required."""
|
|
if self._required_ops is None:
|
|
return True
|
|
|
|
if domain not in self._required_ops:
|
|
return False
|
|
|
|
for opset in self._required_ops[domain]:
|
|
if opset >= start_version and (end_version is None or opset <= end_version):
|
|
if operator in self._required_ops[domain][opset]:
|
|
return True
|
|
|
|
return False
|
|
|
|
def process_registration(
|
|
self,
|
|
lines: typing.List[str],
|
|
constant_for_domain: str,
|
|
operator: str,
|
|
start_version: int,
|
|
end_version: typing.Optional[int] = None,
|
|
type: typing.Optional[str] = None,
|
|
):
|
|
registration_identifier = "{}:{}({}){}".format(
|
|
constant_for_domain, operator, start_version, "<{}>".format(type) if type else ""
|
|
)
|
|
|
|
# convert from the ORT constant name to the domain string used in the config
|
|
domain = op_registration_utils.map_ort_constant_to_domain(constant_for_domain)
|
|
|
|
exclude = False
|
|
reason = ""
|
|
|
|
if domain is not None:
|
|
if not self._is_op_required(domain, operator, start_version, end_version):
|
|
exclude = True
|
|
reason = "Entire op is not required."
|
|
|
|
if not exclude and type is not None and self._op_type_impl_filter is not None:
|
|
if not self._op_type_impl_filter.is_typed_registration_needed(domain, operator, type):
|
|
exclude = True
|
|
reason = "Specific typed registration is not required."
|
|
else:
|
|
log.warning(
|
|
"Keeping {} registration from unknown domain: {}".format(registration_identifier, constant_for_domain)
|
|
)
|
|
|
|
if exclude:
|
|
log.info("Disabling {} registration: {}".format(registration_identifier, reason))
|
|
for line in lines:
|
|
self._output_file.write("// " + line)
|
|
|
|
# edge case of last entry in table where we still need the terminating }; to not be commented out
|
|
if lines[-1].rstrip().endswith("};"):
|
|
self._output_file.write("};\n")
|
|
else:
|
|
for line in lines:
|
|
self._output_file.write(line)
|
|
|
|
def process_other_line(self, line):
|
|
self._output_file.write(line)
|
|
|
|
def ok(self):
|
|
return True
|
|
|
|
|
|
def _get_op_reduction_root(build_dir: Path):
|
|
"""
|
|
Return the op reduction root directory which is a subdirectory of `build_dir`.
|
|
"""
|
|
return Path(build_dir, OP_REDUCTION_DIR)
|
|
|
|
|
|
def _get_op_reduction_file_path(ort_root: Path, build_dir: Path, original_path: Path):
|
|
"""
|
|
Return the op reduction file path corresponding to `original_path`.
|
|
Op reduction files are in the op reduction root but otherwise share the same components of `original_path`
|
|
relative to `ort_root`.
|
|
"""
|
|
return _get_op_reduction_root(build_dir) / original_path.relative_to(ort_root)
|
|
|
|
|
|
def _generate_provider_registrations(
|
|
ort_root: Path,
|
|
build_dir: Path,
|
|
use_cuda: bool,
|
|
required_ops: typing.Optional[dict],
|
|
op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface],
|
|
):
|
|
"""Generate provider registration files."""
|
|
kernel_registration_files = [
|
|
Path(f) for f in op_registration_utils.get_kernel_registration_files(str(ort_root), use_cuda)
|
|
]
|
|
|
|
for kernel_registration_file in kernel_registration_files:
|
|
if not kernel_registration_file.is_file():
|
|
raise ValueError(f"Kernel registration file does not exist: {kernel_registration_file}")
|
|
|
|
log.info("Processing {}".format(kernel_registration_file))
|
|
|
|
reduced_path = _get_op_reduction_file_path(ort_root, build_dir, kernel_registration_file)
|
|
|
|
reduced_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# read from original and create the reduced kernel def file with commented out lines for any kernels that are
|
|
# not required
|
|
with open(reduced_path, "w") as file_to_write:
|
|
processor = _ExcludingRegistrationProcessor(required_ops, op_type_impl_filter, file_to_write)
|
|
|
|
op_registration_utils.process_kernel_registration_file(kernel_registration_file, processor)
|
|
|
|
if not processor.ok():
|
|
# error should have already been logged so just exit
|
|
sys.exit(-1)
|
|
|
|
|
|
def _generate_type_control_overrides(ort_root: Path, build_dir: Path, cpp_lines: typing.Sequence[str]):
|
|
"""
|
|
Generate type control overrides. Insert applicable C++ code to specify operator type requirements.
|
|
:param ort_root: Root of the ONNX Runtime repository
|
|
:param build_dir: Path to the build directory
|
|
:param cpp_lines: The C++ code to insert
|
|
"""
|
|
src = Path(ort_root, "onnxruntime", "core", "providers", "op_kernel_type_control_overrides.inc")
|
|
|
|
if not src.is_file():
|
|
raise ValueError(f"Op kernel type control overrides file does not exist: {src}")
|
|
|
|
# create a copy of op_kernel_type_control_overrides.inc
|
|
target = _get_op_reduction_file_path(ort_root, build_dir, src)
|
|
|
|
target.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
shutil.copyfile(src, target)
|
|
|
|
if cpp_lines:
|
|
# find the insertion block and replace any existing content in it
|
|
inserted = False
|
|
with open(src, "r") as input, open(target, "w") as output:
|
|
inside_insertion_block = False
|
|
for line in input.readlines():
|
|
if "@@insertion_point_begin(allowed_types)@@" in line:
|
|
inside_insertion_block = True
|
|
output.write(line)
|
|
[output.write("{}\n".format(code_line)) for code_line in cpp_lines]
|
|
inserted = True
|
|
continue
|
|
elif inside_insertion_block:
|
|
if "@@insertion_point_end(allowed_types)@@" in line:
|
|
inside_insertion_block = False
|
|
else:
|
|
# we ignore any old lines within the insertion block
|
|
continue
|
|
|
|
output.write(line)
|
|
|
|
if not inserted:
|
|
raise RuntimeError("Insertion point was not found in {}".format(target))
|
|
|
|
|
|
def reduce_ops(
|
|
config_path: str,
|
|
build_dir: str,
|
|
enable_type_reduction: bool,
|
|
use_cuda: bool,
|
|
is_extended_minimal_build_or_higher: bool,
|
|
):
|
|
"""
|
|
Reduce op kernel implementations.
|
|
:param config_path: Path to configuration file that specifies the ops to include
|
|
:param build_dir: Path to the build directory. The op reduction files will be generated under the build directory.
|
|
:param enable_type_reduction: Whether per operator type reduction is enabled
|
|
:param use_cuda: Whether to reduce op kernels for the CUDA provider
|
|
:param is_extended_minimal_build_or_higher: Whether this build has at least the features of an extended minimal
|
|
build enabled.
|
|
"""
|
|
build_dir_path = Path(build_dir).resolve()
|
|
build_dir_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
required_ops, op_type_impl_filter = parse_config(config_path, enable_type_reduction)
|
|
if is_extended_minimal_build_or_higher:
|
|
required_ops, op_type_impl_filter = _adapt_filters_for_extended_minimal_build(required_ops, op_type_impl_filter)
|
|
|
|
# delete any existing generated files first
|
|
op_reduction_root = _get_op_reduction_root(build_dir_path)
|
|
if op_reduction_root.is_dir():
|
|
log.info(f"Deleting existing op reduction file root directory: {op_reduction_root}")
|
|
shutil.rmtree(op_reduction_root)
|
|
|
|
_generate_provider_registrations(ORT_ROOT, build_dir_path, use_cuda, required_ops, op_type_impl_filter)
|
|
|
|
type_control_cpp_code = op_type_impl_filter.get_cpp_entries() if op_type_impl_filter is not None else []
|
|
_generate_type_control_overrides(ORT_ROOT, build_dir_path, type_control_cpp_code)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Reduces operator kernel implementations in ONNX Runtime. "
|
|
"Entire op implementations or op implementations for specific types may be pruned."
|
|
)
|
|
|
|
parser.add_argument(
|
|
"config_path",
|
|
type=str,
|
|
help="Path to configuration file. "
|
|
"Create with <ORT root>/tools/python/create_reduced_build_config.py and edit if needed. "
|
|
"See https://onnxruntime.ai/docs/reference/reduced-operator-config-file.html for more "
|
|
"information.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--cmake_build_dir",
|
|
type=str,
|
|
required=True,
|
|
help="Path to the build directory. " "The op reduction files will be generated under the build directory.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--is_extended_minimal_build_or_higher",
|
|
action="store_true",
|
|
help="Whether this build has at least the features of an extended minimal build enabled.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--enable_type_reduction", action="store_true", help="Whether per operator type reduction is enabled."
|
|
)
|
|
|
|
parser.add_argument("--use_cuda", action="store_true", help="Whether to reduce op kernels for the CUDA provider.")
|
|
|
|
args = parser.parse_args()
|
|
|
|
reduce_ops(
|
|
config_path=args.config_path,
|
|
build_dir=args.cmake_build_dir,
|
|
enable_type_reduction=args.enable_type_reduction,
|
|
use_cuda=args.use_cuda,
|
|
is_extended_minimal_build_or_higher=args.is_extended_minimal_build_or_higher,
|
|
)
|