mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Add type reduction support to Min, Max and Pow (#6519)
* Add type reduction support to Min, Max and Pow Update the C++ type reduction infrastructure to allow specifying an opset for the supported types list, as those can change across opset versions. Minor updates to the type usage tracking script * Add 'all opsets' macros and constant
This commit is contained in:
parent
fbb24b57d0
commit
8d53ef69e5
5 changed files with 192 additions and 60 deletions
|
|
@ -4,6 +4,7 @@
|
|||
#include "core/framework/data_types_internal.h"
|
||||
#include "core/providers/cpu/math/element_wise_ops.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "core/providers/op_kernel_type_control.h"
|
||||
#include <unsupported/Eigen/SpecialFunctions>
|
||||
#include "core/util/math.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
|
@ -11,6 +12,44 @@
|
|||
#include <cmath>
|
||||
|
||||
namespace onnxruntime {
|
||||
// Supported types for operators that have type reduction enabled
|
||||
namespace op_kernel_type_control {
|
||||
// Max
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 8, Input, 0, float, double);
|
||||
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0,
|
||||
float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
|
||||
|
||||
// Min
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 8, Input, 0, float, double);
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 12,
|
||||
Input, 0, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
|
||||
|
||||
// Pow
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 7, Input, 0, float, double);
|
||||
|
||||
// Pow 12 and later has separate Base and Exponent types
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 12,
|
||||
Input, 0, int32_t, int64_t, float, double);
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 12,
|
||||
Input, 1, int32_t, int64_t, float, double);
|
||||
} // namespace op_kernel_type_control
|
||||
|
||||
//
|
||||
// reduce the supported type lists to what's allowed in this build
|
||||
//
|
||||
using EnabledMax8Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Max, 8, Input, 0);
|
||||
using EnabledMax12Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0);
|
||||
|
||||
using EnabledMin8Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 8, Input, 0);
|
||||
using EnabledMin12Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0);
|
||||
|
||||
using EnabledPow7Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 7, Input, 0);
|
||||
using EnabledPow12BaseTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain,
|
||||
Pow, 12, Input, 0);
|
||||
using EnabledPow12ExpTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain,
|
||||
Pow, 12, Input, 1);
|
||||
|
||||
namespace functors {
|
||||
template <>
|
||||
void Exp<float>::operator()(std::ptrdiff_t first, std::ptrdiff_t last) const {
|
||||
|
|
@ -56,25 +95,41 @@ void Exp<float>::operator()(std::ptrdiff_t first, std::ptrdiff_t last) const {
|
|||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
|
||||
KERNEL_CLASS<TYPE>);
|
||||
|
||||
// var args are type constraints for T and T1
|
||||
#define REG_ELEMENTWISE_KERNEL_NONT(OP_TYPE, VERSION, KERNEL_CLASS, ...) \
|
||||
ONNX_CPU_OPERATOR_KERNEL( \
|
||||
OP_TYPE, \
|
||||
VERSION, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", BuildKernelDefConstraints<__VA_ARGS__>()) \
|
||||
.TypeConstraint("T1", BuildKernelDefConstraints<__VA_ARGS__>()), \
|
||||
#define REG_ELEMENTWISE_KERNEL_NONT(OP_TYPE, VERSION, KERNEL_CLASS, CONSTRAINTS) \
|
||||
ONNX_CPU_OPERATOR_KERNEL( \
|
||||
OP_TYPE, \
|
||||
VERSION, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", CONSTRAINTS), \
|
||||
KERNEL_CLASS);
|
||||
|
||||
// var args are type constraints for T and T1
|
||||
#define REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, ...) \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL( \
|
||||
OP_TYPE, \
|
||||
VERSION_FROM, \
|
||||
VERSION_TO, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", BuildKernelDefConstraints<__VA_ARGS__>()) \
|
||||
.TypeConstraint("T1", BuildKernelDefConstraints<__VA_ARGS__>()), \
|
||||
#define REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, CONSTRAINTS) \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL( \
|
||||
OP_TYPE, \
|
||||
VERSION_FROM, \
|
||||
VERSION_TO, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", CONSTRAINTS), \
|
||||
KERNEL_CLASS);
|
||||
|
||||
#define REG_ELEMENTWISE_KERNEL_NONT_2(OP_TYPE, VERSION, KERNEL_CLASS, T1_CONSTRAINTS, T2_CONSTRAINTS) \
|
||||
ONNX_CPU_OPERATOR_KERNEL( \
|
||||
OP_TYPE, \
|
||||
VERSION, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", T1_CONSTRAINTS) \
|
||||
.TypeConstraint("T1", T2_CONSTRAINTS), \
|
||||
KERNEL_CLASS);
|
||||
|
||||
#define REG_ELEMENTWISE_VERSIONED_KERNEL_NONT_2(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, T1_CONSTRAINTS, T2_CONSTRAINTS) \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL( \
|
||||
OP_TYPE, \
|
||||
VERSION_FROM, \
|
||||
VERSION_TO, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", T1_CONSTRAINTS) \
|
||||
.TypeConstraint("T1", T2_CONSTRAINTS), \
|
||||
KERNEL_CLASS);
|
||||
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 7, 12, float, Add);
|
||||
|
|
@ -162,11 +217,14 @@ REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sqrt, 6, 12, double, Sqrt);
|
|||
REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 13, float, Sqrt);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 13, double, Sqrt);
|
||||
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Pow, 7, 11, Pow, float, double);
|
||||
// To reduce templetization we choose to support the below types for both
|
||||
// base and the exponent. This gives us 16 permutations
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Pow, 12, 12, Pow, int32_t, int64_t, float, double);
|
||||
REG_ELEMENTWISE_KERNEL_NONT(Pow, 13, Pow, int32_t, int64_t, float, double);
|
||||
const auto pow7_types = BuildKernelDefConstraintsFunctorFromTypeList<EnabledPow7Types>{}();
|
||||
// To reduce templatization we choose to support the below types for both
|
||||
// base and the exponent. This gives us 16 permutations.
|
||||
const auto pow12_base_types = BuildKernelDefConstraintsFunctorFromTypeList<EnabledPow12BaseTypes>{}();
|
||||
const auto pow12_exp_types = BuildKernelDefConstraintsFunctorFromTypeList<EnabledPow12ExpTypes>{}();
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Pow, 7, 11, Pow, pow7_types);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT_2(Pow, 12, 12, Pow, pow12_base_types, pow12_exp_types);
|
||||
REG_ELEMENTWISE_KERNEL_NONT_2(Pow, 13, Pow, pow12_base_types, pow12_exp_types);
|
||||
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Exp, 6, 12, float, Exp);
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Exp, 6, 12, double, Exp);
|
||||
|
|
@ -187,16 +245,21 @@ REG_ELEMENTWISE_TYPED_KERNEL(Sum, 13, float, Sum_8);
|
|||
REG_ELEMENTWISE_TYPED_KERNEL(Sum, 13, double, Sum_8);
|
||||
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Max, 6, 7, float, Max_6);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Max, 8, 11, Max_8, float, double);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Max, 12, 12, Max_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
|
||||
// Supposed to add BFloat16 but we are not supporting now, however, separate registration
|
||||
REG_ELEMENTWISE_KERNEL_NONT(Max, 13, Max_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
|
||||
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Min, 6, 7, float, Min_6);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Min, 8, 11, Min_8, float);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Min, 12, 12, Min_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
|
||||
const std::vector<MLDataType> max8_types = BuildKernelDefConstraintsFunctorFromTypeList<EnabledMax8Types>{}();
|
||||
const std::vector<MLDataType> max12_types = BuildKernelDefConstraintsFunctorFromTypeList<EnabledMax12Types>{}();
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Max, 8, 11, Max_8, max8_types);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Max, 12, 12, Max_8, max12_types);
|
||||
// Supposed to add BFloat16 but we are not supporting now, however, separate registration
|
||||
REG_ELEMENTWISE_KERNEL_NONT(Min, 13, Min_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
|
||||
REG_ELEMENTWISE_KERNEL_NONT(Max, 13, Max_8, max12_types);
|
||||
|
||||
const std::vector<MLDataType> min8_types = BuildKernelDefConstraintsFunctorFromTypeList<EnabledMin8Types>{}();
|
||||
const std::vector<MLDataType> min12_types = BuildKernelDefConstraintsFunctorFromTypeList<EnabledMin12Types>{}();
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Min, 6, 7, float, Min_6);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Min, 8, 11, Min_8, min8_types);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Min, 12, 12, Min_8, min12_types);
|
||||
// Supposed to add BFloat16 but we are not supporting now, however, separate registration
|
||||
REG_ELEMENTWISE_KERNEL_NONT(Min, 13, Min_8, min12_types);
|
||||
|
||||
REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 7, 8, float, Less);
|
||||
REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 7, 8, double, Less);
|
||||
|
|
|
|||
|
|
@ -30,7 +30,8 @@ using namespace boost::mp11;
|
|||
namespace onnxruntime {
|
||||
|
||||
namespace op_kernel_type_control {
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
|
||||
// we're using one set of types for all opsets of Cast
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0,
|
||||
bool,
|
||||
float, double,
|
||||
|
|
@ -39,7 +40,7 @@ ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
|
|||
MLFloat16, BFloat16,
|
||||
std::string);
|
||||
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0,
|
||||
bool,
|
||||
float, double,
|
||||
|
|
@ -50,9 +51,10 @@ ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
|
|||
} // namespace op_kernel_type_control
|
||||
|
||||
namespace {
|
||||
|
||||
using EnabledSrcTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0);
|
||||
using EnabledDstTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0);
|
||||
using EnabledSrcTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
|
||||
Cast, Input, 0);
|
||||
using EnabledDstTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
|
||||
Cast, Output, 0);
|
||||
|
||||
using IndirectCastTypes = TypeList<MLFloat16, BFloat16>;
|
||||
|
||||
|
|
|
|||
|
|
@ -14,14 +14,15 @@ namespace onnxruntime {
|
|||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#IsInf
|
||||
|
||||
namespace op_kernel_type_control {
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0,
|
||||
float, double);
|
||||
} // namespace op_kernel_type_control
|
||||
|
||||
class IsInf final : public OpKernel {
|
||||
public:
|
||||
using EnabledTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0);
|
||||
using EnabledTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
|
||||
IsInf, Input, 0);
|
||||
|
||||
explicit IsInf(const OpKernelInfo& info);
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
|
|
|||
|
|
@ -37,6 +37,9 @@ enum class OpArgDirection {
|
|||
|
||||
using OpArgIndex = size_t;
|
||||
|
||||
// constant to use for type lists that are valid across all opsets
|
||||
constexpr int kAllOpSets = -1;
|
||||
|
||||
namespace tags {
|
||||
|
||||
// a tag that identifies the target (Op argument) of the specified types
|
||||
|
|
@ -44,8 +47,9 @@ template <typename OpTag, OpArgDirection ArgDirection, OpArgIndex ArgIndex>
|
|||
struct OpArg {};
|
||||
|
||||
// a tag that indicates the supported types for a particular Op argument, identified by OpArgTag,
|
||||
// for a kernel in a particular provider, identified by ProviderTag
|
||||
template <typename OpArgTag, typename ProviderTag>
|
||||
// for a kernel in a particular provider, identified by ProviderTag. as the types may change between opsets,
|
||||
// the opset must also be specified. if the type list is not opset specific, use kAllOpSets as the value.
|
||||
template <typename OpArgTag, typename ProviderTag, int OpSet>
|
||||
struct Supported {};
|
||||
|
||||
// a tag that indicates the allowed types for a particular Op argument, identified by OpArgTag
|
||||
|
|
@ -151,39 +155,64 @@ struct EnabledTypes {
|
|||
* @param OpProvider The Op provider.
|
||||
* @param OpDomain The Op domain.
|
||||
* @param OpName The Op name.
|
||||
* @param OpSet The opset that this set of supported types applies to.
|
||||
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
|
||||
* @param ArgIndex Index of the given Op kernel argument.
|
||||
* @param ... The types.
|
||||
*/
|
||||
#define ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES( \
|
||||
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex, ...) \
|
||||
OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex, ...) \
|
||||
class ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_TAG_CLASS_NAME(OpDomain, OpName); \
|
||||
class ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider); \
|
||||
template <> \
|
||||
struct TypesHolder< \
|
||||
::onnxruntime::op_kernel_type_control::tags::Supported< \
|
||||
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG(OpDomain, OpName, ArgDirection, ArgIndex), \
|
||||
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider)>> { \
|
||||
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider), \
|
||||
OpSet>> { \
|
||||
using types = ::onnxruntime::TypeList<__VA_ARGS__>; \
|
||||
};
|
||||
|
||||
/**
|
||||
* TypeList type with the enabled types for a given Op kernel argument.
|
||||
* Specifies a supported set of types for a given Op kernel argument that is valid for all opsets.
|
||||
* This should be specified with the Op kernel implementation.
|
||||
*
|
||||
* Note: This should be called from the onnxruntime::op_kernel_type_control namespace.
|
||||
*
|
||||
* @param OpProvider The Op provider.
|
||||
* @param OpDomain The Op domain.
|
||||
* @param OpName The Op name.
|
||||
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
|
||||
* @param ArgIndex Index of the given Op kernel argument.
|
||||
* @param ... The types.
|
||||
*/
|
||||
#define ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS( \
|
||||
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex, ...) \
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(OpProvider, OpDomain, OpName, \
|
||||
::onnxruntime::op_kernel_type_control::kAllOpSets, \
|
||||
ArgDirection, ArgIndex, __VA_ARGS__)
|
||||
|
||||
/**
|
||||
* TypeList type with the enabled types for a given Op kernel argument.
|
||||
* This is created by intersecting the supported types with any type restrictions coming from the allowed or global
|
||||
* type lists.
|
||||
*
|
||||
* @param OpProvider The Op provider.
|
||||
* @param OpDomain The Op domain.
|
||||
* @param OpName The Op name.
|
||||
* @param OpSet The opset to use for the supported types list.
|
||||
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
|
||||
* @param ArgIndex Index of the given Op kernel argument.
|
||||
*/
|
||||
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \
|
||||
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \
|
||||
::onnxruntime::op_kernel_type_control::EnabledTypes< \
|
||||
OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex) \
|
||||
::onnxruntime::op_kernel_type_control::EnabledTypes< \
|
||||
::onnxruntime::op_kernel_type_control::TypesHolder< \
|
||||
::onnxruntime::op_kernel_type_control::tags::Supported< \
|
||||
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG(OpDomain, OpName, ArgDirection, ArgIndex), \
|
||||
::onnxruntime::op_kernel_type_control:: \
|
||||
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider)>>, \
|
||||
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider), \
|
||||
OpSet>>, \
|
||||
::onnxruntime::TypeList< \
|
||||
::onnxruntime::op_kernel_type_control::TypesHolder< \
|
||||
::onnxruntime::op_kernel_type_control::tags::Allowed< \
|
||||
|
|
@ -192,7 +221,9 @@ struct EnabledTypes {
|
|||
::onnxruntime::op_kernel_type_control::tags::GlobalAllowed>>>::types
|
||||
|
||||
/**
|
||||
* std::tuple type with the enabled types for a given Op kernel argument.
|
||||
* TypeList type with the enabled types for a given Op kernel argument that is valid for all opsets.
|
||||
* This is created by intersecting the supported types with any type restrictions coming from the allowed or global
|
||||
* type lists.
|
||||
*
|
||||
* @param OpProvider The Op provider.
|
||||
* @param OpDomain The Op domain.
|
||||
|
|
@ -200,13 +231,43 @@ struct EnabledTypes {
|
|||
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
|
||||
* @param ArgIndex Index of the given Op kernel argument.
|
||||
*/
|
||||
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE( \
|
||||
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \
|
||||
::boost::mp11::mp_rename< \
|
||||
ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \
|
||||
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex, SupportedTypeList), \
|
||||
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS( \
|
||||
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \
|
||||
ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(OpProvider, OpDomain, OpName, \
|
||||
::onnxruntime::op_kernel_type_control::kAllOpSets, \
|
||||
ArgDirection, ArgIndex)
|
||||
|
||||
/**
|
||||
* std::tuple type with the enabled types for a given Op kernel argument.
|
||||
*
|
||||
* @param OpProvider The Op provider.
|
||||
* @param OpDomain The Op domain.
|
||||
* @param OpName The Op name.
|
||||
* @param OpSet The opset to use for the supported types list.
|
||||
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
|
||||
* @param ArgIndex Index of the given Op kernel argument.
|
||||
*/
|
||||
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE( \
|
||||
OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex) \
|
||||
::boost::mp11::mp_rename< \
|
||||
ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \
|
||||
OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex, SupportedTypeList), \
|
||||
std::tuple>
|
||||
|
||||
/**
|
||||
* std::tuple type with the enabled types for a given Op kernel argument that is valid for all opsets.
|
||||
*
|
||||
* @param OpProvider The Op provider.
|
||||
* @param OpDomain The Op domain.
|
||||
* @param OpName The Op name.
|
||||
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
|
||||
* @param ArgIndex Index of the given Op kernel argument.
|
||||
*/
|
||||
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE_ALL_OPSETS( \
|
||||
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \
|
||||
ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE(OpProvider, OpDomain, OpName, \
|
||||
::onnxruntime::op_kernel_type_control::kAllOpSets, \
|
||||
ArgDirection, ArgIndex)
|
||||
/**
|
||||
* Usage example:
|
||||
*
|
||||
|
|
|
|||
|
|
@ -241,16 +241,17 @@ def _create_operator_type_usage_processors():
|
|||
# - Mobilenet + SSD Mobilenet + MobileBert
|
||||
# - some known large kernels
|
||||
#
|
||||
# Ops we are ignoring currently so as not to produce meaningless output:
|
||||
# Implementation is not type specific:
|
||||
# Ops we are ignoring currently so as not to produce meaningless/unused output:
|
||||
# - Implementation is not type specific:
|
||||
# If, Loop, Reshape, Scan, Shape, Squeeze, Unsqueeze
|
||||
# Only one type supported in the ORT implementation:
|
||||
# FusedConv, FusedGemm, FusedMatMul, TransposeMatMul
|
||||
#
|
||||
default_processor_onnx_ops = ['Add', 'AveragePool', 'BatchNormalization', 'Clip', 'Concat', 'Conv',
|
||||
'DequantizeLinear', 'Div', 'Equal', 'Exp', 'Expand', 'Flatten',
|
||||
# - Only one type supported in the ORT implementation:
|
||||
# FusedConv, FusedGemm, FusedMatMul, TransposeMatMul
|
||||
# - Implementation does not have any significant type specific code:
|
||||
# Concat, Flatten, Not, QLinearConv, Reshape, Shape, Squeeze, Unsqueeze
|
||||
default_processor_onnx_ops = ['Add', 'AveragePool', 'BatchNormalization', 'Clip', 'Conv',
|
||||
'DequantizeLinear', 'Div', 'Equal', 'Exp', 'Expand',
|
||||
'Gemm', 'Greater', 'Less', 'MatMul', 'Max', 'Min', 'Mul',
|
||||
'NonMaxSuppression', 'NonZero', 'Pad', 'QLinearConv', 'Range', 'Relu', 'Resize',
|
||||
'NonMaxSuppression', 'NonZero', 'Pad', 'Range', 'Relu', 'Resize',
|
||||
'Sigmoid', 'Slice', 'Softmax', 'Split', 'Sub', 'Tile', 'TopK', 'Transpose']
|
||||
|
||||
internal_ops = ['QLinearAdd', 'QLinearMul']
|
||||
|
|
@ -271,7 +272,7 @@ def _create_operator_type_usage_processors():
|
|||
[add(DefaultTypeUsageProcessor('com.microsoft', op)) for op in internal_ops]
|
||||
|
||||
#
|
||||
# Operators that require slightly different handling
|
||||
# Operators that require custom handling
|
||||
#
|
||||
add(DefaultTypeUsageProcessor('ai.onnx', 'Cast', inputs=[0], outputs=[0])) # track input0 and output0
|
||||
|
||||
|
|
@ -282,12 +283,16 @@ def _create_operator_type_usage_processors():
|
|||
# Pow dispatches on base and exponential types
|
||||
add(DefaultTypeUsageProcessor('ai.onnx', 'Pow', inputs=[0, 1]))
|
||||
|
||||
# ConstantOfShape switches on size of output type
|
||||
add(DefaultTypeUsageProcessor('ai.onnx', 'ConstantOfShape', inputs=[], outputs=[0]))
|
||||
|
||||
# Random generator ops produce new data so we track the output type
|
||||
onnx_random_ops = ['RandomNormal', 'RandomNormalLike', 'RandomUniform', 'RandomUniformLike', 'Multinomial']
|
||||
[add(DefaultTypeUsageProcessor('ai.onnx', op, inputs=[], outputs=[0])) for op in onnx_random_ops]
|
||||
|
||||
# we only support 'float' as input for QuantizeLinear so just track the output type
|
||||
# we only support 'float' as input for [Dynamic]QuantizeLinear so just track the output type
|
||||
add(DefaultTypeUsageProcessor('ai.onnx', 'QuantizeLinear', inputs=[], outputs=[0]))
|
||||
add(DefaultTypeUsageProcessor('ai.onnx', 'DynamicQuantizeLinear', inputs=[], outputs=[0]))
|
||||
|
||||
# OneHot concatenates type strings into a triple in the typed registration
|
||||
# e.g. float_int64_t_int64_t
|
||||
|
|
|
|||
Loading…
Reference in a new issue