From 59c9d83f357083688d5d55fd2dde90bf32c3ca04 Mon Sep 17 00:00:00 2001 From: Ke Zhang Date: Fri, 9 Aug 2019 17:16:57 -0700 Subject: [PATCH] add int64 support for less op. (#1604) --- onnxruntime/core/providers/cpu/cpu_execution_provider.cc | 2 ++ onnxruntime/core/providers/cpu/math/element_wise_ops.cc | 1 + .../test/providers/cpu/math/element_wise_ops_test.cc | 8 ++++++++ 3 files changed, 11 insertions(+) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index fbb19fe332..bb95212533 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -218,6 +218,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Mea class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Less); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Cast); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN); @@ -497,6 +498,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index ece68834e7..33d330f7cf 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -90,6 +90,7 @@ REG_ELEMENTWISE_TYPED_KERNEL(Max, 8, double, Max_8); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Less, 7, 9, float, Less); REG_ELEMENTWISE_TYPED_KERNEL(Less, 9, int32_t, Less); +REG_ELEMENTWISE_TYPED_KERNEL(Less, 9, int64_t, Less); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Greater, 7, 9, float, Greater) REG_ELEMENTWISE_TYPED_KERNEL(Greater, 9, int32_t, Greater); diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 80770af33a..c77f5cd224 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -806,6 +806,14 @@ TEST(MathOpTest, Less_Scalar1) { test.Run(); } +TEST(MathOpTest, Less_int64_Scalar1) { + OpTester test("Less", 9); + test.AddInput("A", {4}, {1, 0, 2, -1}); + test.AddInput("B", {1}, {1}); + test.AddOutput("C", {4}, {false, true, false, true}); + test.Run(); +} + TEST(MathOpTest, Greater_7) { OpTester test("Greater"); std::vector dims{4};