diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 6938ecc37b..4247dc7845 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -31,6 +31,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Ran class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, RandomUniformLike); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Multinomial); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Add); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Sub); @@ -61,16 +62,20 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Floor); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Ceil); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Reciprocal); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Sqrt); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Pow); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Exp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Sqrt); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Sqrt); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Pow); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Pow); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Exp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Exp); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Log); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Sum); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Sum); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Min); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Min); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Max); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Max); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Max); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Max); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Not); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, And); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Or); @@ -83,7 +88,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, Equal); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, Mean); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Mean); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Sin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Sin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, double, Sin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Cos); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Tan); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Asin); @@ -133,8 +139,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ReduceSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, ArgMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, ArgMax); @@ -298,6 +306,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -328,16 +337,20 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -350,7 +363,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -401,8 +415,10 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 32b877b82d..e76ce2f8cd 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -17,6 +17,13 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Add); +ONNX_CPU_OPERATOR_TYPED_KERNEL( + Add, + 7, + double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Add); + ONNX_CPU_OPERATOR_TYPED_KERNEL( Add, 7, @@ -173,24 +180,48 @@ ONNX_CPU_OPERATOR_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Reciprocal); -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_TYPED_KERNEL( Sqrt, 6, + float, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Sqrt); -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_TYPED_KERNEL( + Sqrt, + 6, + double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Sqrt); + +ONNX_CPU_OPERATOR_TYPED_KERNEL( Pow, 7, + float, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Pow); -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_TYPED_KERNEL( + Pow, + 7, + double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Pow); + +ONNX_CPU_OPERATOR_TYPED_KERNEL( Exp, 6, + float, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Exp); +ONNX_CPU_OPERATOR_TYPED_KERNEL( + Exp, + 6, + double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Exp); + ONNX_CPU_OPERATOR_KERNEL( Log, 6, @@ -227,12 +258,20 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Max_6); -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_TYPED_KERNEL( Max, 8, + float, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Max_8); +ONNX_CPU_OPERATOR_TYPED_KERNEL( + Max, + 8, + double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Max_8); + ONNX_CPU_OPERATOR_KERNEL( Not, 1, @@ -397,43 +436,43 @@ Status Reciprocal::Compute(OpKernelContext* ctx) const { return Status::OK(); } -template <> -Status Sqrt::Compute(OpKernelContext* ctx) const { +template +Status Sqrt::Compute(OpKernelContext* ctx) const { auto& X = *ctx->Input(0); auto& Y = *ctx->Output(0, X.Shape()); - EigenMap(Y) = EigenMap(X).cwiseSqrt(); + EigenMap(Y) = EigenMap(X).cwiseSqrt(); return Status::OK(); } -template <> -Status Pow::Compute(OpKernelContext* context) const { +template +Status Pow::Compute(OpKernelContext* context) const { const Tensor& Y = *context->Input(1); - std::function, ConstEigenVectorMap, float)> input1scalar = - [](EigenVectorMap output, ConstEigenVectorMap input0, float input1) { output = Eigen::pow(input0.array(), input1); }; + std::function, ConstEigenVectorMap, T)> input1scalar = + [](EigenVectorMap output, ConstEigenVectorMap input0, T input1) { output = Eigen::pow(input0.array(), input1); }; if (Y.Shape().Size() == 1) { - float value = *Y.Data(); + T value = *Y.Data(); if (value == 2.0) { - input1scalar = [](EigenVectorMap output, ConstEigenVectorMap input0, float) { output = Eigen::square(input0.array()); }; + input1scalar = [](EigenVectorMap output, ConstEigenVectorMap input0, T) { output = Eigen::square(input0.array()); }; } else if (value == 3.0) { - input1scalar = [](EigenVectorMap output, ConstEigenVectorMap input0, float) { output = Eigen::cube(input0.array()); }; + input1scalar = [](EigenVectorMap output, ConstEigenVectorMap input0, T) { output = Eigen::cube(input0.array()); }; } } - return BroadcastTwo( + return BroadcastTwo( *context, - [](EigenVectorMap output, float input0, ConstEigenVectorMap input1) { output = Eigen::pow(input0, input1.array()); }, + [](EigenVectorMap output, T input0, ConstEigenVectorMap input1) { output = Eigen::pow(input0, input1.array()); }, input1scalar, - [](EigenVectorMap output, ConstEigenVectorMap input0, ConstEigenVectorMap input1) { output = Eigen::pow(input0.array(), input1.array()); }); + [](EigenVectorMap output, ConstEigenVectorMap input0, ConstEigenVectorMap input1) { output = Eigen::pow(input0.array(), input1.array()); }); } -template <> -Status Exp::Compute(OpKernelContext* ctx) const { +template +Status Exp::Compute(OpKernelContext* ctx) const { auto& X = *ctx->Input(0); auto& Y = *ctx->Output(0, X.Shape()); - EigenMap(Y) = EigenMap(X).array().exp(); + EigenMap(Y) = EigenMap(X).array().exp(); return Status::OK(); } @@ -527,13 +566,13 @@ Status Max_6::Compute(OpKernelContext* ctx) const { return Status::OK(); } -template <> -Status Max_8::Compute(OpKernelContext* context) const { - return BroadcastVariadic( +template +Status Max_8::Compute(OpKernelContext* context) const { + return BroadcastVariadic( Node(), *context, - [](EigenVectorMap output, float input0, ConstEigenVectorMap input1) { output = input1.array().max(input0); }, - [](EigenVectorMap output, ConstEigenVectorMap input0, float input1) { output = input0.array().max(input1); }, - [](EigenVectorMap output, ConstEigenVectorMap input0, ConstEigenVectorMap input1) { output = input0.array().max(input1.array()); }); + [](EigenVectorMap output, T input0, ConstEigenVectorMap input1) { output = input1.array().max(input0); }, + [](EigenVectorMap output, ConstEigenVectorMap input0, T input1) { output = input0.array().max(input1); }, + [](EigenVectorMap output, ConstEigenVectorMap input0, ConstEigenVectorMap input1) { output = input0.array().max(input1.array()); }); } Status Not::Compute(OpKernelContext* context) const { @@ -682,17 +721,25 @@ class Sin final : public OpKernel { Status Compute(OpKernelContext* context) const override { auto& X = *context->Input(0); auto& Y = *context->Output(0, X.Shape()); - MakeEigenArrayMap(Y) = MakeEigenArrayMap(X).sin(); + MakeEigenArrayMap(Y) = MakeEigenArrayMap(X).sin(); return Status::OK(); } }; -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_TYPED_KERNEL( Sin, 7, + float, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Sin); +ONNX_CPU_OPERATOR_TYPED_KERNEL( + Sin, + 7, + double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Sin); + template class Cos final : public OpKernel { public: diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc index 007752f83d..8c9143a238 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc @@ -22,6 +22,14 @@ namespace onnxruntime { KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ x); +#define REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(x, sinceVersion) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + x, \ + sinceVersion, \ + double, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + x); + REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL1, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL2, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceLogSum, 1); @@ -31,7 +39,9 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMean, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMin, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceProd, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceSum, 1); +REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceSum, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceSumSquare, 1); +REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceSumSquare, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMax, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMin, 1); diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 453f01cf37..81762ff5a1 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -26,7 +26,7 @@ TEST(MathOpTest, Add_int64) { test.Run(); } -TEST(MathOpTest, Add) { +TEST(MathOpTest, Add_float) { OpTester test("Add"); std::vector dims{3, 3}; test.AddInput("A", dims, @@ -49,6 +49,25 @@ TEST(MathOpTest, Add) { #endif } +TEST(MathOpTest, Add_double) { + OpTester test("Add"); + std::vector dims{3, 3}; + test.AddInput("A", dims, + {1.0, 2.0, -1.0, + 0.0, 1.5, -100.0, + -5.4, 9.3, -10'000.0}); + test.AddInput("B", dims, + {-1.0, 4.4, 432.3, + 0.0, 3.5, 64.0, + -5.4, 9.3, 10'000.0}); + test.AddOutput("C", dims, + {0.0, 6.4, 431.3, + 0.0, 5.0, -36.0, + -10.8, 18.6, 0.0}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // Disabling OpenVINO as this type is not supported +} + TEST(MathOpTest, Add_Broadcast_Axis) { OpTester test("Add"); @@ -384,7 +403,7 @@ TEST(MathOpTest, Reciprocal) { test.Run(); } -TEST(MathOpTest, Sqrt) { +TEST(MathOpTest, Sqrt_Float) { OpTester test("Sqrt"); std::vector dims{2, 2}; test.AddInput("X", dims, @@ -396,7 +415,19 @@ TEST(MathOpTest, Sqrt) { test.Run(); } -TEST(MathOpTest, Pow) { +TEST(MathOpTest, Sqrt_Double) { + OpTester test("Sqrt"); + std::vector dims{2, 2}; + test.AddInput("X", dims, + {1.0, 4.0, + 0.0, 9.0}); + test.AddOutput("Y", dims, + {1.0, 2.0, + 0.0, 3.0}); + test.Run(); +} + +TEST(MathOpTest, Pow_Float) { OpTester test("Pow"); std::vector dims{2, 2}; test.AddInput("X", dims, @@ -411,6 +442,21 @@ TEST(MathOpTest, Pow) { test.Run(); } +TEST(MathOpTest, Pow_Double) { + OpTester test("Pow"); + std::vector dims{2, 2}; + test.AddInput("X", dims, + {2.0, 2.0, + std::sqrt(2.0), 1.0}); + test.AddInput("Y", dims, + {0.0, 8.0, + 2.0, 9.0}); + test.AddOutput("Z", dims, + {1.0, 256.0, + 2.0, 1.0}); + test.Run(); +} + TEST(MathOpTest, Pow_Broadcast_Scalar0) { OpTester test("Pow"); @@ -431,7 +477,7 @@ TEST(MathOpTest, Pow_Broadcast_Scalar1) { test.Run(); } -TEST(MathOpTest, Exp) { +TEST(MathOpTest, Exp_float) { OpTester test("Exp"); std::vector dims{2, 2}; test.AddInput("X", dims, @@ -444,6 +490,21 @@ TEST(MathOpTest, Exp) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: result differs } +TEST(MathOpTest, Exp_double) { + OpTester test("Exp"); + std::vector dims{2, 2}; + test.AddInput("X", dims, + {0.0, 1.0, + 2.0, 10.0}); + test.AddOutput("Y", dims, + {1.0, std::exp(1.0), + std::exp(2.0), std::exp(10.0)}); + test.SetOutputRelErr("Y", 1e-7f); + // TODO: Check if this test's result really differs for tensorRT + // For now basing this exclusion based on this test's float counterpart - Exp_float + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + TEST(MathOpTest, Log) { OpTester test("Log"); std::vector dims{2, 2}; @@ -610,7 +671,7 @@ TEST(MathOpTest, Max_6) { test.Run(); } -TEST(MathOpTest, Max_8) { +TEST(MathOpTest, Max_8_Float) { OpTester test("Max", 8); test.AddInput("data_0", {1, 3}, {1.0f, 2.0f, 3.0f}); @@ -627,6 +688,23 @@ TEST(MathOpTest, Max_8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: Input batch size is inconsistent } +TEST(MathOpTest, Max_8_Double) { + OpTester test("Max", 8); + test.AddInput("data_0", {1, 3}, + {1.0, 2.0, 3.0}); + test.AddInput("data_2", {3, 3}, + {10.0, 20.0, 30.0, + 40.0, 50.0, 60.0, + 70.0, 80.0, 90.0}); + test.AddInput("data_1", {3, 1}, + {-1.0, -2.0, 300.0}); + test.AddOutput("max", {3, 3}, + {10.0, 20.0, 30.0, + 40.0, 50.0, 60.0, + 300.0, 300.0, 300.0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: Input batch size is inconsistent +} + TEST(MathOpTest, Max_8_2inputbroadcast) { OpTester test("Max", 8); test.AddInput("data_0", {1, 3}, @@ -828,7 +906,7 @@ TEST(MathOpTest, Mean_8) { } template -void TrigTest(OpTester& test, std::initializer_list input) { +void TrigFloatTest(OpTester& test, std::initializer_list input) { std::vector dims{static_cast(input.size())}; std::vector output; @@ -840,59 +918,77 @@ void TrigTest(OpTester& test, std::initializer_list input) { test.Run(); } -TEST(MathOpTest, Sin) { +template +void TrigDoubleTest(OpTester& test, std::initializer_list input) { + std::vector dims{static_cast(input.size())}; + + std::vector output; + for (auto v : input) + output.push_back(op(v)); + + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + test.Run(); +} + +TEST(MathOpTest, SinFloat) { OpTester test("Sin"); - TrigTest(test, {1.1f, -1.1f, 2.2f, -2.2f}); + TrigFloatTest(test, {1.1f, -1.1f, 2.2f, -2.2f}); +} + +TEST(MathOpTest, SinDouble) { + OpTester test("Sin"); + TrigDoubleTest(test, {1.1, -1.1, 2.2, -2.2}); } TEST(MathOpTest, Cos) { OpTester test("Cos"); - TrigTest(test, {1.1f, -1.1f, 2.2f, -2.2f}); + TrigFloatTest(test, {1.1f, -1.1f, 2.2f, -2.2f}); } TEST(MathOpTest, Tan) { OpTester test("Tan"); - TrigTest(test, {-100.0f, -50.0f, 0.0f, 50.0f, 100.0f}); + TrigFloatTest(test, {-100.0f, -50.0f, 0.0f, 50.0f, 100.0f}); } TEST(MathOpTest, Asin) { OpTester test("Asin"); - TrigTest(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}); + TrigFloatTest(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}); } TEST(MathOpTest, Acos) { OpTester test("Acos"); - TrigTest(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}); + TrigFloatTest(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}); } TEST(MathOpTest, Atan) { OpTester test("Atan"); - TrigTest(test, {-10.0f, -5.0f, 0.0f, 5.0f, 10.0f}); + TrigFloatTest(test, {-10.0f, -5.0f, 0.0f, 5.0f, 10.0f}); } TEST(MathOpTest, Sinh) { OpTester test("Sinh", 9); - TrigTest(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}); + TrigFloatTest(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}); } TEST(MathOpTest, Cosh) { OpTester test("Cosh", 9); - TrigTest(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}); + TrigFloatTest(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}); } TEST(MathOpTest, Asinh) { OpTester test("Asinh", 9); - TrigTest(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}); + TrigFloatTest(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}); } TEST(MathOpTest, Acosh) { OpTester test("Acosh", 9); - TrigTest(test, {1.0f, 1.1f, 3.0f, 10.0f, 100.0f}); + TrigFloatTest(test, {1.0f, 1.1f, 3.0f, 10.0f, 100.0f}); } TEST(MathOpTest, Atanh) { OpTester test("Atanh", 9); - TrigTest(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}); + TrigFloatTest(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}); } TEST(MathOpTest, Expand_8_3x3) { @@ -999,9 +1095,9 @@ TEST(MathOpTest, Expand_8_3x1x3x1_int64) { test.AddInput("data_0", {1, 3, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); test.AddInput("data_1", {4}, {3, 1, 3, 1}); test.AddOutput("result", {3, 3, 3, 3}, - {1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9, - 1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9, - 1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,}); + {1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9, + 1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9, + 1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,}); test.Run(); } diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 9a5424ff1a..3a7e0ad762 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -676,6 +676,23 @@ TEST(ReductionOpTest, ReduceSum) { test.Run(); } +TEST(ReductionOpTest, ReduceSum_double) { + OpTester test("ReduceSum"); + test.AddAttribute("axes", std::vector{0, 2}); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1.0, 2.0, + 3.0, 4.0, + + 5.0, 6.0, + 7.0, 8.0, + + 9.0, 10.0, + 11.0, 12.0}); + test.AddOutput("reduced", {1, 2, 1}, {33.0, 45.0}); + test.Run(); +} + TEST(ReductionOpTest, ReduceSum_axes01) { OpTester test("ReduceSum"); test.AddAttribute("axes", std::vector{2}); @@ -798,6 +815,23 @@ TEST(ReductionOpTest, ReduceSumSquare) { test.Run(); } +TEST(ReductionOpTest, ReduceSumSquare_double) { + OpTester test("ReduceSumSquare"); + test.AddAttribute("axes", std::vector{0, 2}); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1.0, 2.0, + 3.0, 4.0, + + 5.0, 6.0, + 7.0, 8.0, + + 9.0, 10.0, + 11.0, 12.0}); + test.AddOutput("reduced", {1, 2, 1}, {247.0, 403.}); + test.Run(); +} + TEST(ReductionOpTest, ReduceSumSquare_int32) { OpTester test("ReduceSumSquare"); test.AddAttribute("axes", std::vector{0, 2}); diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index b0643672d4..4abbc94b82 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -34,6 +34,42 @@ void Check(const OpTester::Data& expected_data, const Tensor& output_tensor, con } } +template <> +void Check(const OpTester::Data& expected_data, const Tensor& output_tensor, const std::string& provider_type) { + auto& expected_tensor = expected_data.data_.Get(); + auto* expected = expected_tensor.template Data(); + auto* output = output_tensor.template Data(); + auto size = output_tensor.Shape().Size(); + + bool has_abs_err = expected_data.absolute_error_.has_value(); + bool has_rel_err = expected_data.relative_error_.has_value(); + + double threshold = 0.001; +#ifdef USE_CUDA + threshold = 0.005; +#endif + + for (int i = 0; i < size; ++i) { + if (std::isinf(expected[i])) { // Test infinity for equality + EXPECT_EQ(expected[i], output[i]); + } else if (std::isnan(expected[i])) { + EXPECT_TRUE(std::isnan(output[i])) << "Expected output " << i << " to be NaN"; + } else { + if (!has_abs_err && !has_rel_err) { + // the default for existing tests + EXPECT_NEAR(expected[i], output[i], threshold) << "provider_type: " << provider_type; + } else { + if (has_abs_err) { + EXPECT_NEAR(expected[i], output[i], expected_data.absolute_error_.value()) << "provider_type: " << provider_type; + } + if (has_rel_err) { + EXPECT_NEAR(expected[i], output[i], expected_data.relative_error_.value() * std::abs(expected[i])) << "provider_type: " << provider_type; + } + } + } + } +} + template <> void Check(const OpTester::Data& expected_data, const Tensor& output_tensor, const std::string& provider_type) { auto& expected_tensor = expected_data.data_.Get();