# Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import pathlib import typing from ..logger import get_logger from .operator_type_usage_processors import OperatorTypeUsageManager from .ort_model_processor import OrtFormatModelProcessor log = get_logger("ort_format_model.utils") def _extract_ops_and_types_from_ort_models(model_files: typing.Iterable[pathlib.Path], enable_type_reduction: bool): required_ops = {} op_type_usage_manager = OperatorTypeUsageManager() if enable_type_reduction else None for model_file in model_files: if not model_file.is_file(): raise ValueError(f"Path is not a file: '{model_file}'") model_processor = OrtFormatModelProcessor(str(model_file), required_ops, op_type_usage_manager) model_processor.process() # this updates required_ops and op_type_processors return required_ops, op_type_usage_manager def create_config_from_models( model_files: typing.Iterable[pathlib.Path], output_file: pathlib.Path, enable_type_reduction: bool ): """ Create a configuration file with required operators and optionally required types. :param model_files: Model files to use to generate the configuration file. :param output_file: File to write configuration to. :param enable_type_reduction: Include required type information for individual operators in the configuration. """ required_ops, op_type_processors = _extract_ops_and_types_from_ort_models(model_files, enable_type_reduction) output_file.parent.mkdir(parents=True, exist_ok=True) with open(output_file, "w") as out: out.write("# Generated from model/s:\n") for model_file in sorted(model_files): out.write(f"# - {model_file}\n") for domain in sorted(required_ops.keys()): for opset in sorted(required_ops[domain].keys()): ops = required_ops[domain][opset] if ops: out.write(f"{domain};{opset};") if enable_type_reduction: # type string is empty if op hasn't been seen entries = [ "{}{}".format(op, op_type_processors.get_config_entry(domain, op) or "") for op in sorted(ops) ] else: entries = sorted(ops) out.write("{}\n".format(",".join(entries))) log.info("Created config in %s", output_file)