Add type reduction support to Min, Max and Pow (#6519)

* Add type reduction support to Min, Max and Pow
Update the C++ type reduction infrastructure to allow specifying an opset for the supported types list, as those can change across opset versions.
Minor updates to the type usage tracking script
* Add 'all opsets' macros and constant
This commit is contained in:
Scott McKay 2021-02-03 06:51:35 +10:00 committed by GitHub
parent fbb24b57d0
commit 8d53ef69e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 192 additions and 60 deletions

View file

@ -4,6 +4,7 @@
#include "core/framework/data_types_internal.h"
#include "core/providers/cpu/math/element_wise_ops.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/op_kernel_type_control.h"
#include <unsupported/Eigen/SpecialFunctions>
#include "core/util/math.h"
#include "core/mlas/inc/mlas.h"
@ -11,6 +12,44 @@
#include <cmath>
namespace onnxruntime {
// Supported types for operators that have type reduction enabled
namespace op_kernel_type_control {
// Max
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 8, Input, 0, float, double);
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0,
float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
// Min
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 8, Input, 0, float, double);
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 12,
Input, 0, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
// Pow
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 7, Input, 0, float, double);
// Pow 12 and later has separate Base and Exponent types
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 12,
Input, 0, int32_t, int64_t, float, double);
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 12,
Input, 1, int32_t, int64_t, float, double);
} // namespace op_kernel_type_control
//
// reduce the supported type lists to what's allowed in this build
//
using EnabledMax8Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Max, 8, Input, 0);
using EnabledMax12Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0);
using EnabledMin8Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 8, Input, 0);
using EnabledMin12Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0);
using EnabledPow7Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 7, Input, 0);
using EnabledPow12BaseTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain,
Pow, 12, Input, 0);
using EnabledPow12ExpTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain,
Pow, 12, Input, 1);
namespace functors {
template <>
void Exp<float>::operator()(std::ptrdiff_t first, std::ptrdiff_t last) const {
@ -56,25 +95,41 @@ void Exp<float>::operator()(std::ptrdiff_t first, std::ptrdiff_t last) const {
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
KERNEL_CLASS<TYPE>);
// var args are type constraints for T and T1
#define REG_ELEMENTWISE_KERNEL_NONT(OP_TYPE, VERSION, KERNEL_CLASS, ...) \
ONNX_CPU_OPERATOR_KERNEL( \
OP_TYPE, \
VERSION, \
KernelDefBuilder() \
.TypeConstraint("T", BuildKernelDefConstraints<__VA_ARGS__>()) \
.TypeConstraint("T1", BuildKernelDefConstraints<__VA_ARGS__>()), \
#define REG_ELEMENTWISE_KERNEL_NONT(OP_TYPE, VERSION, KERNEL_CLASS, CONSTRAINTS) \
ONNX_CPU_OPERATOR_KERNEL( \
OP_TYPE, \
VERSION, \
KernelDefBuilder() \
.TypeConstraint("T", CONSTRAINTS), \
KERNEL_CLASS);
// var args are type constraints for T and T1
#define REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, ...) \
ONNX_CPU_OPERATOR_VERSIONED_KERNEL( \
OP_TYPE, \
VERSION_FROM, \
VERSION_TO, \
KernelDefBuilder() \
.TypeConstraint("T", BuildKernelDefConstraints<__VA_ARGS__>()) \
.TypeConstraint("T1", BuildKernelDefConstraints<__VA_ARGS__>()), \
#define REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, CONSTRAINTS) \
ONNX_CPU_OPERATOR_VERSIONED_KERNEL( \
OP_TYPE, \
VERSION_FROM, \
VERSION_TO, \
KernelDefBuilder() \
.TypeConstraint("T", CONSTRAINTS), \
KERNEL_CLASS);
#define REG_ELEMENTWISE_KERNEL_NONT_2(OP_TYPE, VERSION, KERNEL_CLASS, T1_CONSTRAINTS, T2_CONSTRAINTS) \
ONNX_CPU_OPERATOR_KERNEL( \
OP_TYPE, \
VERSION, \
KernelDefBuilder() \
.TypeConstraint("T", T1_CONSTRAINTS) \
.TypeConstraint("T1", T2_CONSTRAINTS), \
KERNEL_CLASS);
#define REG_ELEMENTWISE_VERSIONED_KERNEL_NONT_2(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, T1_CONSTRAINTS, T2_CONSTRAINTS) \
ONNX_CPU_OPERATOR_VERSIONED_KERNEL( \
OP_TYPE, \
VERSION_FROM, \
VERSION_TO, \
KernelDefBuilder() \
.TypeConstraint("T", T1_CONSTRAINTS) \
.TypeConstraint("T1", T2_CONSTRAINTS), \
KERNEL_CLASS);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 7, 12, float, Add);
@ -162,11 +217,14 @@ REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sqrt, 6, 12, double, Sqrt);
REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 13, float, Sqrt);
REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 13, double, Sqrt);
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Pow, 7, 11, Pow, float, double);
// To reduce templetization we choose to support the below types for both
// base and the exponent. This gives us 16 permutations
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Pow, 12, 12, Pow, int32_t, int64_t, float, double);
REG_ELEMENTWISE_KERNEL_NONT(Pow, 13, Pow, int32_t, int64_t, float, double);
const auto pow7_types = BuildKernelDefConstraintsFunctorFromTypeList<EnabledPow7Types>{}();
// To reduce templatization we choose to support the below types for both
// base and the exponent. This gives us 16 permutations.
const auto pow12_base_types = BuildKernelDefConstraintsFunctorFromTypeList<EnabledPow12BaseTypes>{}();
const auto pow12_exp_types = BuildKernelDefConstraintsFunctorFromTypeList<EnabledPow12ExpTypes>{}();
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Pow, 7, 11, Pow, pow7_types);
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT_2(Pow, 12, 12, Pow, pow12_base_types, pow12_exp_types);
REG_ELEMENTWISE_KERNEL_NONT_2(Pow, 13, Pow, pow12_base_types, pow12_exp_types);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Exp, 6, 12, float, Exp);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Exp, 6, 12, double, Exp);
@ -187,16 +245,21 @@ REG_ELEMENTWISE_TYPED_KERNEL(Sum, 13, float, Sum_8);
REG_ELEMENTWISE_TYPED_KERNEL(Sum, 13, double, Sum_8);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Max, 6, 7, float, Max_6);
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Max, 8, 11, Max_8, float, double);
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Max, 12, 12, Max_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
// Supposed to add BFloat16 but we are not supporting now, however, separate registration
REG_ELEMENTWISE_KERNEL_NONT(Max, 13, Max_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Min, 6, 7, float, Min_6);
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Min, 8, 11, Min_8, float);
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Min, 12, 12, Min_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
const std::vector<MLDataType> max8_types = BuildKernelDefConstraintsFunctorFromTypeList<EnabledMax8Types>{}();
const std::vector<MLDataType> max12_types = BuildKernelDefConstraintsFunctorFromTypeList<EnabledMax12Types>{}();
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Max, 8, 11, Max_8, max8_types);
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Max, 12, 12, Max_8, max12_types);
// Supposed to add BFloat16 but we are not supporting now, however, separate registration
REG_ELEMENTWISE_KERNEL_NONT(Min, 13, Min_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
REG_ELEMENTWISE_KERNEL_NONT(Max, 13, Max_8, max12_types);
const std::vector<MLDataType> min8_types = BuildKernelDefConstraintsFunctorFromTypeList<EnabledMin8Types>{}();
const std::vector<MLDataType> min12_types = BuildKernelDefConstraintsFunctorFromTypeList<EnabledMin12Types>{}();
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Min, 6, 7, float, Min_6);
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Min, 8, 11, Min_8, min8_types);
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Min, 12, 12, Min_8, min12_types);
// Supposed to add BFloat16 but we are not supporting now, however, separate registration
REG_ELEMENTWISE_KERNEL_NONT(Min, 13, Min_8, min12_types);
REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 7, 8, float, Less);
REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 7, 8, double, Less);

