onnxruntime/tools/python/FindOptimizerOpsetVersionUpdatesRequired.py

142 lines
5.4 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import glob
import logging
import os
import re
logging.basicConfig(format="[%(levelname)s] - %(message)s", level=logging.DEBUG)
log = logging.getLogger()
def parse_args():
parser = argparse.ArgumentParser(
description="Find optimizers that involve operators which may need an update to the supported opset versions."
)
root_arg = parser.add_argument(
"--ort-root", "-o", required=True, type=str, help="The root directory of the ONNX Runtime repository to search."
)
args = parser.parse_args()
if not os.path.isdir(args.ort_root):
raise argparse.ArgumentError(root_arg, "{} is not a valid directory".format(args.ort_root))
return args
def get_call_args_from_file(filename, function_or_declaration):
"""Search a file for all function calls or declarations that match the provided name.
Currently requires both the opening '(' and closing ')' to be on the same line."""
results = []
with open(filename) as f:
line_num = 0
for line in f.readlines():
for match in re.finditer(function_or_declaration, line):
# check we have both the opening and closing brackets for the function call/declaration.
# if we do we have all the arguments
start = line.find("(", match.end())
end = line.find(")", match.end())
have_all_args = start != -1 and end != -1
if have_all_args:
results.append(line[start + 1 : end])
else:
# TODO: handle automatically by merging lines
log.error(
"Call/Declaration is split over multiple lines. Please check manually."
"File:{} Line:{}".format(filename, line_num)
)
continue
line_num += 1
return results
def get_latest_op_versions(root_dir):
"""Find the entries for the latest opset for each operator."""
op_to_opset = {}
files = [
os.path.join(root_dir, "onnxruntime/core/providers/cpu/cpu_execution_provider.cc"),
os.path.join(root_dir, "onnxruntime/contrib_ops/cpu_contrib_kernels.cc"),
]
for file in files:
# e.g. class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Clip);
calls = get_call_args_from_file(file, "ONNX_OPERATOR_KERNEL_CLASS_NAME")
for call in calls:
args = call.split(",")
domain = args[1].strip()
opset = args[2].strip()
op = args[3].strip()
op_to_opset[domain + "." + op] = opset
# e.g. class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, ArgMax);
calls = get_call_args_from_file(file, "ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME")
for call in calls:
args = call.split(",")
domain = args[1].strip()
opset = args[2].strip()
op = args[4].strip()
op_to_opset[domain + "." + op] = opset
return op_to_opset
def find_potential_issues(root_dir, op_to_opset):
optimizer_dir = os.path.join(root_dir, "onnxruntime/core/optimizer")
files = glob.glob(optimizer_dir + "/**/*.cc", recursive=True)
files += glob.glob(optimizer_dir + "/**/*.h", recursive=True)
for file in files:
calls = get_call_args_from_file(file, "graph_utils::IsSupportedOptypeVersionAndDomain")
for call in calls:
# Need to handle multiple comma separated version numbers, and the optional domain argument.
# e.g. IsSupportedOptypeVersionAndDomain(node, "MaxPool", {1, 8, 10})
# IsSupportedOptypeVersionAndDomain(node, "FusedConv", {1}, kMSDomain)
args = call.split(",", 2) # first 2 args are simple, remainder need custom processing
op = args[1].strip()
versions_and_domain_arg = args[2]
v1 = versions_and_domain_arg.find("{")
v2 = versions_and_domain_arg.find("}")
versions = versions_and_domain_arg[v1 + 1 : v2].split(",")
last_version = versions[-1].strip()
domain_arg_start = versions_and_domain_arg.find(",", v2)
if domain_arg_start != -1:
domain = versions_and_domain_arg[domain_arg_start + 1 :].strip()
else:
domain = "kOnnxDomain"
if op.startswith('"') and op.endswith('"'):
op = domain + "." + op[1:-1]
else:
log.error("Symbolic name of '{}' found for op. Please check manually. File:{}".format(op, file))
continue
if op in op_to_opset:
latest = op_to_opset[op]
if int(latest) != int(last_version):
log.warning(
"Newer opset found for {}. Latest:{} Optimizer support ends at {}. File:{}".format(
op, latest, last_version, file
)
)
else:
log.error("Failed to find version information for {}. File:{}".format(op, file))
if __name__ == "__main__":
arguments = parse_args()
op_to_opset_map = get_latest_op_versions(arguments.ort_root)
find_potential_issues(arguments.ort_root, op_to_opset_map)