Fix Where op type reduction processing (#9033)

* Update type reduction script to track Where Op's second input type.

* Clean up op_kernel_type_control.h includes.

* Use more maintainable include.
This commit is contained in:
Edward Chen 2021-09-13 08:37:58 -07:00 committed by GitHub
parent a1021a1cf4
commit 011cb8fd48
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 7 deletions

View file

@ -3,16 +3,11 @@
#pragma once
#include <cstdint>
#include <tuple>
#include "boost/mp11.hpp"
#include "core/common/type_list.h"
#include "core/common/type_set_utils.h"
#include "core/framework/data_types.h"
/**
* These utilities provide a way to control what types are enabled for an Op kernel implementation.
* At a high level, we have the notion of default, required, allowed, and enabled type sets.
@ -471,5 +466,7 @@ struct EnabledTypes {
* using Dispatcher = onnxruntime::utils::MLTypeCallDispatcherFromTypeList<MyOpFirstInputEnabledTypes>;
*/
#include "core/framework/data_types.h" // for types that might be used in type specifications
// all allowed type specifications should be contained in the following file
#include "core/providers/op_kernel_type_control_overrides.inc"

View file

@ -240,6 +240,19 @@ class DefaultTypeUsageProcessor(TypeUsageProcessor):
self._output_types[int(o_str)] = set(values)
class Input1TypedRegistrationProcessor(DefaultTypeUsageProcessor):
'''
Processor for operators where the second input type is used in a typed kernel registration.
'''
def __init__(self, domain: str, optype: str):
# init with tracking of input 1 only.
super().__init__(domain, optype, inputs=[1], outputs=[])
def is_typed_registration_needed(self, type_in_registration: str,
globally_allowed_types: typing.Optional[typing.Set[str]]):
return self.is_input_type_enabled(type_in_registration, 1, globally_allowed_types)
class Output0TypedRegistrationProcessor(DefaultTypeUsageProcessor):
'''
Processor for operators where the first output type is used in a typed kernel registration.
@ -339,8 +352,7 @@ def _create_operator_type_usage_processors():
'Scatter', 'ScatterElements', 'ScatterND', 'Shrink', 'Sigmoid', 'Sign', 'Sin',
'Softmax', 'Split', 'SplitToSequence', 'Sqrt', 'Sum',
'Tanh', 'TopK', 'Transpose',
'Unique',
'Where']
'Unique']
# ops that are used to manipulate shapes or indices so require int32_t and int64_t to be available
default_processor_onnx_ops_requiring_ints_for_input_0 = ['Add',
@ -392,6 +404,9 @@ def _create_operator_type_usage_processors():
onnx_random_ops = ['RandomNormal', 'RandomNormalLike', 'RandomUniform', 'RandomUniformLike', 'Multinomial']
[add(DefaultTypeUsageProcessor('ai.onnx', op, inputs=[], outputs=[0])) for op in onnx_random_ops]
# Where always has a boolean first input so track the second input type for typed registration
add(Input1TypedRegistrationProcessor('ai.onnx', 'Where'))
# we only support 'float' as input for [Dynamic]QuantizeLinear so just track the output type
# as that's what is used in the typed registration
add(Output0TypedRegistrationProcessor('ai.onnx', 'QuantizeLinear'))