onnxruntime/tools/ci_build/reduce_op_kernels.py
Justin Chu fdce4fa6af
Format all python files under onnxruntime with black and isort (#11324)
Description: Format all python files under onnxruntime with black and isort.

After checking in, we can use .git-blame-ignore-revs to ignore the formatting PR in git blame.

#11315, #11316
2022-04-26 09:35:16 -07:00

257 lines
9.9 KiB
Python
Executable file

# !/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import shutil
import sys
import typing
from pathlib import Path
import op_registration_utils
from logger import get_logger
# directory containing the reduced op files, relative to the build directory
OP_REDUCTION_DIR = "op_reduction.generated"
# add the path to /tools/python so we can import the config parsing and type reduction processing
SCRIPT_DIR = Path(__file__).parent.resolve()
ORT_ROOT = SCRIPT_DIR.parents[1]
sys.path.append(str(ORT_ROOT / "tools" / "python"))
from util import parse_config # noqa
from util.ort_format_model.operator_type_usage_processors import OpTypeImplFilterInterface # noqa
log = get_logger("reduce_op_kernels")
class _ExcludingRegistrationProcessor(op_registration_utils.RegistrationProcessor):
"""Registration processor that excludes registrations and writes the result to an output file."""
def __init__(
self,
required_ops: typing.Optional[dict],
op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface],
output_file: str,
):
self._required_ops = required_ops
self._op_type_impl_filter = op_type_impl_filter
self._output_file = output_file
def _is_op_required(
self, domain: str, operator: str, start_version: int, end_version: typing.Optional[int]
) -> typing.Tuple[bool, str]:
"""See if an op is required."""
if self._required_ops is None:
return True
if domain not in self._required_ops:
return False
for opset in self._required_ops[domain]:
if opset >= start_version and (end_version is None or opset <= end_version):
if operator in self._required_ops[domain][opset]:
return True
return False
def process_registration(
self,
lines: typing.List[str],
constant_for_domain: str,
operator: str,
start_version: int,
end_version: typing.Optional[int] = None,
type: typing.Optional[str] = None,
):
registration_identifier = "{}:{}({}){}".format(
constant_for_domain, operator, start_version, "<{}>".format(type) if type else ""
)
# convert from the ORT constant name to the domain string used in the config
domain = op_registration_utils.map_ort_constant_to_domain(constant_for_domain)
exclude = False
reason = ""
if domain is not None:
if not self._is_op_required(domain, operator, start_version, end_version):
exclude = True
reason = "Entire op is not required."
if not exclude and type is not None and self._op_type_impl_filter is not None:
if not self._op_type_impl_filter.is_typed_registration_needed(domain, operator, type):
exclude = True
reason = "Specific typed registration is not required."
else:
log.warning(
"Keeping {} registration from unknown domain: {}".format(registration_identifier, constant_for_domain)
)
if exclude:
log.info("Disabling {} registration: {}".format(registration_identifier, reason))
for line in lines:
self._output_file.write("// " + line)
# edge case of last entry in table where we still need the terminating }; to not be commented out
if lines[-1].rstrip().endswith("};"):
self._output_file.write("};\n")
else:
for line in lines:
self._output_file.write(line)
def process_other_line(self, line):
self._output_file.write(line)
def ok(self):
return True
def _get_op_reduction_file_path(ort_root: Path, build_dir: Path, original_path: typing.Optional[Path] = None):
"""
Return the op reduction file path corresponding to `original_path` or the op reduction file root if unspecified.
Op reduction files are in a subdirectory of `build_dir` but otherwise share the same components of `original_path`
relative to `ort_root`.
"""
op_reduction_root = Path(build_dir, OP_REDUCTION_DIR)
return (op_reduction_root / original_path.relative_to(ort_root)) if original_path is not None else op_reduction_root
def _generate_provider_registrations(
ort_root: Path,
build_dir: Path,
use_cuda: bool,
required_ops: typing.Optional[dict],
op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface],
):
"""Generate provider registration files."""
kernel_registration_files = [
Path(f) for f in op_registration_utils.get_kernel_registration_files(str(ort_root), use_cuda)
]
for kernel_registration_file in kernel_registration_files:
if not kernel_registration_file.is_file():
raise ValueError(f"Kernel registration file does not exist: {kernel_registration_file}")
log.info("Processing {}".format(kernel_registration_file))
reduced_path = _get_op_reduction_file_path(ort_root, build_dir, kernel_registration_file)
reduced_path.parent.mkdir(parents=True, exist_ok=True)
# read from original and create the reduced kernel def file with commented out lines for any kernels that are
# not required
with open(reduced_path, "w") as file_to_write:
processor = _ExcludingRegistrationProcessor(required_ops, op_type_impl_filter, file_to_write)
op_registration_utils.process_kernel_registration_file(kernel_registration_file, processor)
if not processor.ok():
# error should have already been logged so just exit
sys.exit(-1)
def _generate_type_control_overrides(ort_root: Path, build_dir: Path, cpp_lines: typing.Sequence[str]):
"""
Generate type control overrides. Insert applicable C++ code to specify operator type requirements.
:param ort_root: Root of the ONNX Runtime repository
:param build_dir: Path to the build directory
:param cpp_lines: The C++ code to insert
"""
src = Path(ort_root, "onnxruntime", "core", "providers", "op_kernel_type_control_overrides.inc")
if not src.is_file():
raise ValueError(f"Op kernel type control overrides file does not exist: {src}")
# create a copy of op_kernel_type_control_overrides.inc
target = _get_op_reduction_file_path(ort_root, build_dir, src)
target.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(src, target)
if cpp_lines:
# find the insertion block and replace any existing content in it
inserted = False
with open(src, "r") as input, open(target, "w") as output:
inside_insertion_block = False
for line in input.readlines():
if "@@insertion_point_begin(allowed_types)@@" in line:
inside_insertion_block = True
output.write(line)
[output.write("{}\n".format(code_line)) for code_line in cpp_lines]
inserted = True
continue
elif inside_insertion_block:
if "@@insertion_point_end(allowed_types)@@" in line:
inside_insertion_block = False
else:
# we ignore any old lines within the insertion block
continue
output.write(line)
if not inserted:
raise RuntimeError("Insertion point was not found in {}".format(target))
def reduce_ops(config_path: str, build_dir: str, enable_type_reduction: bool = False, use_cuda: bool = True):
"""
Reduce op kernel implementations.
:param config_path: Path to configuration file that specifies the ops to include
:param build_dir: Path to the build directory. The op reduction files will be generated under the build directory.
:param enable_type_reduction: Whether per operator type reduction is enabled
:param use_cuda: Whether to reduce op kernels for the CUDA provider
"""
build_dir = Path(build_dir).resolve()
build_dir.mkdir(parents=True, exist_ok=True)
required_ops, op_type_impl_filter = parse_config(config_path, enable_type_reduction)
# delete any existing generated files first
op_reduction_root = _get_op_reduction_file_path(ORT_ROOT, build_dir)
if op_reduction_root.is_dir():
log.info(f"Deleting existing op reduction file root directory: {op_reduction_root}")
shutil.rmtree(op_reduction_root)
_generate_provider_registrations(ORT_ROOT, build_dir, use_cuda, required_ops, op_type_impl_filter)
type_control_cpp_code = op_type_impl_filter.get_cpp_entries() if op_type_impl_filter is not None else []
_generate_type_control_overrides(ORT_ROOT, build_dir, type_control_cpp_code)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Reduces operator kernel implementations in ONNX Runtime. "
"Entire op implementations or op implementations for specific types may be pruned."
)
parser.add_argument(
"config_path",
type=str,
help="Path to configuration file. "
"Create with <ORT root>/tools/python/create_reduced_build_config.py and edit if needed. "
"See /docs/ONNX_Runtime_Format_Model_Usage.md for more information.",
)
parser.add_argument(
"--cmake_build_dir",
type=str,
required=True,
help="Path to the build directory. " "The op reduction files will be generated under the build directory.",
)
parser.add_argument(
"--enable_type_reduction", action="store_true", help="Whether per operator type reduction is enabled."
)
parser.add_argument("--use_cuda", action="store_true", help="Whether to reduce op kernels for the CUDA provider.")
args = parser.parse_args()
reduce_ops(
config_path=args.config_path,
build_dir=args.cmake_build_dir,
enable_type_reduction=args.enable_type_reduction,
use_cuda=args.use_cuda,
)