diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 7776e5deb1..a579bbfa17 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -4,6 +4,7 @@ #include "core/framework/data_types_internal.h" #include "core/providers/cpu/math/element_wise_ops.h" #include "core/providers/cpu/tensor/utils.h" +#include "core/providers/op_kernel_type_control.h" #include #include "core/util/math.h" #include "core/mlas/inc/mlas.h" @@ -11,6 +12,44 @@ #include namespace onnxruntime { +// Supported types for operators that have type reduction enabled +namespace op_kernel_type_control { +// Max +ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 8, Input, 0, float, double); + +ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0, + float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t); + +// Min +ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 8, Input, 0, float, double); +ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 12, + Input, 0, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t); + +// Pow +ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 7, Input, 0, float, double); + +// Pow 12 and later has separate Base and Exponent types +ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 12, + Input, 0, int32_t, int64_t, float, double); +ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 12, + Input, 1, int32_t, int64_t, float, double); +} // namespace op_kernel_type_control + +// +// reduce the supported type lists to what's allowed in this build +// +using EnabledMax8Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Max, 8, Input, 0); +using EnabledMax12Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0); + +using EnabledMin8Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 8, Input, 0); +using EnabledMin12Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0); + +using EnabledPow7Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 7, Input, 0); +using EnabledPow12BaseTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, + Pow, 12, Input, 0); +using EnabledPow12ExpTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, + Pow, 12, Input, 1); + namespace functors { template <> void Exp::operator()(std::ptrdiff_t first, std::ptrdiff_t last) const { @@ -56,25 +95,41 @@ void Exp::operator()(std::ptrdiff_t first, std::ptrdiff_t last) const { .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ KERNEL_CLASS); -// var args are type constraints for T and T1 -#define REG_ELEMENTWISE_KERNEL_NONT(OP_TYPE, VERSION, KERNEL_CLASS, ...) \ - ONNX_CPU_OPERATOR_KERNEL( \ - OP_TYPE, \ - VERSION, \ - KernelDefBuilder() \ - .TypeConstraint("T", BuildKernelDefConstraints<__VA_ARGS__>()) \ - .TypeConstraint("T1", BuildKernelDefConstraints<__VA_ARGS__>()), \ +#define REG_ELEMENTWISE_KERNEL_NONT(OP_TYPE, VERSION, KERNEL_CLASS, CONSTRAINTS) \ + ONNX_CPU_OPERATOR_KERNEL( \ + OP_TYPE, \ + VERSION, \ + KernelDefBuilder() \ + .TypeConstraint("T", CONSTRAINTS), \ KERNEL_CLASS); // var args are type constraints for T and T1 -#define REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, ...) \ - ONNX_CPU_OPERATOR_VERSIONED_KERNEL( \ - OP_TYPE, \ - VERSION_FROM, \ - VERSION_TO, \ - KernelDefBuilder() \ - .TypeConstraint("T", BuildKernelDefConstraints<__VA_ARGS__>()) \ - .TypeConstraint("T1", BuildKernelDefConstraints<__VA_ARGS__>()), \ +#define REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, CONSTRAINTS) \ + ONNX_CPU_OPERATOR_VERSIONED_KERNEL( \ + OP_TYPE, \ + VERSION_FROM, \ + VERSION_TO, \ + KernelDefBuilder() \ + .TypeConstraint("T", CONSTRAINTS), \ + KERNEL_CLASS); + +#define REG_ELEMENTWISE_KERNEL_NONT_2(OP_TYPE, VERSION, KERNEL_CLASS, T1_CONSTRAINTS, T2_CONSTRAINTS) \ + ONNX_CPU_OPERATOR_KERNEL( \ + OP_TYPE, \ + VERSION, \ + KernelDefBuilder() \ + .TypeConstraint("T", T1_CONSTRAINTS) \ + .TypeConstraint("T1", T2_CONSTRAINTS), \ + KERNEL_CLASS); + +#define REG_ELEMENTWISE_VERSIONED_KERNEL_NONT_2(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, T1_CONSTRAINTS, T2_CONSTRAINTS) \ + ONNX_CPU_OPERATOR_VERSIONED_KERNEL( \ + OP_TYPE, \ + VERSION_FROM, \ + VERSION_TO, \ + KernelDefBuilder() \ + .TypeConstraint("T", T1_CONSTRAINTS) \ + .TypeConstraint("T1", T2_CONSTRAINTS), \ KERNEL_CLASS); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 7, 12, float, Add); @@ -162,11 +217,14 @@ REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sqrt, 6, 12, double, Sqrt); REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 13, float, Sqrt); REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 13, double, Sqrt); -REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Pow, 7, 11, Pow, float, double); -// To reduce templetization we choose to support the below types for both -// base and the exponent. This gives us 16 permutations -REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Pow, 12, 12, Pow, int32_t, int64_t, float, double); -REG_ELEMENTWISE_KERNEL_NONT(Pow, 13, Pow, int32_t, int64_t, float, double); +const auto pow7_types = BuildKernelDefConstraintsFunctorFromTypeList{}(); +// To reduce templatization we choose to support the below types for both +// base and the exponent. This gives us 16 permutations. +const auto pow12_base_types = BuildKernelDefConstraintsFunctorFromTypeList{}(); +const auto pow12_exp_types = BuildKernelDefConstraintsFunctorFromTypeList{}(); +REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Pow, 7, 11, Pow, pow7_types); +REG_ELEMENTWISE_VERSIONED_KERNEL_NONT_2(Pow, 12, 12, Pow, pow12_base_types, pow12_exp_types); +REG_ELEMENTWISE_KERNEL_NONT_2(Pow, 13, Pow, pow12_base_types, pow12_exp_types); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Exp, 6, 12, float, Exp); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Exp, 6, 12, double, Exp); @@ -187,16 +245,21 @@ REG_ELEMENTWISE_TYPED_KERNEL(Sum, 13, float, Sum_8); REG_ELEMENTWISE_TYPED_KERNEL(Sum, 13, double, Sum_8); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Max, 6, 7, float, Max_6); -REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Max, 8, 11, Max_8, float, double); -REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Max, 12, 12, Max_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t); -// Supposed to add BFloat16 but we are not supporting now, however, separate registration -REG_ELEMENTWISE_KERNEL_NONT(Max, 13, Max_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t); -REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Min, 6, 7, float, Min_6); -REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Min, 8, 11, Min_8, float); -REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Min, 12, 12, Min_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t); +const std::vector max8_types = BuildKernelDefConstraintsFunctorFromTypeList{}(); +const std::vector max12_types = BuildKernelDefConstraintsFunctorFromTypeList{}(); +REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Max, 8, 11, Max_8, max8_types); +REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Max, 12, 12, Max_8, max12_types); // Supposed to add BFloat16 but we are not supporting now, however, separate registration -REG_ELEMENTWISE_KERNEL_NONT(Min, 13, Min_8, float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t); +REG_ELEMENTWISE_KERNEL_NONT(Max, 13, Max_8, max12_types); + +const std::vector min8_types = BuildKernelDefConstraintsFunctorFromTypeList{}(); +const std::vector min12_types = BuildKernelDefConstraintsFunctorFromTypeList{}(); +REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Min, 6, 7, float, Min_6); +REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Min, 8, 11, Min_8, min8_types); +REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Min, 12, 12, Min_8, min12_types); +// Supposed to add BFloat16 but we are not supporting now, however, separate registration +REG_ELEMENTWISE_KERNEL_NONT(Min, 13, Min_8, min12_types); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 7, 8, float, Less); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 7, 8, double, Less); diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index ff9f1a34e2..191b777a40 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -30,7 +30,8 @@ using namespace boost::mp11; namespace onnxruntime { namespace op_kernel_type_control { -ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES( +// we're using one set of types for all opsets of Cast +ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0, bool, float, double, @@ -39,7 +40,7 @@ ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES( MLFloat16, BFloat16, std::string); -ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES( +ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0, bool, float, double, @@ -50,9 +51,10 @@ ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES( } // namespace op_kernel_type_control namespace { - -using EnabledSrcTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0); -using EnabledDstTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0); +using EnabledSrcTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, + Cast, Input, 0); +using EnabledDstTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, + Cast, Output, 0); using IndirectCastTypes = TypeList; diff --git a/onnxruntime/core/providers/cpu/tensor/isinf.cc b/onnxruntime/core/providers/cpu/tensor/isinf.cc index 598cf312ab..b03298e356 100644 --- a/onnxruntime/core/providers/cpu/tensor/isinf.cc +++ b/onnxruntime/core/providers/cpu/tensor/isinf.cc @@ -14,14 +14,15 @@ namespace onnxruntime { // https://github.com/onnx/onnx/blob/master/docs/Operators.md#IsInf namespace op_kernel_type_control { -ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES( +ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0, float, double); } // namespace op_kernel_type_control class IsInf final : public OpKernel { public: - using EnabledTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0); + using EnabledTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, + IsInf, Input, 0); explicit IsInf(const OpKernelInfo& info); Status Compute(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/op_kernel_type_control.h b/onnxruntime/core/providers/op_kernel_type_control.h index beb0eb234d..72c78165d6 100644 --- a/onnxruntime/core/providers/op_kernel_type_control.h +++ b/onnxruntime/core/providers/op_kernel_type_control.h @@ -37,6 +37,9 @@ enum class OpArgDirection { using OpArgIndex = size_t; +// constant to use for type lists that are valid across all opsets +constexpr int kAllOpSets = -1; + namespace tags { // a tag that identifies the target (Op argument) of the specified types @@ -44,8 +47,9 @@ template struct OpArg {}; // a tag that indicates the supported types for a particular Op argument, identified by OpArgTag, -// for a kernel in a particular provider, identified by ProviderTag -template +// for a kernel in a particular provider, identified by ProviderTag. as the types may change between opsets, +// the opset must also be specified. if the type list is not opset specific, use kAllOpSets as the value. +template struct Supported {}; // a tag that indicates the allowed types for a particular Op argument, identified by OpArgTag @@ -151,39 +155,64 @@ struct EnabledTypes { * @param OpProvider The Op provider. * @param OpDomain The Op domain. * @param OpName The Op name. + * @param OpSet The opset that this set of supported types applies to. * @param ArgDirection Direction of the given Op kernel argument - Input or Output. * @param ArgIndex Index of the given Op kernel argument. * @param ... The types. */ #define ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES( \ - OpProvider, OpDomain, OpName, ArgDirection, ArgIndex, ...) \ + OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex, ...) \ class ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_TAG_CLASS_NAME(OpDomain, OpName); \ class ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider); \ template <> \ struct TypesHolder< \ ::onnxruntime::op_kernel_type_control::tags::Supported< \ ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG(OpDomain, OpName, ArgDirection, ArgIndex), \ - ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider)>> { \ + ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider), \ + OpSet>> { \ using types = ::onnxruntime::TypeList<__VA_ARGS__>; \ }; /** - * TypeList type with the enabled types for a given Op kernel argument. + * Specifies a supported set of types for a given Op kernel argument that is valid for all opsets. + * This should be specified with the Op kernel implementation. + * + * Note: This should be called from the onnxruntime::op_kernel_type_control namespace. * * @param OpProvider The Op provider. * @param OpDomain The Op domain. * @param OpName The Op name. * @param ArgDirection Direction of the given Op kernel argument - Input or Output. * @param ArgIndex Index of the given Op kernel argument. + * @param ... The types. + */ +#define ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS( \ + OpProvider, OpDomain, OpName, ArgDirection, ArgIndex, ...) \ + ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(OpProvider, OpDomain, OpName, \ + ::onnxruntime::op_kernel_type_control::kAllOpSets, \ + ArgDirection, ArgIndex, __VA_ARGS__) + +/** + * TypeList type with the enabled types for a given Op kernel argument. + * This is created by intersecting the supported types with any type restrictions coming from the allowed or global + * type lists. + * + * @param OpProvider The Op provider. + * @param OpDomain The Op domain. + * @param OpName The Op name. + * @param OpSet The opset to use for the supported types list. + * @param ArgDirection Direction of the given Op kernel argument - Input or Output. + * @param ArgIndex Index of the given Op kernel argument. */ #define ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \ - OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \ - ::onnxruntime::op_kernel_type_control::EnabledTypes< \ + OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex) \ + ::onnxruntime::op_kernel_type_control::EnabledTypes< \ ::onnxruntime::op_kernel_type_control::TypesHolder< \ ::onnxruntime::op_kernel_type_control::tags::Supported< \ ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG(OpDomain, OpName, ArgDirection, ArgIndex), \ ::onnxruntime::op_kernel_type_control:: \ - ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider)>>, \ + ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider), \ + OpSet>>, \ ::onnxruntime::TypeList< \ ::onnxruntime::op_kernel_type_control::TypesHolder< \ ::onnxruntime::op_kernel_type_control::tags::Allowed< \ @@ -192,7 +221,9 @@ struct EnabledTypes { ::onnxruntime::op_kernel_type_control::tags::GlobalAllowed>>>::types /** - * std::tuple type with the enabled types for a given Op kernel argument. + * TypeList type with the enabled types for a given Op kernel argument that is valid for all opsets. + * This is created by intersecting the supported types with any type restrictions coming from the allowed or global + * type lists. * * @param OpProvider The Op provider. * @param OpDomain The Op domain. @@ -200,13 +231,43 @@ struct EnabledTypes { * @param ArgDirection Direction of the given Op kernel argument - Input or Output. * @param ArgIndex Index of the given Op kernel argument. */ -#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE( \ - OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \ - ::boost::mp11::mp_rename< \ - ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \ - OpProvider, OpDomain, OpName, ArgDirection, ArgIndex, SupportedTypeList), \ +#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS( \ + OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \ + ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(OpProvider, OpDomain, OpName, \ + ::onnxruntime::op_kernel_type_control::kAllOpSets, \ + ArgDirection, ArgIndex) + +/** + * std::tuple type with the enabled types for a given Op kernel argument. + * + * @param OpProvider The Op provider. + * @param OpDomain The Op domain. + * @param OpName The Op name. + * @param OpSet The opset to use for the supported types list. + * @param ArgDirection Direction of the given Op kernel argument - Input or Output. + * @param ArgIndex Index of the given Op kernel argument. + */ +#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE( \ + OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex) \ + ::boost::mp11::mp_rename< \ + ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \ + OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex, SupportedTypeList), \ std::tuple> +/** + * std::tuple type with the enabled types for a given Op kernel argument that is valid for all opsets. + * + * @param OpProvider The Op provider. + * @param OpDomain The Op domain. + * @param OpName The Op name. + * @param ArgDirection Direction of the given Op kernel argument - Input or Output. + * @param ArgIndex Index of the given Op kernel argument. + */ +#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE_ALL_OPSETS( \ + OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \ + ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE(OpProvider, OpDomain, OpName, \ + ::onnxruntime::op_kernel_type_control::kAllOpSets, \ + ArgDirection, ArgIndex) /** * Usage example: * diff --git a/tools/python/util/ort_format_model/operator_type_usage_processors.py b/tools/python/util/ort_format_model/operator_type_usage_processors.py index 47a2dd460e..035f553744 100644 --- a/tools/python/util/ort_format_model/operator_type_usage_processors.py +++ b/tools/python/util/ort_format_model/operator_type_usage_processors.py @@ -241,16 +241,17 @@ def _create_operator_type_usage_processors(): # - Mobilenet + SSD Mobilenet + MobileBert # - some known large kernels # - # Ops we are ignoring currently so as not to produce meaningless output: - # Implementation is not type specific: + # Ops we are ignoring currently so as not to produce meaningless/unused output: + # - Implementation is not type specific: # If, Loop, Reshape, Scan, Shape, Squeeze, Unsqueeze - # Only one type supported in the ORT implementation: - # FusedConv, FusedGemm, FusedMatMul, TransposeMatMul - # - default_processor_onnx_ops = ['Add', 'AveragePool', 'BatchNormalization', 'Clip', 'Concat', 'Conv', - 'DequantizeLinear', 'Div', 'Equal', 'Exp', 'Expand', 'Flatten', + # - Only one type supported in the ORT implementation: + # FusedConv, FusedGemm, FusedMatMul, TransposeMatMul + # - Implementation does not have any significant type specific code: + # Concat, Flatten, Not, QLinearConv, Reshape, Shape, Squeeze, Unsqueeze + default_processor_onnx_ops = ['Add', 'AveragePool', 'BatchNormalization', 'Clip', 'Conv', + 'DequantizeLinear', 'Div', 'Equal', 'Exp', 'Expand', 'Gemm', 'Greater', 'Less', 'MatMul', 'Max', 'Min', 'Mul', - 'NonMaxSuppression', 'NonZero', 'Pad', 'QLinearConv', 'Range', 'Relu', 'Resize', + 'NonMaxSuppression', 'NonZero', 'Pad', 'Range', 'Relu', 'Resize', 'Sigmoid', 'Slice', 'Softmax', 'Split', 'Sub', 'Tile', 'TopK', 'Transpose'] internal_ops = ['QLinearAdd', 'QLinearMul'] @@ -271,7 +272,7 @@ def _create_operator_type_usage_processors(): [add(DefaultTypeUsageProcessor('com.microsoft', op)) for op in internal_ops] # - # Operators that require slightly different handling + # Operators that require custom handling # add(DefaultTypeUsageProcessor('ai.onnx', 'Cast', inputs=[0], outputs=[0])) # track input0 and output0 @@ -282,12 +283,16 @@ def _create_operator_type_usage_processors(): # Pow dispatches on base and exponential types add(DefaultTypeUsageProcessor('ai.onnx', 'Pow', inputs=[0, 1])) + # ConstantOfShape switches on size of output type + add(DefaultTypeUsageProcessor('ai.onnx', 'ConstantOfShape', inputs=[], outputs=[0])) + # Random generator ops produce new data so we track the output type onnx_random_ops = ['RandomNormal', 'RandomNormalLike', 'RandomUniform', 'RandomUniformLike', 'Multinomial'] [add(DefaultTypeUsageProcessor('ai.onnx', op, inputs=[], outputs=[0])) for op in onnx_random_ops] - # we only support 'float' as input for QuantizeLinear so just track the output type + # we only support 'float' as input for [Dynamic]QuantizeLinear so just track the output type add(DefaultTypeUsageProcessor('ai.onnx', 'QuantizeLinear', inputs=[], outputs=[0])) + add(DefaultTypeUsageProcessor('ai.onnx', 'DynamicQuantizeLinear', inputs=[], outputs=[0])) # OneHot concatenates type strings into a triple in the typed registration # e.g. float_int64_t_int64_t