mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
Support double for operators Log, Reciprocal, Sum (CPU) (#6032)
* Support double for operators Log, Reciprocal, Sum * remove tesdt erf_double
This commit is contained in:
parent
8a0f5c50ab
commit
2d09db67b4
3 changed files with 161 additions and 18 deletions
|
|
@ -60,6 +60,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Floor);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Ceil);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Reciprocal);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double, Reciprocal);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Sqrt);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double, Sqrt);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 12, float, Add);
|
||||
|
|
@ -87,8 +88,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Exp);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double, Exp);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Log);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double, Log);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Sum);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, double, Sum);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, float, Sum);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, double, Sum);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Min);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 11, Min);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Max);
|
||||
|
|
@ -475,6 +479,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Si
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Sign);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Size);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Sum);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Sum);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Flatten);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, LRN);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MeanVarianceNormalization);
|
||||
|
|
@ -536,6 +541,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint32_t, Abs);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint64_t, Abs);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Reciprocal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Reciprocal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Floor);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Ceil);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Sqrt);
|
||||
|
|
@ -544,6 +550,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Re
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Exp);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Exp);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Log);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Log);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Pow);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, DepthToSpace);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, SpaceToDepth);
|
||||
|
|
@ -717,6 +724,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Ceil)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float,
|
||||
Reciprocal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double,
|
||||
Reciprocal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Sqrt)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double, Sqrt)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Neg)>,
|
||||
|
|
@ -730,9 +739,13 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Log)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double, Log)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7,
|
||||
float, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7,
|
||||
double, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, float, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, double, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7,
|
||||
float, Min)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 11, Min)>,
|
||||
|
|
@ -1361,6 +1374,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Size)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Sum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Sigmoid)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t,
|
||||
DequantizeLinear)>,
|
||||
|
|
@ -1470,6 +1484,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
Abs)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float,
|
||||
Reciprocal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double,
|
||||
Reciprocal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float,
|
||||
Floor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float,
|
||||
|
|
@ -1480,6 +1496,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Log)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Log)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Split)>,
|
||||
|
|
|
|||
|
|
@ -153,7 +153,9 @@ REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Ceil, 6, 12, float, Ceil);
|
|||
REG_ELEMENTWISE_TYPED_KERNEL(Ceil, 13, float, Ceil);
|
||||
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Reciprocal, 6, 12, float, Reciprocal);
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Reciprocal, 6, 12, double, Reciprocal);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Reciprocal, 13, float, Reciprocal);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Reciprocal, 13, double, Reciprocal);
|
||||
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sqrt, 6, 12, float, Sqrt);
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sqrt, 6, 12, double, Sqrt);
|
||||
|
|
@ -172,12 +174,17 @@ REG_ELEMENTWISE_TYPED_KERNEL(Exp, 13, float, Exp);
|
|||
REG_ELEMENTWISE_TYPED_KERNEL(Exp, 13, double, Exp);
|
||||
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Log, 6, 12, float, Log);
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Log, 6, 12, double, Log);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Log, 13, float, Log);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Log, 13, double, Log);
|
||||
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sum, 6, 7, float, Sum_6);
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sum, 6, 7, double, Sum_6);
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sum, 8, 12, float, Sum_8);
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sum, 8, 12, double, Sum_8);
|
||||
// Supposed to add BFloat16 but we are not supporting now, however, separate registration
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Sum, 13, float, Sum_8);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Sum, 13, double, Sum_8);
|
||||
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Max, 6, 7, float, Max_6);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Max, 8, 11, Max_8, float, double);
|
||||
|
|
@ -465,49 +472,49 @@ Pow::Compute(OpKernelContext* context) const {
|
|||
return s;
|
||||
}
|
||||
|
||||
template <>
|
||||
Status Sum_6<float>::Compute(OpKernelContext* ctx) const {
|
||||
template <typename T>
|
||||
Status Sum_6<T>::Compute(OpKernelContext* ctx) const {
|
||||
auto input_count = Node().InputArgCount().front();
|
||||
ORT_ENFORCE(input_count >= 1, "Must have 1 or more inputs");
|
||||
auto& data_0 = *ctx->Input<Tensor>(0);
|
||||
auto& shape = data_0.Shape();
|
||||
auto sum = EigenMap<float>(*ctx->Output(0, shape));
|
||||
auto sum = EigenMap<T>(*ctx->Output(0, shape));
|
||||
|
||||
if (input_count == 1) {
|
||||
sum = EigenMap<float>(data_0);
|
||||
sum = EigenMap<T>(data_0);
|
||||
} else {
|
||||
auto& data_1 = *ctx->Input<Tensor>(1);
|
||||
ORT_ENFORCE(data_1.Shape() == shape, "All inputs must have the same shape");
|
||||
|
||||
sum = EigenMap<float>(data_0) + EigenMap<float>(data_1);
|
||||
sum = EigenMap<T>(data_0) + EigenMap<T>(data_1);
|
||||
for (int index = 2; index < input_count; index++) {
|
||||
auto& data_n = *ctx->Input<Tensor>(index);
|
||||
ORT_ENFORCE(data_n.Shape() == shape, "All inputs must have the same shape");
|
||||
sum += EigenMap<float>(data_n);
|
||||
sum += EigenMap<T>(data_n);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status Sum_8<float>::Compute(OpKernelContext* context) const {
|
||||
template <typename T>
|
||||
Status Sum_8<T>::Compute(OpKernelContext* context) const {
|
||||
const auto typed_allocator = [](const TensorAllocator& tensor_allocator, const TensorShape& shape) {
|
||||
return tensor_allocator.Allocate<float>(shape);
|
||||
return tensor_allocator.Allocate<T>(shape);
|
||||
};
|
||||
|
||||
ProcessBroadcastSpanFuncs funcs{
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<float>() =
|
||||
per_iter_bh.ScalarInput0<float>() + per_iter_bh.EigenInput1<float>().array();
|
||||
per_iter_bh.OutputEigen<T>() =
|
||||
per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<float>() =
|
||||
per_iter_bh.EigenInput0<float>().array() + per_iter_bh.ScalarInput1<float>();
|
||||
per_iter_bh.OutputEigen<T>() =
|
||||
per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<float>() =
|
||||
per_iter_bh.EigenInput0<float>() + per_iter_bh.EigenInput1<float>();
|
||||
per_iter_bh.OutputEigen<T>() =
|
||||
per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
|
||||
}};
|
||||
|
||||
int input_count = Node().InputArgCount().front();
|
||||
|
|
|
|||
|
|
@ -587,6 +587,18 @@ TEST(MathOpTest, Reciprocal) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Reciprocal_double) {
|
||||
OpTester test("Reciprocal");
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
test.AddInput<double>("X", dims,
|
||||
{1.0, 2.0,
|
||||
-1.0, -2.0});
|
||||
test.AddOutput<double>("Y", dims,
|
||||
{1.0, 0.5,
|
||||
-1.0, -0.5});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Sqrt_Float) {
|
||||
OpTester test("Sqrt");
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
|
|
@ -833,6 +845,19 @@ TEST(MathOpTest, Log) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Log_double) {
|
||||
OpTester test("Log");
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
test.AddInput<double>("X", dims,
|
||||
{1.0, 2.0,
|
||||
5.0, 10.0});
|
||||
test.AddOutput<double>("Y", dims,
|
||||
{0.0, std::log(2.0),
|
||||
std::log(5.0), std::log(10.0)});
|
||||
test.SetOutputRelErr("Y", 1e-7f);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Sum_6) {
|
||||
OpTester test("Sum", 6);
|
||||
std::vector<int64_t> dims{3, 3};
|
||||
|
|
@ -860,6 +885,33 @@ TEST(MathOpTest, Sum_6) {
|
|||
#endif
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Sum_6_double) {
|
||||
OpTester test("Sum", 6);
|
||||
std::vector<int64_t> dims{3, 3};
|
||||
test.AddInput<double>("data_0", dims,
|
||||
{1.0, 0.0, 1.0,
|
||||
-1.0, 1.1, -100.0,
|
||||
-5.4, 0.01, -10000.0});
|
||||
test.AddInput<double>("data_1", dims,
|
||||
{1.0, 0.0, 2.0,
|
||||
-2.0, 2.2, 64.0,
|
||||
-1.0, 0.02, 0.25});
|
||||
test.AddInput<double>("data_3", dims,
|
||||
{1.0, 0.0, 3.0,
|
||||
-3.0, 3.3, 64.0,
|
||||
5.4, 0.03, 10000.0});
|
||||
test.AddOutput<double>("sum", dims,
|
||||
{3.0, 0.0, 6.0,
|
||||
-6.0, 6.6, 28.0,
|
||||
-1.0, 0.06, 0.25});
|
||||
|
||||
#if defined(OPENVINO_CONFIG_MYRIAD) || defined(OPENVINO_CONFIG_GPU_FP16)
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); // OpenVINO EP: Disabled due to accuracy mismatch for FP16
|
||||
#else
|
||||
test.Run();
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Sum_8_Test1) {
|
||||
OpTester test("Sum", 8);
|
||||
test.AddInput<float>("data_0", {3}, {1.0f, 2.0f, 3.0f});
|
||||
|
|
@ -886,6 +938,31 @@ TEST(MathOpTest, Sum_8_Test1) {
|
|||
#endif
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Sum_8_Test1_double) {
|
||||
OpTester test("Sum", 8);
|
||||
test.AddInput<double>("data_0", {3}, {1.0, 2.0, 3.0});
|
||||
test.AddInput<double>("data_1", {3, 1}, {10.0, 20.0, 30.0});
|
||||
test.AddInput<double>("data_2", {3, 1, 1}, {100.0, 200.0, 300.0});
|
||||
test.AddOutput<double>("sum", {3, 3, 3},
|
||||
{111.0, 112.0, 113.0,
|
||||
121.0, 122.0, 123.0,
|
||||
131.0, 132.0, 133.0,
|
||||
|
||||
211.0, 212.0, 213.0,
|
||||
221.0, 222.0, 223.0,
|
||||
231.0, 232.0, 233.0,
|
||||
|
||||
311.0, 312.0, 313.0,
|
||||
321.0, 322.0, 323.0,
|
||||
331.0, 332.0, 333.0});
|
||||
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32) || defined(OPENVINO_CONFIG_MYRIAD) || defined(OPENVINO_CONFIG_VAD_M)
|
||||
//OpenVINO: Disabled due to software limitation for GPU and VPU Plugins.
|
||||
//This test runs fine on CPU Plugin
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
#else
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: Expected output shape [{3,3,3}] did not match run output shape [{3,1,1}] for sum
|
||||
#endif
|
||||
}
|
||||
TEST(MathOpTest, Sum_8_Test2) {
|
||||
OpTester test("Sum", 8);
|
||||
std::vector<int64_t> dims{3, 3};
|
||||
|
|
@ -913,15 +990,50 @@ TEST(MathOpTest, Sum_8_Test2) {
|
|||
59.6f, 64.01f, -8.0f});
|
||||
|
||||
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32) || defined(OPENVINO_CONFIG_MYRIAD) || defined(OPENVINO_CONFIG_VAD_M)
|
||||
// OpenVINO: Disabled temporarily due to accuarcy issues
|
||||
// OpenVINO: Disabled temporarily due to accuracy issues
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); //TensorRT: Input batch size is inconsistent
|
||||
#else
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "Sum is not correct", {kTensorrtExecutionProvider}); //TensorRT: result differs
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Sum_8_Test2_double) {
|
||||
OpTester test("Sum", 8);
|
||||
std::vector<int64_t> dims{3, 3};
|
||||
test.AddInput<double>("data_0", dims,
|
||||
{
|
||||
1.0,
|
||||
0.0,
|
||||
1.0,
|
||||
-1.0,
|
||||
1.1,
|
||||
-100.0,
|
||||
-5.4,
|
||||
0.01,
|
||||
-74.0,
|
||||
});
|
||||
std::vector<int64_t> dims_1{3};
|
||||
test.AddInput<double>("data_1", dims_1,
|
||||
{1.0, 0.0, 2.0});
|
||||
std::vector<int64_t> dims_2{3, 1};
|
||||
test.AddInput<double>("data_2", dims_2,
|
||||
{-3.0, 3.3, 64.0});
|
||||
test.AddOutput<double>("sum", dims,
|
||||
{-1.0, -3.0, 0.0,
|
||||
3.3, 4.4, -94.7,
|
||||
59.6, 64.01, -8.0});
|
||||
|
||||
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32) || defined(OPENVINO_CONFIG_MYRIAD) || defined(OPENVINO_CONFIG_VAD_M)
|
||||
// OpenVINO: Disabled temporarily due to accuracy issues
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); //TensorRT: Input batch size is inconsistent
|
||||
#else
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "Sum is not correct", {kTensorrtExecutionProvider}); //TensorRT: result differs
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void TestSumMultipleInputsNoBroadcasting(size_t num_inputs, const TensorShape& shape) {
|
||||
using element_type = float;
|
||||
using element_type = T;
|
||||
|
||||
OpTester test{"Sum", 8};
|
||||
|
||||
|
|
@ -949,7 +1061,14 @@ static void TestSumMultipleInputsNoBroadcasting(size_t num_inputs, const TensorS
|
|||
TEST(MathOpTest, SumMultipleInputsNoBroadcasting) {
|
||||
const TensorShape shape{3, 3, 3};
|
||||
for (size_t num_inputs = 2; num_inputs < 10; ++num_inputs) {
|
||||
TestSumMultipleInputsNoBroadcasting(num_inputs, shape);
|
||||
TestSumMultipleInputsNoBroadcasting<float>(num_inputs, shape);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MathOpTest, SumMultipleInputsNoBroadcasting_double) {
|
||||
const TensorShape shape{3, 3, 3};
|
||||
for (size_t num_inputs = 2; num_inputs < 10; ++num_inputs) {
|
||||
TestSumMultipleInputsNoBroadcasting<double>(num_inputs, shape);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue