mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
**Description**: LayerNormalization is now part of the ONNX spec as of opset 17. We had a LayerNormalization contrib op, which (incorrectly) was registered in the ONNX domain. Use that implementation for the ONNX operator. Update skip_layer_norm_fusion.cc. There are other optimizers that use LayerNormalization that need updates as well. **Motivation and Context** #12916
116 lines
4.2 KiB
Python
116 lines
4.2 KiB
Python
# !/usr/bin/env python3
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
"""
|
|
Validate ORT kernel registrations.
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
import typing
|
|
|
|
import op_registration_utils
|
|
from logger import get_logger
|
|
|
|
log = get_logger("op_registration_validator")
|
|
|
|
# deprecated ops where the last registration should have an end version.
|
|
# value for each entry is the opset when it was deprecated. end version of last registration should equal value - 1.
|
|
deprecated_ops = {
|
|
"kOnnxDomain:Scatter": 11,
|
|
"kOnnxDomain:Upsample": 10,
|
|
# LayerNormalization, MeanVarianceNormalization and ThresholdedRelu were in contrib ops and incorrectly registered
|
|
# using the kOnnxDomain. They became official ONNX operators later and are registered there now. That leaves
|
|
# entries in the contrib ops registrations with end versions for when the contrib op was 'deprecated'
|
|
# and became an official op.
|
|
"kOnnxDomain:LayerNormalization": 17,
|
|
"kOnnxDomain:MeanVarianceNormalization": 9,
|
|
"kOnnxDomain:ThresholdedRelu": 10,
|
|
}
|
|
|
|
|
|
class RegistrationValidator(op_registration_utils.RegistrationProcessor):
|
|
def __init__(self):
|
|
self.last_op_registrations = {}
|
|
self.failed = False
|
|
|
|
def process_registration(
|
|
self,
|
|
lines: typing.List[str],
|
|
domain: str,
|
|
operator: str,
|
|
start_version: int,
|
|
end_version: typing.Optional[int] = None,
|
|
type: typing.Optional[str] = None,
|
|
):
|
|
key = domain + ":" + operator
|
|
prev_start, prev_end = self.last_op_registrations[key] if key in self.last_op_registrations else (None, None)
|
|
|
|
if prev_start:
|
|
# a typed registration where the to/from matches for each entry so nothing to update
|
|
if prev_start == start_version and prev_end == end_version:
|
|
return
|
|
|
|
# previous registration was unversioned but should have been if we are seeing another registration
|
|
if not prev_end:
|
|
log.error(
|
|
"Invalid registration for {}. Registration for opset {} has no end version but was "
|
|
"superceeded by version {}.".format(key, prev_start, start_version)
|
|
)
|
|
self.failed = True
|
|
return
|
|
|
|
# previous registration end opset is not adjacent to the start of the next registration
|
|
if prev_end != start_version - 1:
|
|
log.error(
|
|
"Invalid registration for {}. Registration for opset {} should have end version of {}".format(
|
|
key, prev_start, start_version - 1
|
|
)
|
|
)
|
|
self.failed = True
|
|
return
|
|
|
|
self.last_op_registrations[key] = (start_version, end_version)
|
|
|
|
def ok(self):
|
|
return not self.failed
|
|
|
|
def validate_last_registrations(self):
|
|
# make sure we have an unversioned last entry for each operator unless it's deprecated
|
|
for entry in self.last_op_registrations.items():
|
|
key, value = entry
|
|
opset_from, opset_to = value
|
|
|
|
deprecated = key in deprecated_ops and opset_to == deprecated_ops[key] - 1
|
|
if opset_to and not deprecated:
|
|
log.error("Missing unversioned registration for {}".format(key))
|
|
self.failed = True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description="Script to validate operator kernel registrations.")
|
|
|
|
parser.add_argument(
|
|
"--ort_root",
|
|
type=str,
|
|
help="Path to ONNXRuntime repository root. " "Inferred from the location of this script if not provided.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
ort_root = os.path.abspath(args.ort_root) if args.ort_root else ""
|
|
include_cuda = True # validate CPU and CUDA EP registrations
|
|
|
|
registration_files = op_registration_utils.get_kernel_registration_files(ort_root, include_cuda)
|
|
|
|
for file in registration_files:
|
|
log.info("Processing {}".format(file))
|
|
|
|
processor = RegistrationValidator()
|
|
op_registration_utils.process_kernel_registration_file(file, processor)
|
|
processor.validate_last_registrations()
|
|
|
|
if not processor.ok():
|
|
sys.exit(-1)
|