#!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import argparse import pathlib import sys import typing import onnx from util.file_utils import files_from_file_or_dir, path_match_suffix_ignore_case def _get_suffix_match_predicate(suffix: str): def predicate(file_path: pathlib.Path): return path_match_suffix_ignore_case(file_path, suffix) return predicate def _extract_ops_from_onnx_graph(graph, operators, domain_opset_map): """Extract ops from an ONNX graph and all subgraphs""" for operator in graph.node: # empty domain is used as an alias for 'ai.onnx' domain = operator.domain if operator.domain else "ai.onnx" if domain not in operators or domain not in domain_opset_map: continue operators[domain][domain_opset_map[domain]].add(operator.op_type) for attr in operator.attribute: if attr.type == onnx.AttributeProto.GRAPH: # process subgraph _extract_ops_from_onnx_graph(attr.g, operators, domain_opset_map) elif attr.type == onnx.AttributeProto.GRAPHS: # Currently no ONNX operators use GRAPHS. # Fail noisily if we encounter this so we can implement support raise RuntimeError("Unexpected attribute proto of GRAPHS") def _process_onnx_model(model_path, required_ops): model = onnx.load(model_path) # create map of domain to opset for the model domain_opset_map = {} for opset in model.opset_import: # empty domain == ai.onnx domain = opset.domain if opset.domain else "ai.onnx" domain_opset_map[domain] = opset.version if domain not in required_ops: required_ops[domain] = {opset.version: set()} elif opset.version not in required_ops[domain]: required_ops[domain][opset.version] = set() # check the model imports at least one opset. if it does not it's an unexpected edge case that we have to ignore # as we don't know what opset nodes in the graph belong to. if domain_opset_map: _extract_ops_from_onnx_graph(model.graph, required_ops, domain_opset_map) def _extract_ops_from_onnx_model(model_files: typing.Iterable[pathlib.Path]): """Extract ops from ONNX models""" required_ops = {} for model_file in model_files: if not model_file.is_file(): raise ValueError(f"Path is not a file: '{model_file}'") _process_onnx_model(model_file, required_ops) return required_ops def create_config_from_onnx_models(model_files: typing.Iterable[pathlib.Path], output_file: pathlib.Path): required_ops = _extract_ops_from_onnx_model(model_files) output_file.parent.mkdir(parents=True, exist_ok=True) with open(output_file, "w") as out: out.write("# Generated from ONNX model/s:\n") for model_file in sorted(model_files): out.write(f"# - {model_file}\n") for domain in sorted(required_ops.keys()): for opset in sorted(required_ops[domain].keys()): ops = required_ops[domain][opset] if ops: out.write("{};{};{}\n".format(domain, opset, ",".join(sorted(ops)))) def main(): argparser = argparse.ArgumentParser( "Script to create a reduced build config file from either ONNX or ORT format model/s. " "See /docs/Reduced_Operator_Kernel_build.md for more information on the configuration file format.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) argparser.add_argument( "-f", "--format", choices=["ONNX", "ORT"], default="ONNX", help="Format of model/s to process." ) argparser.add_argument( "-t", "--enable_type_reduction", action="store_true", help="Enable tracking of the specific types that individual operators require. " "Operator implementations MAY support limiting the type support included in the build " "to these types. Only possible with ORT format models.", ) argparser.add_argument( "model_path_or_dir", type=pathlib.Path, help="Path to a single model, or a directory that will be recursively searched " "for models to process.", ) argparser.add_argument( "config_path", nargs="?", type=pathlib.Path, default=None, help="Path to write configuration file to. Default is to write to required_operators.config " "or required_operators_and_types.config in the same directory as the models.", ) args = argparser.parse_args() if args.enable_type_reduction and args.format == "ONNX": print("Type reduction requires model format to be ORT.", file=sys.stderr) sys.exit(-1) model_path_or_dir = args.model_path_or_dir.resolve() if args.config_path: config_path = args.config_path.resolve() else: config_path = model_path_or_dir if model_path_or_dir.is_dir() else model_path_or_dir.parent if config_path.is_dir(): filename = "required_operators_and_types.config" if args.enable_type_reduction else "required_operators.config" config_path = config_path.joinpath(filename) if args.format == "ONNX": model_files = files_from_file_or_dir(model_path_or_dir, _get_suffix_match_predicate(".onnx")) create_config_from_onnx_models(model_files, config_path) else: from util.ort_format_model import create_config_from_models as create_config_from_ort_models model_files = files_from_file_or_dir(model_path_or_dir, _get_suffix_match_predicate(".ort")) create_config_from_ort_models(model_files, config_path, args.enable_type_reduction) # Debug code to validate that the config parsing matches # from util import parse_config # required_ops, op_type_usage_processor, _ = parse_config(args.config_path, True) # op_type_usage_processor.debug_dump() if __name__ == "__main__": main()