add int64 support for less op. (#1604)

This commit is contained in:
Ke Zhang 2019-08-09 17:16:57 -07:00 committed by GitHub
parent 0187d876cb
commit 59c9d83f35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 0 deletions

View file

@ -218,6 +218,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Mea
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Greater);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Less);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Cast);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN);
@ -497,6 +498,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Cast)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN)>,

View file

@ -90,6 +90,7 @@ REG_ELEMENTWISE_TYPED_KERNEL(Max, 8, double, Max_8);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Less, 7, 9, float, Less);
REG_ELEMENTWISE_TYPED_KERNEL(Less, 9, int32_t, Less);
REG_ELEMENTWISE_TYPED_KERNEL(Less, 9, int64_t, Less);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Greater, 7, 9, float, Greater)
REG_ELEMENTWISE_TYPED_KERNEL(Greater, 9, int32_t, Greater);

View file

@ -806,6 +806,14 @@ TEST(MathOpTest, Less_Scalar1) {
test.Run();
}
TEST(MathOpTest, Less_int64_Scalar1) {
OpTester test("Less", 9);
test.AddInput<int64_t>("A", {4}, {1, 0, 2, -1});
test.AddInput<int64_t>("B", {1}, {1});
test.AddOutput<bool>("C", {4}, {false, true, false, true});
test.Run();
}
TEST(MathOpTest, Greater_7) {
OpTester test("Greater");
std::vector<int64_t> dims{4};