mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
223 lines
8.7 KiB
Python
223 lines
8.7 KiB
Python
# !/usr/bin/env python3
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
"""
|
|
Validate ORT kernel registrations.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import dataclasses
|
|
import itertools
|
|
import os
|
|
import sys
|
|
import typing
|
|
|
|
import op_registration_utils
|
|
from logger import get_logger
|
|
|
|
log = get_logger("op_registration_validator")
|
|
|
|
# deprecated ops where the last registration should have an end version.
|
|
# value for each entry is the opset when it was deprecated. end version of last registration should equal value - 1.
|
|
deprecated_ops = {
|
|
"kOnnxDomain:Scatter": 11,
|
|
"kOnnxDomain:Upsample": 10,
|
|
# LayerNormalization, MeanVarianceNormalization and ThresholdedRelu were in contrib ops and incorrectly registered
|
|
# using the kOnnxDomain. They became official ONNX operators later and are registered there now. That leaves
|
|
# entries in the contrib ops registrations with end versions for when the contrib op was 'deprecated'
|
|
# and became an official op.
|
|
"kOnnxDomain:LayerNormalization": 17,
|
|
"kOnnxDomain:MeanVarianceNormalization": 9,
|
|
"kOnnxDomain:ThresholdedRelu": 10,
|
|
}
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class RegistrationInfo:
|
|
domain: str
|
|
operator: str
|
|
start_version: int
|
|
end_version: int | None
|
|
lines: list[str]
|
|
|
|
def domain_and_op_str(self):
|
|
return f"{self.domain}:{self.operator}"
|
|
|
|
|
|
def _log_registration_error(r: RegistrationInfo, message: str):
|
|
log.error("Invalid registration for %s. %s\n%s", r.domain_and_op_str(), message, "".join(r.lines))
|
|
|
|
|
|
class RegistrationValidator(op_registration_utils.RegistrationProcessor):
|
|
def __init__(self):
|
|
self.all_registrations: list[RegistrationInfo] = []
|
|
|
|
def process_registration(
|
|
self,
|
|
lines: list[str],
|
|
domain: str,
|
|
operator: str,
|
|
start_version: int,
|
|
end_version: int | None = None,
|
|
type: str | None = None,
|
|
):
|
|
self.all_registrations.append(
|
|
RegistrationInfo(
|
|
domain=domain, operator=operator, start_version=start_version, end_version=end_version, lines=lines
|
|
)
|
|
)
|
|
|
|
def ok(self):
|
|
num_invalid_registrations = self._validate_all_registrations()
|
|
if num_invalid_registrations > 0:
|
|
log.error(f"Found {num_invalid_registrations} invalid registration(s).")
|
|
return False
|
|
|
|
return True
|
|
|
|
def _validate_all_registrations(self) -> int:
|
|
"""
|
|
Validates all registrations added by `process_registration()` and returns the number of invalid ones.
|
|
"""
|
|
|
|
def registration_info_sort_key(r: RegistrationInfo):
|
|
return (
|
|
r.domain,
|
|
r.operator,
|
|
r.start_version,
|
|
1 if r.end_version is None else 0, # unspecified end_version > specified end_version
|
|
r.end_version,
|
|
)
|
|
|
|
def domain_and_op_key(r: RegistrationInfo):
|
|
return (r.domain, r.operator)
|
|
|
|
sorted_registrations = sorted(self.all_registrations, key=registration_info_sort_key)
|
|
|
|
num_invalid_registrations = 0
|
|
for _, registration_group in itertools.groupby(sorted_registrations, key=domain_and_op_key):
|
|
num_invalid_registrations += self._validate_registrations_for_domain_and_op(registration_group)
|
|
|
|
return num_invalid_registrations
|
|
|
|
def _validate_registrations_for_domain_and_op(self, registrations: typing.Iterator[RegistrationInfo]) -> int:
|
|
"""
|
|
Validates registrations in sorted order for a single domain and op and returns the number of invalid ones.
|
|
"""
|
|
num_invalid_registrations = 0
|
|
r = next(registrations, None)
|
|
while r is not None:
|
|
next_r = next(registrations, None)
|
|
if not self._validate_registration(r, next_r):
|
|
num_invalid_registrations += 1
|
|
r = next_r
|
|
|
|
return num_invalid_registrations
|
|
|
|
def _validate_registration(self, r: RegistrationInfo, next_r: RegistrationInfo | None) -> bool:
|
|
"""
|
|
Validates a registration, `r`, with the next one in sorted order for a single domain and op, `next_r`, and
|
|
returns whether it is valid.
|
|
"""
|
|
if not (r.end_version is None or r.start_version <= r.end_version):
|
|
_log_registration_error(
|
|
r, f"Start version ({r.start_version}) is greater than end version ({r.end_version})."
|
|
)
|
|
return False
|
|
|
|
if next_r is None:
|
|
return self._validate_last_registration(r)
|
|
|
|
# It is valid to match next registration start and end versions exactly.
|
|
# This is expected if there are multiple registrations for an opset (e.g., typed registrations).
|
|
if (r.start_version, r.end_version) == (next_r.start_version, next_r.end_version):
|
|
return True
|
|
|
|
# This registration has no end version but it should have one if the next registration has different versions.
|
|
if r.end_version is None:
|
|
_log_registration_error(
|
|
r,
|
|
f"Registration for opset {r.start_version} has no end version but was superseded by version "
|
|
f"{next_r.start_version}.",
|
|
)
|
|
return False
|
|
|
|
# This registration's end version is not adjacent to the start version of the next registration.
|
|
if r.end_version != next_r.start_version - 1:
|
|
_log_registration_error(
|
|
r,
|
|
f"Registration end version is not adjacent to the next registration's start version. "
|
|
f"Current start and end versions: {(r.start_version, r.end_version)}. "
|
|
f"Next start and end versions: {(next_r.start_version, next_r.end_version)}.",
|
|
)
|
|
return False
|
|
|
|
return True
|
|
|
|
def _validate_last_registration(self, last_r: RegistrationInfo) -> bool:
|
|
"""
|
|
Validates the last registration in sorted order for a single domain and op and returns whether it is valid.
|
|
"""
|
|
# make sure we have an unversioned last entry for each operator unless it's deprecated
|
|
|
|
# TODO If the operator is deprecated, validation is more lax. I.e., it doesn't require a versioned registration.
|
|
# This could be tightened up but we would need to handle the deprecated contrib ops registered in the ONNX
|
|
# domain that have newer registrations in a non-contrib op file differently. They should only be considered
|
|
# deprecated as contrib ops.
|
|
domain_and_op_str = last_r.domain_and_op_str()
|
|
deprecation_version = deprecated_ops.get(domain_and_op_str)
|
|
|
|
allow_missing_unversioned_registration = (
|
|
deprecation_version is not None and last_r.end_version == deprecation_version - 1
|
|
)
|
|
|
|
# special handling for ArgMin/ArgMax, which CUDA EP doesn't yet support for opset 12+
|
|
# TODO remove once CUDA EP supports ArgMin/ArgMax for opset 12+
|
|
ops_with_incomplete_support = ["kOnnxDomain:ArgMin", "kOnnxDomain:ArgMax"]
|
|
if domain_and_op_str in ops_with_incomplete_support:
|
|
log.warning(
|
|
f"Allowing missing unversioned registration for op with incomplete support: {domain_and_op_str}."
|
|
)
|
|
allow_missing_unversioned_registration = True
|
|
|
|
if last_r.end_version is not None and not allow_missing_unversioned_registration:
|
|
log.error(f"Missing unversioned registration for {domain_and_op_str}.")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Script to validate operator kernel registrations.")
|
|
|
|
parser.add_argument(
|
|
"--ort_root",
|
|
type=str,
|
|
help="Path to ONNXRuntime repository root. Inferred from the location of this script if not provided.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
ort_root = os.path.abspath(args.ort_root) if args.ort_root else None
|
|
include_cuda = True # validate CPU and CUDA EP registrations
|
|
|
|
registration_files = op_registration_utils.get_kernel_registration_files(ort_root, include_cuda)
|
|
|
|
def validate_registration_file(file: str) -> bool:
|
|
log.info(f"Processing {file}")
|
|
|
|
processor = RegistrationValidator()
|
|
op_registration_utils.process_kernel_registration_file(file, processor)
|
|
|
|
return processor.ok()
|
|
|
|
validation_successful = all(
|
|
# Validate each file first by storing the validation results in a list.
|
|
# Otherwise, all() will exit early when it encounters the first invalid file.
|
|
list(map(validate_registration_file, registration_files))
|
|
)
|
|
|
|
log.info(f"Op kernel registration validation {'succeeded' if validation_successful else 'failed'}.")
|
|
sys.exit(0 if validation_successful else 1)
|