From 011cb8fd48b95bc952e50feef73137038d224cb0 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Mon, 13 Sep 2021 08:37:58 -0700 Subject: [PATCH] Fix Where op type reduction processing (#9033) * Update type reduction script to track Where Op's second input type. * Clean up op_kernel_type_control.h includes. * Use more maintainable include. --- .../core/providers/op_kernel_type_control.h | 7 ++----- .../operator_type_usage_processors.py | 19 +++++++++++++++++-- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/op_kernel_type_control.h b/onnxruntime/core/providers/op_kernel_type_control.h index 608d004d6c..446666eaa7 100644 --- a/onnxruntime/core/providers/op_kernel_type_control.h +++ b/onnxruntime/core/providers/op_kernel_type_control.h @@ -3,16 +3,11 @@ #pragma once -#include -#include - #include "boost/mp11.hpp" #include "core/common/type_list.h" #include "core/common/type_set_utils.h" -#include "core/framework/data_types.h" - /** * These utilities provide a way to control what types are enabled for an Op kernel implementation. * At a high level, we have the notion of default, required, allowed, and enabled type sets. @@ -471,5 +466,7 @@ struct EnabledTypes { * using Dispatcher = onnxruntime::utils::MLTypeCallDispatcherFromTypeList; */ +#include "core/framework/data_types.h" // for types that might be used in type specifications + // all allowed type specifications should be contained in the following file #include "core/providers/op_kernel_type_control_overrides.inc" diff --git a/tools/python/util/ort_format_model/operator_type_usage_processors.py b/tools/python/util/ort_format_model/operator_type_usage_processors.py index f3ea7aca18..2d9a679bc3 100644 --- a/tools/python/util/ort_format_model/operator_type_usage_processors.py +++ b/tools/python/util/ort_format_model/operator_type_usage_processors.py @@ -240,6 +240,19 @@ class DefaultTypeUsageProcessor(TypeUsageProcessor): self._output_types[int(o_str)] = set(values) +class Input1TypedRegistrationProcessor(DefaultTypeUsageProcessor): + ''' + Processor for operators where the second input type is used in a typed kernel registration. + ''' + def __init__(self, domain: str, optype: str): + # init with tracking of input 1 only. + super().__init__(domain, optype, inputs=[1], outputs=[]) + + def is_typed_registration_needed(self, type_in_registration: str, + globally_allowed_types: typing.Optional[typing.Set[str]]): + return self.is_input_type_enabled(type_in_registration, 1, globally_allowed_types) + + class Output0TypedRegistrationProcessor(DefaultTypeUsageProcessor): ''' Processor for operators where the first output type is used in a typed kernel registration. @@ -339,8 +352,7 @@ def _create_operator_type_usage_processors(): 'Scatter', 'ScatterElements', 'ScatterND', 'Shrink', 'Sigmoid', 'Sign', 'Sin', 'Softmax', 'Split', 'SplitToSequence', 'Sqrt', 'Sum', 'Tanh', 'TopK', 'Transpose', - 'Unique', - 'Where'] + 'Unique'] # ops that are used to manipulate shapes or indices so require int32_t and int64_t to be available default_processor_onnx_ops_requiring_ints_for_input_0 = ['Add', @@ -392,6 +404,9 @@ def _create_operator_type_usage_processors(): onnx_random_ops = ['RandomNormal', 'RandomNormalLike', 'RandomUniform', 'RandomUniformLike', 'Multinomial'] [add(DefaultTypeUsageProcessor('ai.onnx', op, inputs=[], outputs=[0])) for op in onnx_random_ops] + # Where always has a boolean first input so track the second input type for typed registration + add(Input1TypedRegistrationProcessor('ai.onnx', 'Where')) + # we only support 'float' as input for [Dynamic]QuantizeLinear so just track the output type # as that's what is used in the typed registration add(Output0TypedRegistrationProcessor('ai.onnx', 'QuantizeLinear'))