onnxruntime/tools/python/create_reduced_build_config.py
Justin Chu a36caba073
Bump ruff in CI (#15533)
### Description

Bump ruff version in CI and fixed new lint errors. 

- This change enables the flake8-implicit-str-concat rules which helps
detect unintended string concatenations:
https://beta.ruff.rs/docs/rules/#flake8-implicit-str-concat-isc
- Update gitignore to include common python files that we want to
exclude.


### Motivation and Context

Code quality
2023-04-17 10:11:44 -07:00

158 lines
5.9 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import pathlib
import sys
import typing
import onnx
from util.file_utils import files_from_file_or_dir, path_match_suffix_ignore_case
def _get_suffix_match_predicate(suffix: str):
def predicate(file_path: pathlib.Path):
return path_match_suffix_ignore_case(file_path, suffix)
return predicate
def _extract_ops_from_onnx_graph(graph, operators, domain_opset_map):
"""Extract ops from an ONNX graph and all subgraphs"""
for operator in graph.node:
# empty domain is used as an alias for 'ai.onnx'
domain = operator.domain if operator.domain else "ai.onnx"
if domain not in operators or domain not in domain_opset_map:
continue
operators[domain][domain_opset_map[domain]].add(operator.op_type)
for attr in operator.attribute:
if attr.type == onnx.AttributeProto.GRAPH: # process subgraph
_extract_ops_from_onnx_graph(attr.g, operators, domain_opset_map)
elif attr.type == onnx.AttributeProto.GRAPHS:
# Currently no ONNX operators use GRAPHS.
# Fail noisily if we encounter this so we can implement support
raise RuntimeError("Unexpected attribute proto of GRAPHS")
def _process_onnx_model(model_path, required_ops):
model = onnx.load(model_path)
# create map of domain to opset for the model
domain_opset_map = {}
for opset in model.opset_import:
# empty domain == ai.onnx
domain = opset.domain if opset.domain else "ai.onnx"
domain_opset_map[domain] = opset.version
if domain not in required_ops:
required_ops[domain] = {opset.version: set()}
elif opset.version not in required_ops[domain]:
required_ops[domain][opset.version] = set()
# check the model imports at least one opset. if it does not it's an unexpected edge case that we have to ignore
# as we don't know what opset nodes in the graph belong to.
if domain_opset_map:
_extract_ops_from_onnx_graph(model.graph, required_ops, domain_opset_map)
def _extract_ops_from_onnx_model(model_files: typing.Iterable[pathlib.Path]):
"""Extract ops from ONNX models"""
required_ops = {}
for model_file in model_files:
if not model_file.is_file():
raise ValueError(f"Path is not a file: '{model_file}'")
_process_onnx_model(model_file, required_ops)
return required_ops
def create_config_from_onnx_models(model_files: typing.Iterable[pathlib.Path], output_file: pathlib.Path):
required_ops = _extract_ops_from_onnx_model(model_files)
output_file.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, "w") as out:
out.write("# Generated from ONNX model/s:\n")
for model_file in sorted(model_files):
out.write(f"# - {model_file}\n")
for domain in sorted(required_ops.keys()):
for opset in sorted(required_ops[domain].keys()):
ops = required_ops[domain][opset]
if ops:
out.write("{};{};{}\n".format(domain, opset, ",".join(sorted(ops))))
def main():
argparser = argparse.ArgumentParser(
"Script to create a reduced build config file from either ONNX or ORT format model/s. "
"See /docs/Reduced_Operator_Kernel_build.md for more information on the configuration file format.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
argparser.add_argument(
"-f", "--format", choices=["ONNX", "ORT"], default="ONNX", help="Format of model/s to process."
)
argparser.add_argument(
"-t",
"--enable_type_reduction",
action="store_true",
help="Enable tracking of the specific types that individual operators require. "
"Operator implementations MAY support limiting the type support included in the build "
"to these types. Only possible with ORT format models.",
)
argparser.add_argument(
"model_path_or_dir",
type=pathlib.Path,
help="Path to a single model, or a directory that will be recursively searched for models to process.",
)
argparser.add_argument(
"config_path",
nargs="?",
type=pathlib.Path,
default=None,
help="Path to write configuration file to. Default is to write to required_operators.config "
"or required_operators_and_types.config in the same directory as the models.",
)
args = argparser.parse_args()
if args.enable_type_reduction and args.format == "ONNX":
print("Type reduction requires model format to be ORT.", file=sys.stderr)
sys.exit(-1)
model_path_or_dir = args.model_path_or_dir.resolve()
if args.config_path:
config_path = args.config_path.resolve()
else:
config_path = model_path_or_dir if model_path_or_dir.is_dir() else model_path_or_dir.parent
if config_path.is_dir():
filename = "required_operators_and_types.config" if args.enable_type_reduction else "required_operators.config"
config_path = config_path.joinpath(filename)
if args.format == "ONNX":
model_files = files_from_file_or_dir(model_path_or_dir, _get_suffix_match_predicate(".onnx"))
create_config_from_onnx_models(model_files, config_path)
else:
from util.ort_format_model import create_config_from_models as create_config_from_ort_models
model_files = files_from_file_or_dir(model_path_or_dir, _get_suffix_match_predicate(".ort"))
create_config_from_ort_models(model_files, config_path, args.enable_type_reduction)
# Debug code to validate that the config parsing matches
# from util import parse_config
# required_ops, op_type_usage_processor, _ = parse_config(args.config_path, True)
# op_type_usage_processor.debug_dump()
if __name__ == "__main__":
main()