onnxruntime/tools/python/FindOptimizerOpsetVersionUpdatesRequired.py
Scott McKay 5e0928a777
Enable running PEP8 on python scripts using flake8 (#3928)
* Enable running PEP8 checks via flake8 as part of the build if flake8 is installed.
Update scripts in \tools and \onnxruntime\python. Excluding \onnxruntime\python\tools which needs a lot more work to be PEP8 compliant. Also excluding orttraining\tools for the same reason.
Install flake8 as part of the static_analysis build task in the Win-CPU CI so the checks are run in one CI build.
Update coding standards doc.
2020-05-15 07:15:06 +10:00

132 lines
5.3 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import glob
import logging
import os
import re
logging.basicConfig(format="[%(levelname)s] - %(message)s", level=logging.DEBUG)
log = logging.getLogger()
def parse_args():
parser = argparse.ArgumentParser(
description='Find optimizers that involve operators which may need an update to the supported opset versions.')
root_arg = parser.add_argument('--ort-root', '-o', required=True, type=str,
help='The root directory of the ONNX Runtime repository to search.')
args = parser.parse_args()
if not os.path.isdir(args.ort_root):
raise argparse.ArgumentError(root_arg, "{} is not a valid directory".format(args.ort_root))
return args
def get_call_args_from_file(filename, function_or_declaration):
"""Search a file for all function calls or declarations that match the provided name.
Currently requires both the opening '(' and closing ')' to be on the same line."""
results = []
with open(filename) as f:
line_num = 0
for line in f.readlines():
for match in re.finditer(function_or_declaration, line):
# check we have both the opening and closing brackets for the function call/declaration.
# if we do we have all the arguments
start = line.find('(', match.end())
end = line.find(')', match.end())
have_all_args = start != -1 and end != -1
if have_all_args:
results.append(line[start + 1: end])
else:
# TODO: handle automatically by merging lines
log.error("Call/Declaration is split over multiple lines. Please check manually."
"File:{} Line:{}".format(filename, line_num))
continue
line_num += 1
return results
def get_latest_op_versions(root_dir):
"""Find the entries for the latest opset for each operator."""
op_to_opset = {}
files = [os.path.join(root_dir, "onnxruntime/core/providers/cpu/cpu_execution_provider.cc"),
os.path.join(root_dir, "onnxruntime/contrib_ops/cpu_contrib_kernels.cc")]
for file in files:
# e.g. class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Clip);
calls = get_call_args_from_file(file, 'ONNX_OPERATOR_KERNEL_CLASS_NAME')
for call in calls:
args = call.split(',')
domain = args[1].strip()
opset = args[2].strip()
op = args[3].strip()
op_to_opset[domain + '.' + op] = opset
# e.g. class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, ArgMax);
calls = get_call_args_from_file(file, 'ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME')
for call in calls:
args = call.split(',')
domain = args[1].strip()
opset = args[2].strip()
op = args[4].strip()
op_to_opset[domain + '.' + op] = opset
return op_to_opset
def find_potential_issues(root_dir, op_to_opset):
optimizer_dir = os.path.join(root_dir, "onnxruntime/core/optimizer")
files = glob.glob(optimizer_dir + '/**/*.cc', recursive=True)
files += glob.glob(optimizer_dir + '/**/*.h', recursive=True)
for file in files:
calls = get_call_args_from_file(file, 'graph_utils::IsSupportedOptypeVersionAndDomain')
for call in calls:
# Need to handle multiple comma separated version numbers, and the optional domain argument.
# e.g. IsSupportedOptypeVersionAndDomain(node, "MaxPool", {1, 8, 10})
# IsSupportedOptypeVersionAndDomain(node, "FusedConv", {1}, kMSDomain)
args = call.split(',', 2) # first 2 args are simple, remainder need custom processing
op = args[1].strip()
versions_and_domain_arg = args[2]
v1 = versions_and_domain_arg.find('{')
v2 = versions_and_domain_arg.find('}')
versions = versions_and_domain_arg[v1 + 1: v2].split(',')
last_version = versions[-1].strip()
domain_arg_start = versions_and_domain_arg.find(',', v2)
if domain_arg_start != -1:
domain = versions_and_domain_arg[domain_arg_start + 1:].strip()
else:
domain = "kOnnxDomain"
if op.startswith('"') and op.endswith('"'):
op = domain + '.' + op[1:-1]
else:
log.error("Symbolic name of '{}' found for op. Please check manually. File:{}".format(op, file))
continue
if op in op_to_opset:
latest = op_to_opset[op]
if int(latest) != int(last_version):
log.warning("Newer opset found for {}. Latest:{} Optimizer support ends at {}. File:{}"
.format(op, latest, last_version, file))
else:
log.error("Failed to find version information for {}. File:{}".format(op, file))
if __name__ == '__main__':
arguments = parse_args()
op_to_opset_map = get_latest_op_versions(arguments.ort_root)
find_potential_issues(arguments.ort_root, op_to_opset_map)