onnxruntime/tools/python/create_reduced_build_config.py
Edward Chen ee35be0129
Support specifying globally allowed types from build script (#6677)
Add initial support for constraining operator kernel implementations (which support this type-granularity) to a set of allowed types from scripts.
2021-02-22 14:05:00 -08:00

144 lines
6 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import os
import onnx
import pathlib
import sys
def _extract_ops_from_onnx_graph(graph, operators, domain_opset_map):
'''Extract ops from an ONNX graph and all subgraphs'''
for operator in graph.node:
# empty domain is used as an alias for 'ai.onnx'
domain = operator.domain if operator.domain else 'ai.onnx'
if domain not in operators or domain not in domain_opset_map:
continue
operators[domain][domain_opset_map[domain]].add(operator.op_type)
for attr in operator.attribute:
if attr.type == onnx.AttributeProto.GRAPH: # process subgraph
_extract_ops_from_onnx_graph(attr.g, operators, domain_opset_map)
elif attr.type == onnx.AttributeProto.GRAPHS:
# Currently no ONNX operators use GRAPHS.
# Fail noisily if we encounter this so we can implement support
raise RuntimeError('Unexpected attribute proto of GRAPHS')
def _process_onnx_model(model_path, required_ops):
model = onnx.load(model_path)
# create map of domain to opset for the model
domain_opset_map = {}
for opset in model.opset_import:
# empty domain == ai.onnx
domain = opset.domain if opset.domain else 'ai.onnx'
domain_opset_map[domain] = opset.version
if domain not in required_ops:
required_ops[domain] = {opset.version: set()}
elif opset.version not in required_ops[domain]:
required_ops[domain][opset.version] = set()
# check the model imports at least one opset. if it does not it's an unexpected edge case that we have to ignore
# as we don't know what opset nodes in the graph belong to.
if domain_opset_map:
_extract_ops_from_onnx_graph(model.graph, required_ops, domain_opset_map)
def _extract_ops_from_onnx_model(model_path_or_dir):
'''Extract ops from a single ONNX model, or all ONNX models found by recursing model_path_or_dir'''
if not os.path.exists(model_path_or_dir):
raise ValueError('Path to model/s does not exist: {}'.format(model_path_or_dir))
required_ops = {}
if os.path.isfile(model_path_or_dir):
_process_onnx_model(model_path_or_dir, required_ops)
else:
for root, _, files in os.walk(model_path_or_dir):
for file in files:
if file.lower().endswith('.onnx'):
model_path = os.path.join(root, file)
_process_onnx_model(model_path, required_ops)
return required_ops
def create_config_from_onnx_models(model_path_or_dir: str, output_file: str):
required_ops = _extract_ops_from_onnx_model(model_path_or_dir)
directory, filename = os.path.split(output_file)
if not filename:
raise RuntimeError("Invalid output path for configuation: {}".format(output_file))
if not os.path.exists(directory):
os.makedirs(directory)
with open(output_file, 'w') as out:
out.write("# Generated from ONNX models path of {}\n".format(model_path_or_dir))
for domain in sorted(required_ops.keys()):
for opset in sorted(required_ops[domain].keys()):
ops = required_ops[domain][opset]
if ops:
out.write("{};{};{}\n".format(domain, opset, ','.join(sorted(ops))))
def main():
argparser = argparse.ArgumentParser(
'Script to create a reduced build config file from either ONNX or ORT format model/s. '
'See /docs/Reduced_Operator_Kernel_build.md for more information on the configuration file format.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
argparser.add_argument('-f', '--format', choices=['ONNX', 'ORT'], default='ONNX',
help='Format of model/s to process.')
argparser.add_argument('-t', '--enable_type_reduction', action='store_true',
help='Enable tracking of the specific types that individual operators require. '
'Operator implementations MAY support limiting the type support included in the build '
'to these types. Only possible with ORT format models.')
argparser.add_argument('model_path_or_dir', type=pathlib.Path,
help='Path to a single model, or a directory that will be recursively searched '
'for models to process.')
argparser.add_argument('config_path', nargs='?', type=pathlib.Path, default=None,
help='Path to write configuration file to. Default is to write to required_operators.config '
'or required_operators_and_types.config in the same directory as the models.')
args = argparser.parse_args()
if args.enable_type_reduction and args.format == 'ONNX':
print('Type reduction requires model format to be ORT.', file=sys.stderr)
sys.exit(-1)
model_path_or_dir = args.model_path_or_dir.resolve()
if args.config_path:
config_path = args.config_path.resolve()
else:
config_path = model_path_or_dir if model_path_or_dir.is_dir() else model_path_or_dir.parent
if config_path.is_dir():
filename = 'required_operators_and_types.config' if args.enable_type_reduction else 'required_operators.config'
config_path = config_path.joinpath(filename)
if args.format == 'ONNX':
create_config_from_onnx_models(model_path_or_dir, config_path)
else:
from util.ort_format_model import create_config_from_models as create_config_from_ort_models
create_config_from_ort_models(model_path_or_dir, config_path, args.enable_type_reduction)
# Debug code to validate that the config parsing matches
# from util import parse_config
# required_ops, op_type_usage_processor, _ = parse_config(args.config_path, True)
# op_type_usage_processor.debug_dump()
if __name__ == "__main__":
main()