mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Fix Where op type reduction processing (#9033)
* Update type reduction script to track Where Op's second input type. * Clean up op_kernel_type_control.h includes. * Use more maintainable include.
This commit is contained in:
parent
a1021a1cf4
commit
011cb8fd48
2 changed files with 19 additions and 7 deletions
|
|
@ -3,16 +3,11 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <tuple>
|
||||
|
||||
#include "boost/mp11.hpp"
|
||||
|
||||
#include "core/common/type_list.h"
|
||||
#include "core/common/type_set_utils.h"
|
||||
|
||||
#include "core/framework/data_types.h"
|
||||
|
||||
/**
|
||||
* These utilities provide a way to control what types are enabled for an Op kernel implementation.
|
||||
* At a high level, we have the notion of default, required, allowed, and enabled type sets.
|
||||
|
|
@ -471,5 +466,7 @@ struct EnabledTypes {
|
|||
* using Dispatcher = onnxruntime::utils::MLTypeCallDispatcherFromTypeList<MyOpFirstInputEnabledTypes>;
|
||||
*/
|
||||
|
||||
#include "core/framework/data_types.h" // for types that might be used in type specifications
|
||||
|
||||
// all allowed type specifications should be contained in the following file
|
||||
#include "core/providers/op_kernel_type_control_overrides.inc"
|
||||
|
|
|
|||
|
|
@ -240,6 +240,19 @@ class DefaultTypeUsageProcessor(TypeUsageProcessor):
|
|||
self._output_types[int(o_str)] = set(values)
|
||||
|
||||
|
||||
class Input1TypedRegistrationProcessor(DefaultTypeUsageProcessor):
|
||||
'''
|
||||
Processor for operators where the second input type is used in a typed kernel registration.
|
||||
'''
|
||||
def __init__(self, domain: str, optype: str):
|
||||
# init with tracking of input 1 only.
|
||||
super().__init__(domain, optype, inputs=[1], outputs=[])
|
||||
|
||||
def is_typed_registration_needed(self, type_in_registration: str,
|
||||
globally_allowed_types: typing.Optional[typing.Set[str]]):
|
||||
return self.is_input_type_enabled(type_in_registration, 1, globally_allowed_types)
|
||||
|
||||
|
||||
class Output0TypedRegistrationProcessor(DefaultTypeUsageProcessor):
|
||||
'''
|
||||
Processor for operators where the first output type is used in a typed kernel registration.
|
||||
|
|
@ -339,8 +352,7 @@ def _create_operator_type_usage_processors():
|
|||
'Scatter', 'ScatterElements', 'ScatterND', 'Shrink', 'Sigmoid', 'Sign', 'Sin',
|
||||
'Softmax', 'Split', 'SplitToSequence', 'Sqrt', 'Sum',
|
||||
'Tanh', 'TopK', 'Transpose',
|
||||
'Unique',
|
||||
'Where']
|
||||
'Unique']
|
||||
|
||||
# ops that are used to manipulate shapes or indices so require int32_t and int64_t to be available
|
||||
default_processor_onnx_ops_requiring_ints_for_input_0 = ['Add',
|
||||
|
|
@ -392,6 +404,9 @@ def _create_operator_type_usage_processors():
|
|||
onnx_random_ops = ['RandomNormal', 'RandomNormalLike', 'RandomUniform', 'RandomUniformLike', 'Multinomial']
|
||||
[add(DefaultTypeUsageProcessor('ai.onnx', op, inputs=[], outputs=[0])) for op in onnx_random_ops]
|
||||
|
||||
# Where always has a boolean first input so track the second input type for typed registration
|
||||
add(Input1TypedRegistrationProcessor('ai.onnx', 'Where'))
|
||||
|
||||
# we only support 'float' as input for [Dynamic]QuantizeLinear so just track the output type
|
||||
# as that's what is used in the typed registration
|
||||
add(Output0TypedRegistrationProcessor('ai.onnx', 'QuantizeLinear'))
|
||||
|
|
|
|||
Loading…
Reference in a new issue