onnxruntime/tools/ci_build/reduce_op_kernels.py
Justin Chu c203d89958
Update ruff and clang-format versions (#21479)
ruff -> 0.5.4
clang-format -> 18
2024-07-24 11:50:11 -07:00

356 lines
14 KiB
Python
Executable file

# !/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import io
import re
import shutil
import sys
import typing
from pathlib import Path
import op_registration_utils
from logger import get_logger
# directory containing the reduced op files, relative to the build directory
OP_REDUCTION_DIR = "op_reduction.generated"
# add the path to tools/python so we can import the config parsing and type reduction processing
SCRIPT_DIR = Path(__file__).parent.resolve()
ORT_ROOT = SCRIPT_DIR.parents[1]
sys.path.append(str(ORT_ROOT / "tools" / "python"))
from util import parse_config # noqa: E402
from util.ort_format_model.operator_type_usage_processors import OpTypeImplFilterInterface # noqa: E402
log = get_logger("reduce_op_kernels")
def _adapt_filters_for_extended_minimal_build(
base_required_ops: typing.Optional[dict], base_op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface]
):
"""
Adapts the values returned by parse_config() for an extended minimal build or higher.
In particular:
- Includes ONNX ops needed by layout transformation
- Includes MS ops needed by NHWC optimizer
"""
# graph transformations in an extended minimal build require certain ops to be available
extended_minimal_build_required_op_ids = set() # set of (domain, optype, opset)
with open(
ORT_ROOT / "onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h",
) as f:
region_boundary_pattern = re.compile(r"@@region_(begin|end)\(extended_minimal_build_required_kernels\)@@")
op_id_pattern = re.compile(
r'OpIdentifierWithStringViews{(?P<domain>\w+),\s+"(?P<optype>\w+)",\s+(?P<opset>\d+)}'
)
in_region = False
for line in f:
region_boundary_match = region_boundary_pattern.search(line)
if region_boundary_match:
in_region = region_boundary_match.group(1) == "begin"
continue
if not in_region:
continue
op_id_match = op_id_pattern.search(line)
if op_id_match:
domain = op_registration_utils.map_ort_constant_to_domain(
op_id_match.group("domain"), allow_unknown_constant=False
)
optype = op_id_match.group("optype")
opset = int(op_id_match.group("opset"))
extended_minimal_build_required_op_ids.add((domain, optype, opset))
adapted_required_ops = None
if base_required_ops is not None:
adapted_required_ops = base_required_ops.copy()
for domain, optype, opset in extended_minimal_build_required_op_ids:
adapted_required_ops.setdefault(domain, dict()).setdefault(opset, set()).add(optype)
adapted_op_type_impl_filter = None
if base_op_type_impl_filter is not None:
class _AdaptedFilter(OpTypeImplFilterInterface):
def __init__(
self,
filter_to_adapt: OpTypeImplFilterInterface,
required_domain_and_optypes: typing.Set[typing.Tuple[str, str]],
):
self.filter_to_adapt = filter_to_adapt
self.required_domain_and_optypes = required_domain_and_optypes
def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str):
# Always require registration for ops in self.required_domain_and_optypes.
if (domain, optype) in self.required_domain_and_optypes:
return True
return self.filter_to_adapt.is_typed_registration_needed(domain, optype, type_registration_str)
def get_cpp_entries(self):
# The required types for ops in self.required_optypes must be specified in the C++ implementation.
# Doing that also accounts for globally allowed types.
# We don't need to do anything special with the allowed type overrides here.
return self.filter_to_adapt.get_cpp_entries()
adapted_op_type_impl_filter = _AdaptedFilter(
base_op_type_impl_filter,
{(domain, optype) for (domain, optype, opset) in extended_minimal_build_required_op_ids},
)
return (adapted_required_ops, adapted_op_type_impl_filter)
class _ExcludingRegistrationProcessor(op_registration_utils.RegistrationProcessor):
"""Registration processor that excludes registrations and writes the result to an output file."""
def __init__(
self,
required_ops: typing.Optional[dict],
op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface],
output_file: io.TextIOWrapper,
):
self._required_ops = required_ops
self._op_type_impl_filter = op_type_impl_filter
self._output_file = output_file
def _is_op_required(
self, domain: str, operator: str, start_version: int, end_version: typing.Optional[int]
) -> bool:
"""See if an op is required."""
if self._required_ops is None:
return True
if domain not in self._required_ops:
return False
for opset in self._required_ops[domain]:
if opset >= start_version and (end_version is None or opset <= end_version):
if operator in self._required_ops[domain][opset]:
return True
return False
def process_registration(
self,
lines: typing.List[str],
constant_for_domain: str,
operator: str,
start_version: int,
end_version: typing.Optional[int] = None,
type: typing.Optional[str] = None,
):
registration_identifier = "{}:{}({}){}".format(
constant_for_domain, operator, start_version, f"<{type}>" if type else ""
)
# convert from the ORT constant name to the domain string used in the config
domain = op_registration_utils.map_ort_constant_to_domain(constant_for_domain, allow_unknown_constant=False)
exclude = False
reason = ""
if domain is not None:
if not self._is_op_required(domain, operator, start_version, end_version):
exclude = True
reason = "Entire op is not required."
if not exclude and type is not None and self._op_type_impl_filter is not None:
if not self._op_type_impl_filter.is_typed_registration_needed(domain, operator, type):
exclude = True
reason = "Specific typed registration is not required."
else:
log.warning(f"Keeping {registration_identifier} registration from unknown domain: {constant_for_domain}")
if exclude:
log.info(f"Disabling {registration_identifier} registration: {reason}")
for line in lines:
self._output_file.write("// " + line)
# edge case of last entry in table where we still need the terminating }; to not be commented out
if lines[-1].rstrip().endswith("};"):
self._output_file.write("};\n")
else:
for line in lines:
self._output_file.write(line)
def process_other_line(self, line):
self._output_file.write(line)
def ok(self):
return True
def _get_op_reduction_root(build_dir: Path):
"""
Return the op reduction root directory which is a subdirectory of `build_dir`.
"""
return Path(build_dir, OP_REDUCTION_DIR)
def _get_op_reduction_file_path(ort_root: Path, build_dir: Path, original_path: Path):
"""
Return the op reduction file path corresponding to `original_path`.
Op reduction files are in the op reduction root but otherwise share the same components of `original_path`
relative to `ort_root`.
"""
return _get_op_reduction_root(build_dir) / original_path.relative_to(ort_root)
def _generate_provider_registrations(
ort_root: Path,
build_dir: Path,
use_cuda: bool,
required_ops: typing.Optional[dict],
op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface],
):
"""Generate provider registration files."""
kernel_registration_files = [
Path(f) for f in op_registration_utils.get_kernel_registration_files(str(ort_root), use_cuda)
]
for kernel_registration_file in kernel_registration_files:
if not kernel_registration_file.is_file():
raise ValueError(f"Kernel registration file does not exist: {kernel_registration_file}")
log.info(f"Processing {kernel_registration_file}")
reduced_path = _get_op_reduction_file_path(ort_root, build_dir, kernel_registration_file)
reduced_path.parent.mkdir(parents=True, exist_ok=True)
# read from original and create the reduced kernel def file with commented out lines for any kernels that are
# not required
with open(reduced_path, "w") as file_to_write:
processor = _ExcludingRegistrationProcessor(required_ops, op_type_impl_filter, file_to_write)
op_registration_utils.process_kernel_registration_file(kernel_registration_file, processor)
if not processor.ok():
# error should have already been logged so just exit
sys.exit(-1)
def _generate_type_control_overrides(ort_root: Path, build_dir: Path, cpp_lines: typing.Sequence[str]):
"""
Generate type control overrides. Insert applicable C++ code to specify operator type requirements.
:param ort_root: Root of the ONNX Runtime repository
:param build_dir: Path to the build directory
:param cpp_lines: The C++ code to insert
"""
src = Path(ort_root, "onnxruntime", "core", "providers", "op_kernel_type_control_overrides.inc")
if not src.is_file():
raise ValueError(f"Op kernel type control overrides file does not exist: {src}")
# create a copy of op_kernel_type_control_overrides.inc
target = _get_op_reduction_file_path(ort_root, build_dir, src)
target.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(src, target)
if cpp_lines:
# find the insertion block and replace any existing content in it
inserted = False
with open(src) as input, open(target, "w") as output:
inside_insertion_block = False
for line in input:
if "@@insertion_point_begin(allowed_types)@@" in line:
inside_insertion_block = True
output.write(line)
[output.write(f"{code_line}\n") for code_line in cpp_lines]
inserted = True
continue
elif inside_insertion_block:
if "@@insertion_point_end(allowed_types)@@" in line:
inside_insertion_block = False
else:
# we ignore any old lines within the insertion block
continue
output.write(line)
if not inserted:
raise RuntimeError(f"Insertion point was not found in {target}")
def reduce_ops(
config_path: str,
build_dir: str,
enable_type_reduction: bool,
use_cuda: bool,
is_extended_minimal_build_or_higher: bool,
):
"""
Reduce op kernel implementations.
:param config_path: Path to configuration file that specifies the ops to include
:param build_dir: Path to the build directory. The op reduction files will be generated under the build directory.
:param enable_type_reduction: Whether per operator type reduction is enabled
:param use_cuda: Whether to reduce op kernels for the CUDA provider
:param is_extended_minimal_build_or_higher: Whether this build has at least the features of an extended minimal
build enabled.
"""
build_dir_path = Path(build_dir).resolve()
build_dir_path.mkdir(parents=True, exist_ok=True)
required_ops, op_type_impl_filter = parse_config(config_path, enable_type_reduction)
if is_extended_minimal_build_or_higher:
required_ops, op_type_impl_filter = _adapt_filters_for_extended_minimal_build(required_ops, op_type_impl_filter)
# delete any existing generated files first
op_reduction_root = _get_op_reduction_root(build_dir_path)
if op_reduction_root.is_dir():
log.info(f"Deleting existing op reduction file root directory: {op_reduction_root}")
shutil.rmtree(op_reduction_root)
_generate_provider_registrations(ORT_ROOT, build_dir_path, use_cuda, required_ops, op_type_impl_filter)
type_control_cpp_code = op_type_impl_filter.get_cpp_entries() if op_type_impl_filter is not None else []
_generate_type_control_overrides(ORT_ROOT, build_dir_path, type_control_cpp_code)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Reduces operator kernel implementations in ONNX Runtime. "
"Entire op implementations or op implementations for specific types may be pruned."
)
parser.add_argument(
"config_path",
type=str,
help="Path to configuration file. "
"Create with <ORT root>/tools/python/create_reduced_build_config.py and edit if needed. "
"See https://onnxruntime.ai/docs/reference/operators/reduced-operator-config-file.html for more "
"information.",
)
parser.add_argument(
"--cmake_build_dir",
type=str,
required=True,
help="Path to the build directory. The op reduction files will be generated under the build directory.",
)
parser.add_argument(
"--is_extended_minimal_build_or_higher",
action="store_true",
help="Whether this build has at least the features of an extended minimal build enabled.",
)
parser.add_argument(
"--enable_type_reduction", action="store_true", help="Whether per operator type reduction is enabled."
)
parser.add_argument("--use_cuda", action="store_true", help="Whether to reduce op kernels for the CUDA provider.")
args = parser.parse_args()
reduce_ops(
config_path=args.config_path,
build_dir=args.cmake_build_dir,
enable_type_reduction=args.enable_type_reduction,
use_cuda=args.use_cuda,
is_extended_minimal_build_or_higher=args.is_extended_minimal_build_or_higher,
)