#!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import argparse import glob import logging import os import re logging.basicConfig(format="[%(levelname)s] - %(message)s", level=logging.DEBUG) log = logging.getLogger() def parse_args(): parser = argparse.ArgumentParser( description='Find optimizers that involve operators which may need an update to the supported opset versions.') root_arg = parser.add_argument('--ort-root', '-o', required=True, type=str, help='The root directory of the ONNX Runtime repository to search.') args = parser.parse_args() if not os.path.isdir(args.ort_root): raise argparse.ArgumentError(root_arg, "{} is not a valid directory".format(args.ort_root)) return args def get_call_args_from_file(filename, function_or_declaration): """Search a file for all function calls or declarations that match the provided name. Currently requires both the opening '(' and closing ')' to be on the same line.""" results = [] with open(filename) as f: line_num = 0 for line in f.readlines(): for match in re.finditer(function_or_declaration, line): # check we have both the opening and closing brackets for the function call/declaration. # if we do we have all the arguments start = line.find('(', match.end()) end = line.find(')', match.end()) have_all_args = start != -1 and end != -1 if have_all_args: results.append(line[start + 1: end]) else: # TODO: handle automatically by merging lines log.error("Call/Declaration is split over multiple lines. Please check manually." "File:{} Line:{}".format(filename, line_num)) continue line_num += 1 return results def get_latest_op_versions(root_dir): """Find the entries for the latest opset for each operator.""" op_to_opset = {} files = [os.path.join(root_dir, "onnxruntime/core/providers/cpu/cpu_execution_provider.cc"), os.path.join(root_dir, "onnxruntime/contrib_ops/cpu_contrib_kernels.cc")] for file in files: # e.g. class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Clip); calls = get_call_args_from_file(file, 'ONNX_OPERATOR_KERNEL_CLASS_NAME') for call in calls: args = call.split(',') domain = args[1].strip() opset = args[2].strip() op = args[3].strip() op_to_opset[domain + '.' + op] = opset # e.g. class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, ArgMax); calls = get_call_args_from_file(file, 'ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME') for call in calls: args = call.split(',') domain = args[1].strip() opset = args[2].strip() op = args[4].strip() op_to_opset[domain + '.' + op] = opset return op_to_opset def find_potential_issues(root_dir, op_to_opset): optimizer_dir = os.path.join(root_dir, "onnxruntime/core/optimizer") files = glob.glob(optimizer_dir + '/**/*.cc', recursive=True) files += glob.glob(optimizer_dir + '/**/*.h', recursive=True) for file in files: calls = get_call_args_from_file(file, 'graph_utils::IsSupportedOptypeVersionAndDomain') for call in calls: # Need to handle multiple comma separated version numbers, and the optional domain argument. # e.g. IsSupportedOptypeVersionAndDomain(node, "MaxPool", {1, 8, 10}) # IsSupportedOptypeVersionAndDomain(node, "FusedConv", {1}, kMSDomain) args = call.split(',', 2) # first 2 args are simple, remainder need custom processing op = args[1].strip() versions_and_domain_arg = args[2] v1 = versions_and_domain_arg.find('{') v2 = versions_and_domain_arg.find('}') versions = versions_and_domain_arg[v1 + 1: v2].split(',') last_version = versions[-1].strip() domain_arg_start = versions_and_domain_arg.find(',', v2) if domain_arg_start != -1: domain = versions_and_domain_arg[domain_arg_start + 1:].strip() else: domain = "kOnnxDomain" if op.startswith('"') and op.endswith('"'): op = domain + '.' + op[1:-1] else: log.error("Symbolic name of '{}' found for op. Please check manually. File:{}".format(op, file)) continue if op in op_to_opset: latest = op_to_opset[op] if int(latest) != int(last_version): log.warning("Newer opset found for {}. Latest:{} Optimizer support ends at {}. File:{}" .format(op, latest, last_version, file)) else: log.error("Failed to find version information for {}. File:{}".format(op, file)) if __name__ == '__main__': arguments = parse_args() op_to_opset_map = get_latest_op_versions(arguments.ort_root) find_potential_issues(arguments.ort_root, op_to_opset_map)