diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 4247dc7845..08b6a31938 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -59,21 +59,21 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int8_t, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int32_t, Neg); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Floor); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Ceil); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Reciprocal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Floor); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Ceil); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Reciprocal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Pow); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Pow); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Exp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Exp); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Log); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Sum); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Sum); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Min); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Min); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Max); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Log); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Sum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Sum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Min); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Min); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Max); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Max); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Max); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Not); @@ -86,8 +86,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, Equal); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Mean); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Mean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Mean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Mean); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Sin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Sin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Cos); @@ -216,6 +216,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Com class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantOfShape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MeanVarianceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Cast); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike); @@ -223,7 +224,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sign); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Shrink); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, Erf); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int64_t_int64_t, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_int64_t_int64_t, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_string_int64_t, OneHot); @@ -334,21 +335,21 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -361,8 +362,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -491,6 +492,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -498,7 +500,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index e76ce2f8cd..ece68834e7 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -10,267 +10,100 @@ namespace onnxruntime { -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Add, - 7, - float, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Add); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Add, - 7, - double, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Add); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Add, - 7, - int32_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Add); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Add, - 7, - int64_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Add); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Sub, - 7, - float, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Sub); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Sub, - 7, - double, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Sub); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Sub, - 7, - int32_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Sub); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Sub, - 7, - int64_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Sub); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Mul, - 7, - float, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Mul); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Mul, - 7, - double, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Mul); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Mul, - 7, - int32_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Mul); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Mul, - 7, - int64_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Mul); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Div, - 7, - float, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Div); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Div, - 7, - double, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Div); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Div, - 7, - int32_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Div); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Div, - 7, - int64_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Div); - -#define REG_ABS_KERNEL(TYPE) \ +#define REG_ELEMENTWISE_TYPED_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \ ONNX_CPU_OPERATOR_TYPED_KERNEL( \ - Abs, \ - 6, \ + OP_TYPE, \ + VERSION, \ TYPE, \ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Abs); + KERNEL_CLASS); -REG_ABS_KERNEL(float) -REG_ABS_KERNEL(double) -REG_ABS_KERNEL(int8_t) -REG_ABS_KERNEL(int16_t) -REG_ABS_KERNEL(int32_t) -REG_ABS_KERNEL(int64_t) -REG_ABS_KERNEL(uint8_t) -REG_ABS_KERNEL(uint16_t) -REG_ABS_KERNEL(uint32_t) -REG_ABS_KERNEL(uint64_t) +#define REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \ + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ + OP_TYPE, \ + VERSION_FROM, VERSION_TO, \ + TYPE, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + KERNEL_CLASS); -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Neg, - 6, - float, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Neg); +REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, float, Add); +REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, double, Add); +REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, int32_t, Add); +REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, int64_t, Add); -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Neg, - 6, - int8_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Neg); +REG_ELEMENTWISE_TYPED_KERNEL(Sub, 7, float, Sub); +REG_ELEMENTWISE_TYPED_KERNEL(Sub, 7, double, Sub); +REG_ELEMENTWISE_TYPED_KERNEL(Sub, 7, int32_t, Sub); +REG_ELEMENTWISE_TYPED_KERNEL(Sub, 7, int64_t, Sub); -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Neg, - 6, - int32_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Neg); +REG_ELEMENTWISE_TYPED_KERNEL(Mul, 7, float, Mul); +REG_ELEMENTWISE_TYPED_KERNEL(Mul, 7, double, Mul); +REG_ELEMENTWISE_TYPED_KERNEL(Mul, 7, int32_t, Mul); +REG_ELEMENTWISE_TYPED_KERNEL(Mul, 7, int64_t, Mul); -ONNX_CPU_OPERATOR_KERNEL( - Floor, - 6, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Floor); +REG_ELEMENTWISE_TYPED_KERNEL(Div, 7, float, Div); +REG_ELEMENTWISE_TYPED_KERNEL(Div, 7, double, Div); +REG_ELEMENTWISE_TYPED_KERNEL(Div, 7, int32_t, Div); +REG_ELEMENTWISE_TYPED_KERNEL(Div, 7, int64_t, Div); -ONNX_CPU_OPERATOR_KERNEL( - Ceil, - 6, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Ceil); +REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, float, Abs); +REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, double, Abs); +REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, int8_t, Abs); +REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, int16_t, Abs); +REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, int32_t, Abs); +REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, int64_t, Abs); +REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, uint8_t, Abs); +REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, uint16_t, Abs); +REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, uint32_t, Abs); +REG_ELEMENTWISE_TYPED_KERNEL(Abs, 6, uint64_t, Abs); -ONNX_CPU_OPERATOR_KERNEL( - Reciprocal, - 6, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Reciprocal); +REG_ELEMENTWISE_TYPED_KERNEL(Neg, 6, float, Neg); +REG_ELEMENTWISE_TYPED_KERNEL(Neg, 6, int8_t, Neg); +REG_ELEMENTWISE_TYPED_KERNEL(Neg, 6, int32_t, Neg); -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Sqrt, - 6, - float, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Sqrt); +REG_ELEMENTWISE_TYPED_KERNEL(Floor, 6, float, Floor); -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Sqrt, - 6, - double, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Sqrt); +REG_ELEMENTWISE_TYPED_KERNEL(Ceil, 6, float, Ceil); -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Pow, - 7, - float, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Pow); +REG_ELEMENTWISE_TYPED_KERNEL(Reciprocal, 6, float, Reciprocal); -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Pow, - 7, - double, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Pow); +REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 6, float, Sqrt); +REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 6, double, Sqrt); -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Exp, - 6, - float, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Exp); +REG_ELEMENTWISE_TYPED_KERNEL(Pow, 7, float, Pow); +REG_ELEMENTWISE_TYPED_KERNEL(Pow, 7, double, Pow); -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Exp, - 6, - double, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Exp); +REG_ELEMENTWISE_TYPED_KERNEL(Exp, 6, float, Exp); +REG_ELEMENTWISE_TYPED_KERNEL(Exp, 6, double, Exp); -ONNX_CPU_OPERATOR_KERNEL( - Log, - 6, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Log); +REG_ELEMENTWISE_TYPED_KERNEL(Log, 6, float, Log); -ONNX_CPU_OPERATOR_VERSIONED_KERNEL( - Sum, - 6, 7, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Sum_6); +REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sum, 6, 7, float, Sum_6); +REG_ELEMENTWISE_TYPED_KERNEL(Sum, 8, float, Sum_8); -ONNX_CPU_OPERATOR_KERNEL( - Sum, - 8, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Sum_8); +REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Min, 6, 7, float, Min_6); +REG_ELEMENTWISE_TYPED_KERNEL(Min, 8, float, Min_8); -ONNX_CPU_OPERATOR_VERSIONED_KERNEL( - Min, - 6, 7, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Min_6); +REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Max, 6, 7, float, Max_6); +REG_ELEMENTWISE_TYPED_KERNEL(Max, 8, float, Max_8); +REG_ELEMENTWISE_TYPED_KERNEL(Max, 8, double, Max_8); -ONNX_CPU_OPERATOR_KERNEL( - Min, - 8, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Min_8); +REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Less, 7, 9, float, Less); +REG_ELEMENTWISE_TYPED_KERNEL(Less, 9, int32_t, Less); -ONNX_CPU_OPERATOR_VERSIONED_KERNEL( - Max, - 6, 7, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Max_6); +REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Greater, 7, 9, float, Greater) +REG_ELEMENTWISE_TYPED_KERNEL(Greater, 9, int32_t, Greater); +REG_ELEMENTWISE_TYPED_KERNEL(Greater, 9, int64_t, Greater); -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Max, - 8, - float, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Max_8); +REG_ELEMENTWISE_TYPED_KERNEL(Equal, 7, bool, Equal); +REG_ELEMENTWISE_TYPED_KERNEL(Equal, 7, int32_t, Equal); +REG_ELEMENTWISE_TYPED_KERNEL(Equal, 7, int64_t, Equal); +REG_ELEMENTWISE_TYPED_KERNEL(Equal, 11, float, Equal); -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Max, - 8, - double, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Max_8); +REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mean, 6, 7, float, Mean_6); +REG_ELEMENTWISE_TYPED_KERNEL(Mean, 8, float, Mean_8); + +REG_ELEMENTWISE_TYPED_KERNEL(Erf, 9, float, Erf); ONNX_CPU_OPERATOR_KERNEL( Not, @@ -296,80 +129,6 @@ ONNX_CPU_OPERATOR_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Xor); -ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( - Less, - 7, 9, - float, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Less); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Less, - 9, - int32_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Less); - -ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( - Greater, - 7, 9, - float, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Greater); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Greater, - 9, - int32_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Greater); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Equal, - 7, - bool, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Equal); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Equal, - 7, - int32_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Equal); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Equal, - 7, - int64_t, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Equal); - -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Equal, - 11, - float, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Equal); - -ONNX_CPU_OPERATOR_VERSIONED_KERNEL( - Mean, - 6, 7, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Mean_6); - -ONNX_CPU_OPERATOR_KERNEL( - Mean, - 8, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Mean_8); - -ONNX_CPU_OPERATOR_KERNEL( - Erf, - 9, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Erf); - template Status Add::Compute(OpKernelContext* context) const { return BroadcastTwo( diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 81762ff5a1..bfff81c5be 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -806,7 +806,7 @@ TEST(MathOpTest, Less_Scalar1) { test.Run(); } -TEST(MathOpTest, Greater) { +TEST(MathOpTest, Greater_7) { OpTester test("Greater"); std::vector dims{4}; test.AddInput("A", dims, {1.0f, 0.0f, -1.0f, -1.0f}); @@ -815,6 +815,36 @@ TEST(MathOpTest, Greater) { test.Run(); } +TEST( MathOpTest, Greater_9_float ) +{ + OpTester test( "Greater", 9 ); + std::vector dims { 4 }; + test.AddInput( "A", dims, { 1.0f, 0.0f, -1.0f, -1.0f } ); + test.AddInput( "B", dims, { 1.0f, 1.0f, 2.0f, -1.0f } ); + test.AddOutput( "C", dims, { false, false, false, false } ); + test.Run(); +} + +TEST( MathOpTest, Greater_9_int32 ) +{ + OpTester test( "Greater", 9 ); + std::vector dims { 4 }; + test.AddInput( "A", dims, { 10, 11, 12, 13 } ); + test.AddInput( "B", dims, { 15, 7, 12, 9 } ); + test.AddOutput( "C", dims, { false, true, false, true } ); + test.Run(); +} + +TEST( MathOpTest, Greater_9_int64 ) +{ + OpTester test( "Greater", 9 ); + std::vector dims { 4 }; + test.AddInput( "A", dims, { 10, 11, 12, 13 } ); + test.AddInput( "B", dims, { 15, 7, 12, 9 } ); + test.AddOutput( "C", dims, { false, true, false, true } ); + test.Run(); +} + TEST(MathOpTest, Equal_bool) { OpTester test("Equal"); std::vector dims{4};