mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
### Description Windows GPU Reduced Ops CI Pipeline is broken due to the introduction of a second template type in registered kernels. The python code checking the registration is broken due to that. This PR addresses this issue on the python side by keeping only one type equal to the concatenation of the two types.
221 lines
8.7 KiB
Python
221 lines
8.7 KiB
Python
# !/usr/bin/env python3
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
"""
|
|
Validate ORT kernel registrations.
|
|
"""
|
|
|
|
import argparse
|
|
import dataclasses
|
|
import itertools
|
|
import os
|
|
import sys
|
|
import typing
|
|
|
|
import op_registration_utils
|
|
from logger import get_logger
|
|
|
|
log = get_logger("op_registration_validator")
|
|
|
|
# deprecated ops where the last registration should have an end version.
|
|
# value for each entry is the opset when it was deprecated. end version of last registration should equal value - 1.
|
|
deprecated_ops = {
|
|
"kOnnxDomain:Scatter": 11,
|
|
"kOnnxDomain:Upsample": 10,
|
|
# LayerNormalization, MeanVarianceNormalization and ThresholdedRelu were in contrib ops and incorrectly registered
|
|
# using the kOnnxDomain. They became official ONNX operators later and are registered there now. That leaves
|
|
# entries in the contrib ops registrations with end versions for when the contrib op was 'deprecated'
|
|
# and became an official op.
|
|
"kOnnxDomain:LayerNormalization": 17,
|
|
"kOnnxDomain:MeanVarianceNormalization": 9,
|
|
"kOnnxDomain:ThresholdedRelu": 10,
|
|
}
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class RegistrationInfo:
|
|
domain: str
|
|
operator: str
|
|
start_version: int
|
|
end_version: typing.Optional[int]
|
|
lines: typing.List[str]
|
|
|
|
def domain_and_op_str(self):
|
|
return f"{self.domain}:{self.operator}"
|
|
|
|
|
|
def _log_registration_error(r: RegistrationInfo, message: str):
|
|
log.error("Invalid registration for {}. {}\n{}".format(r.domain_and_op_str(), message, "".join(r.lines)))
|
|
|
|
|
|
class RegistrationValidator(op_registration_utils.RegistrationProcessor):
|
|
def __init__(self):
|
|
self.all_registrations: typing.List[RegistrationInfo] = []
|
|
|
|
def process_registration(
|
|
self,
|
|
lines: typing.List[str],
|
|
domain: str,
|
|
operator: str,
|
|
start_version: int,
|
|
end_version: typing.Optional[int] = None,
|
|
type: typing.Optional[str] = None,
|
|
):
|
|
self.all_registrations.append(
|
|
RegistrationInfo(
|
|
domain=domain, operator=operator, start_version=start_version, end_version=end_version, lines=lines
|
|
)
|
|
)
|
|
|
|
def ok(self):
|
|
num_invalid_registrations = self._validate_all_registrations()
|
|
if num_invalid_registrations > 0:
|
|
log.error(f"Found {num_invalid_registrations} invalid registration(s).")
|
|
return False
|
|
|
|
return True
|
|
|
|
def _validate_all_registrations(self) -> int:
|
|
"""
|
|
Validates all registrations added by `process_registration()` and returns the number of invalid ones.
|
|
"""
|
|
|
|
def registration_info_sort_key(r: RegistrationInfo):
|
|
return (
|
|
r.domain,
|
|
r.operator,
|
|
r.start_version,
|
|
1 if r.end_version is None else 0, # unspecified end_version > specified end_version
|
|
r.end_version,
|
|
)
|
|
|
|
def domain_and_op_key(r: RegistrationInfo):
|
|
return (r.domain, r.operator)
|
|
|
|
sorted_registrations = sorted(self.all_registrations, key=registration_info_sort_key)
|
|
|
|
num_invalid_registrations = 0
|
|
for _, registration_group in itertools.groupby(sorted_registrations, key=domain_and_op_key):
|
|
num_invalid_registrations += self._validate_registrations_for_domain_and_op(registration_group)
|
|
|
|
return num_invalid_registrations
|
|
|
|
def _validate_registrations_for_domain_and_op(self, registrations: typing.Iterator[RegistrationInfo]) -> int:
|
|
"""
|
|
Validates registrations in sorted order for a single domain and op and returns the number of invalid ones.
|
|
"""
|
|
num_invalid_registrations = 0
|
|
r = next(registrations, None)
|
|
while r is not None:
|
|
next_r = next(registrations, None)
|
|
if not self._validate_registration(r, next_r):
|
|
num_invalid_registrations += 1
|
|
r = next_r
|
|
|
|
return num_invalid_registrations
|
|
|
|
def _validate_registration(self, r: RegistrationInfo, next_r: typing.Optional[RegistrationInfo]) -> bool:
|
|
"""
|
|
Validates a registration, `r`, with the next one in sorted order for a single domain and op, `next_r`, and
|
|
returns whether it is valid.
|
|
"""
|
|
if not (r.end_version is None or r.start_version <= r.end_version):
|
|
_log_registration_error(
|
|
r, f"Start version ({r.start_version}) is greater than end version ({r.end_version})."
|
|
)
|
|
return False
|
|
|
|
if next_r is None:
|
|
return self._validate_last_registration(r)
|
|
|
|
# It is valid to match next registration start and end versions exactly.
|
|
# This is expected if there are multiple registrations for an opset (e.g., typed registrations).
|
|
if (r.start_version, r.end_version) == (next_r.start_version, next_r.end_version):
|
|
return True
|
|
|
|
# This registration has no end version but it should have one if the next registration has different versions.
|
|
if r.end_version is None:
|
|
_log_registration_error(
|
|
r,
|
|
f"Registration for opset {r.start_version} has no end version but was superseded by version "
|
|
f"{next_r.start_version}.",
|
|
)
|
|
return False
|
|
|
|
# This registration's end version is not adjacent to the start version of the next registration.
|
|
if r.end_version != next_r.start_version - 1:
|
|
_log_registration_error(
|
|
r,
|
|
f"Registration end version is not adjacent to the next registration's start version. "
|
|
f"Current start and end versions: {(r.start_version, r.end_version)}. "
|
|
f"Next start and end versions: {(next_r.start_version, next_r.end_version)}.",
|
|
)
|
|
return False
|
|
|
|
return True
|
|
|
|
def _validate_last_registration(self, last_r: RegistrationInfo) -> bool:
|
|
"""
|
|
Validates the last registration in sorted order for a single domain and op and returns whether it is valid.
|
|
"""
|
|
# make sure we have an unversioned last entry for each operator unless it's deprecated
|
|
|
|
# TODO If the operator is deprecated, validation is more lax. I.e., it doesn't require a versioned registration.
|
|
# This could be tightened up but we would need to handle the deprecated contrib ops registered in the ONNX
|
|
# domain that have newer registrations in a non-contrib op file differently. They should only be considered
|
|
# deprecated as contrib ops.
|
|
domain_and_op_str = last_r.domain_and_op_str()
|
|
deprecation_version = deprecated_ops.get(domain_and_op_str, None)
|
|
|
|
allow_missing_unversioned_registration = (
|
|
deprecation_version is not None and last_r.end_version == deprecation_version - 1
|
|
)
|
|
|
|
# special handling for ArgMin/ArgMax, which CUDA EP doesn't yet support for opset 12+
|
|
# TODO remove once CUDA EP supports ArgMin/ArgMax for opset 12+
|
|
ops_with_incomplete_support = ["kOnnxDomain:ArgMin", "kOnnxDomain:ArgMax"]
|
|
if domain_and_op_str in ops_with_incomplete_support:
|
|
log.warning(
|
|
f"Allowing missing unversioned registration for op with incomplete support: {domain_and_op_str}."
|
|
)
|
|
allow_missing_unversioned_registration = True
|
|
|
|
if last_r.end_version is not None and not allow_missing_unversioned_registration:
|
|
log.error(f"Missing unversioned registration for {domain_and_op_str}.")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Script to validate operator kernel registrations.")
|
|
|
|
parser.add_argument(
|
|
"--ort_root",
|
|
type=str,
|
|
help="Path to ONNXRuntime repository root. Inferred from the location of this script if not provided.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
ort_root = os.path.abspath(args.ort_root) if args.ort_root else None
|
|
include_cuda = True # validate CPU and CUDA EP registrations
|
|
|
|
registration_files = op_registration_utils.get_kernel_registration_files(ort_root, include_cuda)
|
|
|
|
def validate_registration_file(file: str) -> bool:
|
|
log.info(f"Processing {file}")
|
|
|
|
processor = RegistrationValidator()
|
|
op_registration_utils.process_kernel_registration_file(file, processor)
|
|
|
|
return processor.ok()
|
|
|
|
validation_successful = all(
|
|
# Validate each file first by storing the validation results in a list.
|
|
# Otherwise, all() will exit early when it encounters the first invalid file.
|
|
list(map(validate_registration_file, registration_files))
|
|
)
|
|
|
|
log.info(f"Op kernel registration validation {'succeeded' if validation_successful else 'failed'}.")
|
|
sys.exit(0 if validation_successful else 1)
|