mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
262 lines
10 KiB
Python
262 lines
10 KiB
Python
# !/usr/bin/env python3
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
'''
|
|
Exclude operators that are unused/not required from build to reduce binary size.
|
|
'''
|
|
|
|
import argparse
|
|
import onnx
|
|
import op_registration_utils
|
|
import os
|
|
import shutil
|
|
import sys
|
|
import typing
|
|
|
|
from onnx import AttributeProto
|
|
from logger import get_logger
|
|
|
|
|
|
log = get_logger("exclude_unused_ops")
|
|
|
|
|
|
def _extract_ops_from_config(file_path, required_ops):
|
|
'''extract ops from config file of format: domain;opset;op1,op2...'''
|
|
|
|
if not file_path:
|
|
return required_ops
|
|
|
|
if not os.path.isfile(file_path):
|
|
# exit. to continue may result in unexpectedly disabling all kernels.
|
|
log.error('Configuration file {} does not exist'.format(file_path))
|
|
sys.exit(-1)
|
|
|
|
with open(file_path, 'r') as file_to_read:
|
|
for stripped_line in [line.strip() for line in file_to_read.readlines()]:
|
|
|
|
if not stripped_line: # skip empty lines
|
|
continue
|
|
|
|
if stripped_line.startswith("#"): # skip comments
|
|
continue
|
|
|
|
raw_domain, raw_opset, raw_ops = [segment.strip() for segment in stripped_line.split(';')]
|
|
|
|
domain = op_registration_utils.map_domain(raw_domain)
|
|
opset = int(raw_opset)
|
|
operators = set([raw_op.strip() for raw_op in raw_ops.split(',')])
|
|
|
|
if domain not in required_ops:
|
|
required_ops[domain] = {opset: operators}
|
|
elif opset not in required_ops[domain]:
|
|
required_ops[domain][opset] = operators
|
|
else:
|
|
required_ops[domain][opset].update(operators)
|
|
|
|
return required_ops # end of _extract_ops_from_file(...)
|
|
|
|
|
|
def _extract_ops_from_graph(graph, operators, domain_opset_map):
|
|
'''extract ops from graph and all subgraphs'''
|
|
|
|
for operator in graph.node:
|
|
mapped_domain = op_registration_utils.map_domain(operator.domain)
|
|
|
|
if mapped_domain not in operators or mapped_domain not in domain_opset_map:
|
|
continue
|
|
|
|
operators[mapped_domain][domain_opset_map[mapped_domain]].add(operator.op_type)
|
|
|
|
for attr in operator.attribute:
|
|
if attr.type == AttributeProto.GRAPH: # process subgraph
|
|
_extract_ops_from_graph(attr.g, operators, domain_opset_map)
|
|
|
|
|
|
def _extract_ops_from_model(model_path, required_ops):
|
|
'''extract ops from models under model_path and return a diction'''
|
|
|
|
if not model_path:
|
|
return required_ops
|
|
|
|
if not os.path.isdir(model_path):
|
|
# exit. to continue may result in unexpectedly disabling all kernels.
|
|
log.error('Directory containing models does not exist: {}'.format(model_path))
|
|
sys.exit(-1)
|
|
|
|
for root, _, files in os.walk(model_path):
|
|
for file in files:
|
|
if file.endswith('.onnx'):
|
|
model_path = os.path.join(root, file)
|
|
model = onnx.load(model_path)
|
|
domain_opset_map = {}
|
|
|
|
if len(model.opset_import) == 0:
|
|
continue
|
|
|
|
for opset in model.opset_import:
|
|
mapped_domain = op_registration_utils.map_domain(opset.domain)
|
|
domain_opset_map[mapped_domain] = opset.version
|
|
|
|
if mapped_domain not in required_ops:
|
|
required_ops[mapped_domain] = {opset.version: set()}
|
|
|
|
elif opset.version not in required_ops[mapped_domain]:
|
|
required_ops[mapped_domain][opset.version] = set()
|
|
|
|
_extract_ops_from_graph(model.graph, required_ops, domain_opset_map)
|
|
|
|
return required_ops # end of _extract_ops_from_model(...)
|
|
|
|
|
|
class ExcludeOpsRegistrationProcessor(op_registration_utils.RegistrationProcessor):
|
|
def __init__(self, required_ops, output_file):
|
|
self.required_ops = required_ops
|
|
self.output_file = output_file
|
|
|
|
def _should_exclude_op(self, domain, operator, start_version, end_version):
|
|
if domain not in self.required_ops:
|
|
return True
|
|
|
|
for opset in self.required_ops[domain]:
|
|
if opset >= start_version and (end_version is None or opset <= end_version):
|
|
if operator in self.required_ops[domain][opset]:
|
|
return False # found a match, do not exclude
|
|
|
|
return True
|
|
|
|
def process_registration(self, lines: typing.List[str], domain: str, operator: str,
|
|
start_version: int, end_version: int = None, input_type: str = None):
|
|
exclude = self._should_exclude_op(domain, operator, start_version, end_version)
|
|
if exclude:
|
|
log.info('Disabling {}:{}({})'.format(domain, operator, start_version))
|
|
for line in lines:
|
|
self.output_file.write('// ' + line)
|
|
|
|
# edge case of last entry in table where we still need the terminating }; to not be commented out
|
|
if lines[-1].rstrip().endswith('};'):
|
|
self.output_file.write('};\n')
|
|
else:
|
|
for line in lines:
|
|
self.output_file.write(line)
|
|
|
|
def process_other_line(self, line):
|
|
self.output_file.write(line)
|
|
|
|
def ok(self):
|
|
return True
|
|
|
|
|
|
def _exclude_unused_ops_in_registrations(required_operators, provider_registration_paths):
|
|
'''rewrite provider registration file to exclude unused ops'''
|
|
|
|
for kernel_registration_file in provider_registration_paths:
|
|
if not os.path.isfile(kernel_registration_file):
|
|
log.warning('Kernel registration file {} does not exist'.format(kernel_registration_file))
|
|
return
|
|
|
|
log.info("Processing {}".format(kernel_registration_file))
|
|
|
|
backup_path = kernel_registration_file + '~'
|
|
shutil.move(kernel_registration_file, backup_path)
|
|
|
|
# read from backup and overwrite original with commented out lines for any kernels that are not required
|
|
with open(kernel_registration_file, 'w') as file_to_write:
|
|
processor = ExcludeOpsRegistrationProcessor(required_operators, file_to_write)
|
|
|
|
op_registration_utils.process_kernel_registration_file(backup_path, processor)
|
|
|
|
if not processor.ok():
|
|
# error should have already been logged so just exit
|
|
sys.exit(-1)
|
|
|
|
|
|
def _create_config_file_with_required_ops(required_operators, model_path, config_path, output_file):
|
|
|
|
directory, filename = os.path.split(output_file)
|
|
if not filename:
|
|
log.error("Invalid path to write final config to. {}".format(output_file))
|
|
sys.exit(-1)
|
|
|
|
if not os.path.exists(directory):
|
|
os.makedirs(directory)
|
|
|
|
with open(output_file, 'w') as out:
|
|
model_path_info = '--model_path {} '.format(model_path) if model_path else ''
|
|
config_path_info = '--config_path {}'.format(config_path) if config_path else ''
|
|
out.write("# Generated from {}{}\n".format(model_path_info, config_path_info))
|
|
|
|
for domain in sorted(required_operators.keys()):
|
|
if domain == 'UnknownDomain':
|
|
continue
|
|
|
|
# reverse the mapping of the domain. entry must exist given we created the required_operators dictionary.
|
|
# also need to handle ai.onnx being special-cased as an empty string
|
|
orig_domain = [key for (key, value) in op_registration_utils.domain_map.items() if value == domain][0]
|
|
if not orig_domain:
|
|
orig_domain = 'ai.onnx'
|
|
|
|
for opset in sorted(required_operators[domain].keys()):
|
|
ops = required_operators[domain][opset]
|
|
if ops:
|
|
out.write("{};{};{}\n".format(orig_domain, opset, ','.join(sorted(ops))))
|
|
|
|
log.info("Wrote set of required operators to {}".format(output_file))
|
|
|
|
|
|
def exclude_unused_ops(models_path, config_path, ort_root=None, use_cuda=True, output_config_path=None):
|
|
'''Determine operators that are used, and either exclude them or create a configuration file that will.
|
|
Note that this called directly from build.py'''
|
|
|
|
if not models_path and not config_path:
|
|
log.error('Please specify model_path and/or config_path.')
|
|
sys.exit(-1)
|
|
|
|
if not ort_root and not output_config_path:
|
|
log.info('ort_root was not specified. Inferring ONNX Runtime repository root from location of this script.')
|
|
|
|
required_ops = _extract_ops_from_config(config_path, _extract_ops_from_model(models_path, {}))
|
|
|
|
if output_config_path:
|
|
_create_config_file_with_required_ops(required_ops, models_path, config_path, output_config_path)
|
|
else:
|
|
registration_files = op_registration_utils.get_kernel_registration_files(ort_root, use_cuda)
|
|
_exclude_unused_ops_in_registrations(required_ops, registration_files)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Script to exclude unused operator kernels by disabling their registration in ONNXRuntime."
|
|
"See /docs/Reduced_Operator_Kernel_build.md for more information",
|
|
usage="Provide model_path, config_path, or both, to disable kernel registration of unused kernels.")
|
|
|
|
parser.add_argument(
|
|
"--model_path", type=str, help="Path to folder containing one or more ONNX models")
|
|
|
|
parser.add_argument(
|
|
"--config_path", type=str, help="Path to configuration file with format of 'domain;opset;op1,op2...'")
|
|
|
|
parser.add_argument(
|
|
"--ort_root", type=str, help="Path to ONNXRuntime repository root. "
|
|
"Inferred from the location of this script if not provided.")
|
|
|
|
parser.add_argument(
|
|
"--write_combined_config_to", type=str,
|
|
help="Optional path to create a configuration file with the combined set of required kernels "
|
|
"from processing --model_path and/or --config_path. If provided, a configuration file will be created "
|
|
"and NO updates will be made to the kernel registrations."
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
models_path = os.path.abspath(args.model_path) if args.model_path else ''
|
|
config_path = os.path.abspath(args.config_path) if args.config_path else ''
|
|
ort_root = os.path.abspath(args.ort_root) if args.ort_root else ''
|
|
|
|
if not models_path and not config_path:
|
|
log.error('Please specify at least either model path or file path.')
|
|
parser.print_help()
|
|
sys.exit(-1)
|
|
|
|
exclude_unused_ops(models_path, config_path, ort_root, use_cuda=True,
|
|
output_config_path=args.write_combined_config_to)
|