onnxruntime/tools/ci_build/reduce_op_kernels.py
Edward Chen 3bc91c2151
Move reduced ops files into build directory (#10030)
In a reduced ops build, some source files get updated. This change moves the updated files into the build directory. This way, it is easier to simultaneously manage different build directories (with possibly different reduced ops configurations) based on a single source directory.
2021-12-28 19:04:20 -08:00

232 lines
10 KiB
Python
Executable file

# !/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import op_registration_utils
import shutil
import sys
import typing
from logger import get_logger
from pathlib import Path
# directory containing the reduced op files, relative to the build directory
OP_REDUCTION_DIR = "op_reduction.generated"
# add the path to /tools/python so we can import the config parsing and type reduction processing
SCRIPT_DIR = Path(__file__).parent.resolve()
ORT_ROOT = SCRIPT_DIR.parents[1]
sys.path.append(str(ORT_ROOT / 'tools' / 'python'))
from util import parse_config # noqa
from util.ort_format_model.operator_type_usage_processors import OpTypeImplFilterInterface # noqa
log = get_logger("reduce_op_kernels")
class _ExcludingRegistrationProcessor(op_registration_utils.RegistrationProcessor):
'''Registration processor that excludes registrations and writes the result to an output file.'''
def __init__(self, required_ops: typing.Optional[dict],
op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface],
output_file: str):
self._required_ops = required_ops
self._op_type_impl_filter = op_type_impl_filter
self._output_file = output_file
def _is_op_required(self, domain: str, operator: str,
start_version: int, end_version: typing.Optional[int]) -> typing.Tuple[bool, str]:
'''See if an op is required.'''
if self._required_ops is None:
return True
if domain not in self._required_ops:
return False
for opset in self._required_ops[domain]:
if opset >= start_version and (end_version is None or opset <= end_version):
if operator in self._required_ops[domain][opset]:
return True
return False
def process_registration(self, lines: typing.List[str], constant_for_domain: str, operator: str,
start_version: int, end_version: typing.Optional[int] = None,
type: typing.Optional[str] = None):
registration_identifier = '{}:{}({}){}'.format(constant_for_domain, operator, start_version,
'<{}>'.format(type) if type else '')
# convert from the ORT constant name to the domain string used in the config
domain = op_registration_utils.map_ort_constant_to_domain(constant_for_domain)
exclude = False
reason = ""
if domain is not None:
if not self._is_op_required(domain, operator, start_version, end_version):
exclude = True
reason = "Entire op is not required."
if not exclude and type is not None and self._op_type_impl_filter is not None:
if not self._op_type_impl_filter.is_typed_registration_needed(domain, operator, type):
exclude = True
reason = "Specific typed registration is not required."
else:
log.warning('Keeping {} registration from unknown domain: {}'
.format(registration_identifier, constant_for_domain))
if exclude:
log.info('Disabling {} registration: {}'.format(registration_identifier, reason))
for line in lines:
self._output_file.write('// ' + line)
# edge case of last entry in table where we still need the terminating }; to not be commented out
if lines[-1].rstrip().endswith('};'):
self._output_file.write('};\n')
else:
for line in lines:
self._output_file.write(line)
def process_other_line(self, line):
self._output_file.write(line)
def ok(self):
return True
def _get_op_reduction_file_path(ort_root: Path, build_dir: Path, original_path: typing.Optional[Path] = None):
'''
Return the op reduction file path corresponding to `original_path` or the op reduction file root if unspecified.
Op reduction files are in a subdirectory of `build_dir` but otherwise share the same components of `original_path`
relative to `ort_root`.
'''
op_reduction_root = Path(build_dir, OP_REDUCTION_DIR)
return (op_reduction_root / original_path.relative_to(ort_root)) if original_path is not None \
else op_reduction_root
def _generate_provider_registrations(
ort_root: Path, build_dir: Path, use_cuda: bool,
required_ops: typing.Optional[dict],
op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface]):
'''Generate provider registration files.'''
kernel_registration_files = [Path(f) for f in
op_registration_utils.get_kernel_registration_files(str(ort_root), use_cuda)]
for kernel_registration_file in kernel_registration_files:
if not kernel_registration_file.is_file():
raise ValueError(f'Kernel registration file does not exist: {kernel_registration_file}')
log.info("Processing {}".format(kernel_registration_file))
reduced_path = _get_op_reduction_file_path(ort_root, build_dir, kernel_registration_file)
reduced_path.parent.mkdir(parents=True, exist_ok=True)
# read from original and create the reduced kernel def file with commented out lines for any kernels that are
# not required
with open(reduced_path, 'w') as file_to_write:
processor = _ExcludingRegistrationProcessor(required_ops, op_type_impl_filter, file_to_write)
op_registration_utils.process_kernel_registration_file(kernel_registration_file, processor)
if not processor.ok():
# error should have already been logged so just exit
sys.exit(-1)
def _generate_type_control_overrides(ort_root: Path, build_dir: Path, cpp_lines: typing.Sequence[str]):
'''
Generate type control overrides. Insert applicable C++ code to specify operator type requirements.
:param ort_root: Root of the ONNX Runtime repository
:param build_dir: Path to the build directory
:param cpp_lines: The C++ code to insert
'''
src = Path(ort_root, 'onnxruntime', 'core', 'providers', 'op_kernel_type_control_overrides.inc')
if not src.is_file():
raise ValueError(f"Op kernel type control overrides file does not exist: {src}")
# create a copy of op_kernel_type_control_overrides.inc
target = _get_op_reduction_file_path(ort_root, build_dir, src)
target.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(src, target)
if cpp_lines:
# find the insertion block and replace any existing content in it
inserted = False
with open(src, 'r') as input, open(target, 'w') as output:
inside_insertion_block = False
for line in input.readlines():
if '@@insertion_point_begin(allowed_types)@@' in line:
inside_insertion_block = True
output.write(line)
[output.write('{}\n'.format(code_line)) for code_line in cpp_lines]
inserted = True
continue
elif inside_insertion_block:
if '@@insertion_point_end(allowed_types)@@' in line:
inside_insertion_block = False
else:
# we ignore any old lines within the insertion block
continue
output.write(line)
if not inserted:
raise RuntimeError('Insertion point was not found in {}'.format(target))
def reduce_ops(config_path: str, build_dir: str, enable_type_reduction: bool = False, use_cuda: bool = True):
'''
Reduce op kernel implementations.
:param config_path: Path to configuration file that specifies the ops to include
:param build_dir: Path to the build directory. The op reduction files will be generated under the build directory.
:param enable_type_reduction: Whether per operator type reduction is enabled
:param use_cuda: Whether to reduce op kernels for the CUDA provider
'''
build_dir = Path(build_dir).resolve()
build_dir.mkdir(parents=True, exist_ok=True)
required_ops, op_type_impl_filter = parse_config(config_path, enable_type_reduction)
# delete any existing generated files first
op_reduction_root = _get_op_reduction_file_path(ORT_ROOT, build_dir)
if op_reduction_root.is_dir():
log.info(f"Deleting existing op reduction file root directory: {op_reduction_root}")
shutil.rmtree(op_reduction_root)
_generate_provider_registrations(ORT_ROOT, build_dir, use_cuda, required_ops, op_type_impl_filter)
type_control_cpp_code = op_type_impl_filter.get_cpp_entries() if op_type_impl_filter is not None else []
_generate_type_control_overrides(ORT_ROOT, build_dir, type_control_cpp_code)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Reduces operator kernel implementations in ONNX Runtime. "
"Entire op implementations or op implementations for specific types may be pruned.")
parser.add_argument("config_path", type=str,
help="Path to configuration file. "
"Create with <ORT root>/tools/python/create_reduced_build_config.py and edit if needed. "
"See /docs/ONNX_Runtime_Format_Model_Usage.md for more information.")
parser.add_argument("--cmake_build_dir", type=str, required=True,
help="Path to the build directory. "
"The op reduction files will be generated under the build directory.")
parser.add_argument("--enable_type_reduction", action="store_true",
help="Whether per operator type reduction is enabled.")
parser.add_argument("--use_cuda", action="store_true",
help="Whether to reduce op kernels for the CUDA provider.")
args = parser.parse_args()
reduce_ops(config_path=args.config_path,
build_dir=args.cmake_build_dir,
enable_type_reduction=args.enable_type_reduction,
use_cuda=args.use_cuda)