From 19b0d0af87b2dfaca56462bf810c55c7e67bd9ab Mon Sep 17 00:00:00 2001 From: Yang Chen <40417152+yangchen-MS@users.noreply.github.com> Date: Tue, 8 Oct 2019 01:50:37 -0700 Subject: [PATCH] Enabled bool input type for Equal for op_ver 11 (#2034) This change enabled bool type for Equal-11's inputs --- onnxruntime/core/providers/cpu/cpu_execution_provider.cc | 3 +++ onnxruntime/core/providers/cpu/math/element_wise_ops.cc | 1 + .../test/providers/cpu/math/element_wise_ops_test.cc | 9 +++++++++ 3 files changed, 13 insertions(+) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 8aa6ab91e3..4f1de06a03 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -91,6 +91,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, bool, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, bool, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Mean); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Mean); @@ -505,6 +506,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Equal)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo dims{4}; + test.AddInput("A", dims, {false, true, false, true}); + test.AddInput("B", dims, {true, true, true, true}); + test.AddOutput("C", dims, {false, true, false, true}); + test.Run(); +} + TEST(MathOpTest, Equal_bool_scalar0) { OpTester test("Equal"); test.AddInput("A", {1}, {false});