From 07455cff289d411a3dd0bc58c728d4cc015249ed Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Mon, 13 Jul 2020 11:25:14 -0700 Subject: [PATCH] Support double type for Greater CPU (#4373) * Add double for Greater * add double type for Greater * udpate test according to dtype --- onnxruntime/core/providers/cpu/cpu_execution_provider.cc | 3 +++ onnxruntime/core/providers/cpu/math/element_wise_ops.cc | 3 ++- .../test/providers/cpu/math/element_wise_ops_test.cc | 9 +++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index f77cb09dbb..4bc6d526ad 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -98,6 +98,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Xor class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, double, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Greater); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, double, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, bool, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, int32_t, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, int64_t, Equal); @@ -542,6 +543,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { double, Less)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 2efa823f6f..497e3f89c2 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -140,7 +140,8 @@ REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 7, 9, double, Less); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 9, int32_t, Less); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 9, int64_t, Less); -REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 7, 9, float, Greater) +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 7, 9, float, Greater); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 7, 9, double, Greater); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 9, int32_t, Greater); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 9, int64_t, Greater); diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index cf20e27fc8..8ad723980f 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1360,6 +1360,15 @@ TEST(MathOpTest, Greater_9_float) { test.Run(); } +TEST(MathOpTest, Greater_9_double) { + OpTester test("Greater", 9); + std::vector dims{4}; + test.AddInput("A", dims, {1.0, 0.0, 3.0, -1.0}); + test.AddInput("B", dims, {1.0, 1.0, 2.0, -1.0}); + test.AddOutput("C", dims, {false, false, true, false}); + test.Run(); +} + TEST(MathOpTest, Greater_9_int32) { OpTester test("Greater", 9); std::vector dims{4};