View file

@ -30,7 +30,8 @@ using namespace boost::mp11;
namespace onnxruntime {
namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
// we're using one set of types for all opsets of Cast
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0,
bool,
float, double,
@ -39,7 +40,7 @@ ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
MLFloat16, BFloat16,
std::string);
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0,
bool,
float, double,
@ -50,9 +51,10 @@ ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
} // namespace op_kernel_type_control
namespace {
using EnabledSrcTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0);
using EnabledDstTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0);
using EnabledSrcTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
Cast, Input, 0);
using EnabledDstTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
Cast, Output, 0);
using IndirectCastTypes = TypeList<MLFloat16, BFloat16>;

View file

@ -14,14 +14,15 @@ namespace onnxruntime {
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#IsInf
namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0,
float, double);
} // namespace op_kernel_type_control
class IsInf final : public OpKernel {
public:
using EnabledTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0);
using EnabledTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
IsInf, Input, 0);
explicit IsInf(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;

View file

@ -37,6 +37,9 @@ enum class OpArgDirection {
using OpArgIndex = size_t;
// constant to use for type lists that are valid across all opsets
constexpr int kAllOpSets = -1;
namespace tags {
// a tag that identifies the target (Op argument) of the specified types
@ -44,8 +47,9 @@ template <typename OpTag, OpArgDirection ArgDirection, OpArgIndex ArgIndex>
struct OpArg {};
// a tag that indicates the supported types for a particular Op argument, identified by OpArgTag,
// for a kernel in a particular provider, identified by ProviderTag
template <typename OpArgTag, typename ProviderTag>
// for a kernel in a particular provider, identified by ProviderTag. as the types may change between opsets,
// the opset must also be specified. if the type list is not opset specific, use kAllOpSets as the value.
template <typename OpArgTag, typename ProviderTag, int OpSet>
struct Supported {};
// a tag that indicates the allowed types for a particular Op argument, identified by OpArgTag
@ -151,39 +155,64 @@ struct EnabledTypes {
* @param OpProvider The Op provider.
* @param OpDomain The Op domain.
* @param OpName The Op name.
* @param OpSet The opset that this set of supported types applies to.
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
* @param ArgIndex Index of the given Op kernel argument.
* @param ... The types.
*/
#define ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES( \
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex, ...) \
OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex, ...) \
class ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_TAG_CLASS_NAME(OpDomain, OpName); \
class ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider); \
template <> \
struct TypesHolder< \
::onnxruntime::op_kernel_type_control::tags::Supported< \
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG(OpDomain, OpName, ArgDirection, ArgIndex), \
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider)>> { \
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider), \
OpSet>> { \
using types = ::onnxruntime::TypeList<__VA_ARGS__>; \
};
/**
* TypeList type with the enabled types for a given Op kernel argument.
* Specifies a supported set of types for a given Op kernel argument that is valid for all opsets.
* This should be specified with the Op kernel implementation.
*
* Note: This should be called from the onnxruntime::op_kernel_type_control namespace.
*
* @param OpProvider The Op provider.
* @param OpDomain The Op domain.
* @param OpName The Op name.
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
* @param ArgIndex Index of the given Op kernel argument.
* @param ... The types.
*/
#define ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS( \
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex, ...) \
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(OpProvider, OpDomain, OpName, \
::onnxruntime::op_kernel_type_control::kAllOpSets, \
ArgDirection, ArgIndex, __VA_ARGS__)
/**
* TypeList type with the enabled types for a given Op kernel argument.
* This is created by intersecting the supported types with any type restrictions coming from the allowed or global
* type lists.
*
* @param OpProvider The Op provider.
* @param OpDomain The Op domain.
* @param OpName The Op name.
* @param OpSet The opset to use for the supported types list.
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
* @param ArgIndex Index of the given Op kernel argument.
*/
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \
::onnxruntime::op_kernel_type_control::EnabledTypes< \
OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex) \
::onnxruntime::op_kernel_type_control::EnabledTypes< \
::onnxruntime::op_kernel_type_control::TypesHolder< \
::onnxruntime::op_kernel_type_control::tags::Supported< \
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG(OpDomain, OpName, ArgDirection, ArgIndex), \
::onnxruntime::op_kernel_type_control:: \
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider)>>, \
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider), \
OpSet>>, \
::onnxruntime::TypeList< \
::onnxruntime::op_kernel_type_control::TypesHolder< \
::onnxruntime::op_kernel_type_control::tags::Allowed< \
@ -192,7 +221,9 @@ struct EnabledTypes {
::onnxruntime::op_kernel_type_control::tags::GlobalAllowed>>>::types
/**
* std::tuple type with the enabled types for a given Op kernel argument.
* TypeList type with the enabled types for a given Op kernel argument that is valid for all opsets.
* This is created by intersecting the supported types with any type restrictions coming from the allowed or global
* type lists.
*
* @param OpProvider The Op provider.
* @param OpDomain The Op domain.
@ -200,13 +231,43 @@ struct EnabledTypes {
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
* @param ArgIndex Index of the given Op kernel argument.
*/
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE( \
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \
::boost::mp11::mp_rename< \
ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex, SupportedTypeList), \
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS( \
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \
ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(OpProvider, OpDomain, OpName, \
::onnxruntime::op_kernel_type_control::kAllOpSets, \
ArgDirection, ArgIndex)
/**
* std::tuple type with the enabled types for a given Op kernel argument.
*
* @param OpProvider The Op provider.
* @param OpDomain The Op domain.
* @param OpName The Op name.
* @param OpSet The opset to use for the supported types list.
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
* @param ArgIndex Index of the given Op kernel argument.
*/
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE( \
OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex) \
::boost::mp11::mp_rename< \
ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \
OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex, SupportedTypeList), \
std::tuple>
/**
* std::tuple type with the enabled types for a given Op kernel argument that is valid for all opsets.
*
* @param OpProvider The Op provider.
* @param OpDomain The Op domain.
* @param OpName The Op name.
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
* @param ArgIndex Index of the given Op kernel argument.
*/
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE_ALL_OPSETS( \
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \
ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE(OpProvider, OpDomain, OpName, \
::onnxruntime::op_kernel_type_control::kAllOpSets, \
ArgDirection, ArgIndex)
/**
* Usage example:
*

View file

@ -241,16 +241,17 @@ def _create_operator_type_usage_processors():
# - Mobilenet + SSD Mobilenet + MobileBert
# - some known large kernels
#
# Ops we are ignoring currently so as not to produce meaningless output:
# Implementation is not type specific:
# Ops we are ignoring currently so as not to produce meaningless/unused output:
# - Implementation is not type specific:
# If, Loop, Reshape, Scan, Shape, Squeeze, Unsqueeze
# Only one type supported in the ORT implementation:
# FusedConv, FusedGemm, FusedMatMul, TransposeMatMul
#
default_processor_onnx_ops = ['Add', 'AveragePool', 'BatchNormalization', 'Clip', 'Concat', 'Conv',
'DequantizeLinear', 'Div', 'Equal', 'Exp', 'Expand', 'Flatten',
# - Only one type supported in the ORT implementation:
# FusedConv, FusedGemm, FusedMatMul, TransposeMatMul
# - Implementation does not have any significant type specific code:
# Concat, Flatten, Not, QLinearConv, Reshape, Shape, Squeeze, Unsqueeze
default_processor_onnx_ops = ['Add', 'AveragePool', 'BatchNormalization', 'Clip', 'Conv',
'DequantizeLinear', 'Div', 'Equal', 'Exp', 'Expand',
'Gemm', 'Greater', 'Less', 'MatMul', 'Max', 'Min', 'Mul',
'NonMaxSuppression', 'NonZero', 'Pad', 'QLinearConv', 'Range', 'Relu', 'Resize',
'NonMaxSuppression', 'NonZero', 'Pad', 'Range', 'Relu', 'Resize',
'Sigmoid', 'Slice', 'Softmax', 'Split', 'Sub', 'Tile', 'TopK', 'Transpose']
internal_ops = ['QLinearAdd', 'QLinearMul']
@ -271,7 +272,7 @@ def _create_operator_type_usage_processors():
[add(DefaultTypeUsageProcessor('com.microsoft', op)) for op in internal_ops]
#
# Operators that require slightly different handling
# Operators that require custom handling
#
add(DefaultTypeUsageProcessor('ai.onnx', 'Cast', inputs=[0], outputs=[0])) # track input0 and output0
@ -282,12 +283,16 @@ def _create_operator_type_usage_processors():
# Pow dispatches on base and exponential types
add(DefaultTypeUsageProcessor('ai.onnx', 'Pow', inputs=[0, 1]))
# ConstantOfShape switches on size of output type
add(DefaultTypeUsageProcessor('ai.onnx', 'ConstantOfShape', inputs=[], outputs=[0]))
# Random generator ops produce new data so we track the output type
onnx_random_ops = ['RandomNormal', 'RandomNormalLike', 'RandomUniform', 'RandomUniformLike', 'Multinomial']
[add(DefaultTypeUsageProcessor('ai.onnx', op, inputs=[], outputs=[0])) for op in onnx_random_ops]
# we only support 'float' as input for QuantizeLinear so just track the output type
# we only support 'float' as input for [Dynamic]QuantizeLinear so just track the output type
add(DefaultTypeUsageProcessor('ai.onnx', 'QuantizeLinear', inputs=[], outputs=[0]))
add(DefaultTypeUsageProcessor('ai.onnx', 'DynamicQuantizeLinear', inputs=[], outputs=[0]))
# OneHot concatenates type strings into a triple in the typed registration
# e.g. float_int64_t_int64_t