From 4d8510611bfc679a714a06fd35aea1dd23f8baef Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 29 Sep 2022 16:55:22 +1000 Subject: [PATCH] Update find_optimizer_opset_version_updates_required.py to use the ONNX headers to determine the latest opset. (#12484) **Description**: Use the onnx headers to find the latest opset for each operator. This allows the script to detect optimizers with `graph_utils::IsSupportedOptypeVersionAndDomain` calls that need updating when run during the update of the onnx commit id. Without this change issues are not detected until a new kernel is registered. **Motivation and Context** Detect optimizers that need updates as part of the ONNX update process. --- ...ptimizer_opset_version_updates_required.py | 51 +++++++++++++++++-- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/tools/python/find_optimizer_opset_version_updates_required.py b/tools/python/find_optimizer_opset_version_updates_required.py index 412b25df90..13438eef5a 100644 --- a/tools/python/find_optimizer_opset_version_updates_required.py +++ b/tools/python/find_optimizer_opset_version_updates_required.py @@ -96,13 +96,22 @@ def get_multiline_call_args_from_file(filename: str, function_or_declaration: st return results -def get_latest_op_versions(root_dir): +def _add_if_newer(domain: str, op: str, opset: int, op_to_opset: typing.Dict[str, int]): + key = domain + "." + op + if key not in op_to_opset or op_to_opset[key] < opset: + op_to_opset[key] = opset + + +def get_latest_ort_op_versions(root_dir): """Find the entries for the latest opset for each operator.""" op_to_opset = {} files = [ - os.path.join(root_dir, "onnxruntime/core/providers/cpu/cpu_execution_provider.cc"), + # for ONNX operators we use get_latest_onnx_op_versions + # os.path.join(root_dir, "onnxruntime/core/providers/cpu/cpu_execution_provider.cc"), + # for internal kernels we use the current registrations os.path.join(root_dir, "onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc"), + os.path.join(root_dir, "onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc"), ] for file in files: @@ -113,7 +122,7 @@ def get_latest_op_versions(root_dir): domain = args[1].strip() opset = args[2].strip() op = args[3].strip() - op_to_opset[domain + "." + op] = opset + _add_if_newer(domain, op, int(opset), op_to_opset) # e.g. class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, ArgMax); calls = get_multiline_call_args_from_file(file, "ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME") @@ -122,7 +131,34 @@ def get_latest_op_versions(root_dir): domain = args[1].strip() opset = args[2].strip() op = args[4].strip() - op_to_opset[domain + "." + op] = opset + _add_if_newer(domain, op, int(opset), op_to_opset) + + return op_to_opset + + +def get_latest_onnx_op_versions(root_dir): + """Get the latest versions of the ONNX operators from the ONNX headers.""" + + op_to_opset = {} + files = [ + # operators with domain of 'Onnx' + os.path.join(root_dir, "cmake/external/onnx/onnx/defs/operator_sets.h"), + # ML operators with domain of 'OnnxML' + os.path.join(root_dir, "cmake/external/onnx/onnx/defs/operator_sets_ml.h"), + ] + + for file in files: + # e.g. fn(GetOpSchema()); + # fn(GetOpSchema()); + calls = get_multiline_call_args_from_file(file, "ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME") + for call in calls: + args = call.split(",") + orig_domain = args[0].strip() + # convert domain to the ORT constants + domain = "kMLDomain" if orig_domain == "OnnxML" else "kOnnxDomain" + opset = args[1].strip() + op = args[2].strip() + _add_if_newer(domain, op, int(opset), op_to_opset) return op_to_opset @@ -174,5 +210,10 @@ def find_potential_issues(root_dir, op_to_opset): if __name__ == "__main__": arguments = parse_args() - op_to_opset_map = get_latest_op_versions(arguments.ort_root) + ort_to_opset_map = get_latest_ort_op_versions(arguments.ort_root) + onnx_op_to_opset_map = get_latest_onnx_op_versions(arguments.ort_root) + + # merge the two maps + op_to_opset_map = {**ort_to_opset_map, **onnx_op_to_opset_map} + find_potential_issues(arguments.ort_root, op_to_opset_map)