Enabled bool input type for Equal for op_ver 11 (#2034)

This change enabled bool type for Equal-11's inputs
This commit is contained in:
Yang Chen 2019-10-08 01:50:37 -07:00 committed by GitHub
parent 203c2f5b59
commit 19b0d0af87
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 0 deletions

View file

@ -91,6 +91,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, bool, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, bool, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, Equal);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7, float, Mean);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Mean);
@ -505,6 +506,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t,
Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, bool,
Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float,
Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 7,

View file

@ -121,6 +121,7 @@ REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 9, int64_t, Greater);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 7, bool, Equal);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 7, int32_t, Equal);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 7, int64_t, Equal);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, bool, Equal);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, float, Equal);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mean, 6, 7, float, Mean_6);

View file

@ -995,6 +995,15 @@ TEST(MathOpTest, Equal_bool) {
test.Run();
}
TEST(MathOpTest, Equal_11_bool) {
OpTester test("Equal", 11);
std::vector<int64_t> dims{4};
test.AddInput<bool>("A", dims, {false, true, false, true});
test.AddInput<bool>("B", dims, {true, true, true, true});
test.AddOutput<bool>("C", dims, {false, true, false, true});
test.Run();
}
TEST(MathOpTest, Equal_bool_scalar0) {
OpTester test("Equal");
test.AddInput<bool>("A", {1}, {false});