Support double type for Greater CPU (#4373)

* Add double for Greater

* add double type for Greater

* udpate test according to dtype
This commit is contained in:
Bowen Bao 2020-07-13 11:25:14 -07:00 committed by GitHub
parent f18dee84c2
commit 07455cff28
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 1 deletions

View file

@ -98,6 +98,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Xor
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Less);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, double, Less);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Greater);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, double, Greater);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, bool, Equal);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, int32_t, Equal);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, int64_t, Equal);
@ -542,6 +543,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
double, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9,
float, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9,
double, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, bool, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, int32_t,
Equal)>,

View file

@ -140,7 +140,8 @@ REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 7, 9, double, Less);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 9, int32_t, Less);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 9, int64_t, Less);
REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 7, 9, float, Greater)
REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 7, 9, float, Greater);
REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 7, 9, double, Greater);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 9, int32_t, Greater);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 9, int64_t, Greater);

View file

@ -1360,6 +1360,15 @@ TEST(MathOpTest, Greater_9_float) {
test.Run();
}
TEST(MathOpTest, Greater_9_double) {
OpTester test("Greater", 9);
std::vector<int64_t> dims{4};
test.AddInput<double>("A", dims, {1.0, 0.0, 3.0, -1.0});
test.AddInput<double>("B", dims, {1.0, 1.0, 2.0, -1.0});
test.AddOutput<bool>("C", dims, {false, false, true, false});
test.Run();
}
TEST(MathOpTest, Greater_9_int32) {
OpTester test("Greater", 9);
std::vector<int64_t> dims{4};