From 47a0289ee6ff7c92e28ec67ba635cbbbaa1909ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 7 Jul 2023 18:21:06 +0200 Subject: [PATCH] [CI] Removes type2 in process_registration and fix Windows GPU Reduced Ops CI Pipeline (#16530) ### Description Windows GPU Reduced Ops CI Pipeline is broken due to the introduction of a second template type in registered kernels. The python code checking the registration is broken due to that. This PR addresses this issue on the python side by keeping only one type equal to the concatenation of the two types. --- tools/ci_build/op_registration_utils.py | 8 +++----- tools/ci_build/op_registration_validator.py | 1 - 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tools/ci_build/op_registration_utils.py b/tools/ci_build/op_registration_utils.py index 5ed8e52b5f..3fd01253a3 100644 --- a/tools/ci_build/op_registration_utils.py +++ b/tools/ci_build/op_registration_utils.py @@ -94,7 +94,6 @@ class RegistrationProcessor: start_version: int, end_version: typing.Optional[int] = None, type: typing.Optional[str] = None, - type2: typing.Optional[str] = None, ): """ Process lines that contain a kernel registration. @@ -103,8 +102,7 @@ class RegistrationProcessor: :param operator: Operator type :param start_version: Start version :param end_version: End version or None if unversioned registration - :param type: Type used in registration, if this is a typed registration - :param type2: Second type used in registration, if this is a typed registration + :param type: Type or types used in registration, if this is a typed registration """ pass @@ -218,7 +216,7 @@ def _process_lines(lines: typing.List[str], offset: int, registration_processor: arg.strip() for arg in code_line[trim_at : -len(end_mark)].split(",") ) registration_processor.process_registration( - lines_to_process, domain, op_type, int(start_version), None, type1, type2 + lines_to_process, domain, op_type, int(start_version), None, type1 + ", " + type2 ) elif onnx_versioned_two_typed_op in code_line: @@ -229,7 +227,7 @@ def _process_lines(lines: typing.List[str], offset: int, registration_processor: arg.strip() for arg in code_line[trim_at : -len(end_mark)].split(",") ) registration_processor.process_registration( - lines_to_process, domain, op_type, int(start_version), int(end_version), type1, type2 + lines_to_process, domain, op_type, int(start_version), int(end_version), type1 + ", " + type2 ) else: diff --git a/tools/ci_build/op_registration_validator.py b/tools/ci_build/op_registration_validator.py index 36ca0c54e2..8222437f7b 100644 --- a/tools/ci_build/op_registration_validator.py +++ b/tools/ci_build/op_registration_validator.py @@ -60,7 +60,6 @@ class RegistrationValidator(op_registration_utils.RegistrationProcessor): start_version: int, end_version: typing.Optional[int] = None, type: typing.Optional[str] = None, - type2: typing.Optional[str] = None, ): self.all_registrations.append( RegistrationInfo(