From fdbe365c37745556629ffe28e0c56e0dbce6646f Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 5 Oct 2019 07:48:58 +1000 Subject: [PATCH] Add BitShift operator (#1981) * Add BitShift operator. Enable uint32 and uint64 support initially. --- .../providers/cpu/cpu_execution_provider.cc | 4 + .../providers/cpu/math/element_wise_ops.cc | 114 ++++++++++++++---- .../providers/cpu/math/element_wise_ops.h | 14 ++- onnxruntime/test/onnx/main.cc | 12 +- .../cpu/math/element_wise_ops_test.cc | 72 +++++++++++ .../test/python/onnx_backend_test_series.py | 21 ++-- 6 files changed, 197 insertions(+), 40 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 0bcc81ee32..3283330328 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -392,6 +392,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, If class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gemm); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, GatherElements); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint64_t, BitShift); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Range); @@ -1010,6 +1012,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index b94b12ca9e..2a72d211e1 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -18,13 +18,14 @@ namespace onnxruntime { KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ KERNEL_CLASS); -#define REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \ - ONNX_CPU_OPERATOR_TYPED_KERNEL( \ - OP_TYPE, \ - VERSION, \ - TYPE, \ - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ +#define REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + OP_TYPE, \ + VERSION, \ + TYPE, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ KERNEL_CLASS); #define REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \ @@ -36,12 +37,13 @@ namespace onnxruntime { KERNEL_CLASS); #define REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \ - ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ - OP_TYPE, \ - VERSION_FROM, VERSION_TO, \ - TYPE, \ - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ + OP_TYPE, \ + VERSION_FROM, VERSION_TO, \ + TYPE, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ KERNEL_CLASS); REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, float, Add); @@ -124,6 +126,11 @@ REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, float, Equal); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mean, 6, 7, float, Mean_6); REG_ELEMENTWISE_TYPED_KERNEL(Mean, 8, float, Mean_8); +//REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint8_t, BitShift); +//REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint16_t, BitShift); +REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint32_t, BitShift); +REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint64_t, BitShift); + REG_ELEMENTWISE_TYPED_KERNEL(Erf, 9, float, Erf); // REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Not, 1, bool, Not); @@ -134,29 +141,33 @@ REG_ELEMENTWISE_TYPED_KERNEL(Erf, 9, float, Erf); ONNX_CPU_OPERATOR_KERNEL( Not, 1, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Not); ONNX_CPU_OPERATOR_KERNEL( And, 7, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), And); ONNX_CPU_OPERATOR_KERNEL( Or, 7, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Or); ONNX_CPU_OPERATOR_KERNEL( Xor, 7, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Xor); template @@ -501,6 +512,67 @@ Status Mean_8::Compute(OpKernelContext* context) const { return Status::OK(); } +template +BitShift::BitShift(const OpKernelInfo& info) : OpKernel(info) { + std::string direction; + auto status = info.GetAttr("direction", &direction); + ORT_ENFORCE(status.IsOK(), status); + + if (direction == "LEFT") + shift_left_ = true; + else if (direction == "RIGHT") + shift_left_ = false; + else + ORT_THROW("Invalid direction value of '", direction, "'. Valid values are 'LEFT' or 'RIGHT'."); +} + +template +Status BitShift::Compute(OpKernelContext* context) const { + return BroadcastTwo( + *context, + [this](EigenVectorMap output, T input0, ConstEigenVectorMap input1) { + int64_t i = 0; + if (shift_left_) { + for (const auto& input : input1.array()) { + output[i++] = input0 << input; + } + } else { + for (const auto& input : input1.array()) { + output[i++] = input0 >> input; + } + } + }, + [this](EigenVectorMap output, ConstEigenVectorMap input0, T input1) { + int64_t i = 0; + if (shift_left_) { + for (const auto& input : input0.array()) { + output[i++] = input << input1; + } + } else { + for (const auto& input : input0.array()) { + output[i++] = input >> input1; + } + } + }, + [this](EigenVectorMap output, ConstEigenVectorMap input0, ConstEigenVectorMap input1) { + auto cur0 = input0.begin(), end0 = input0.end(); + auto cur1 = input1.begin(), end1 = input1.end(); + auto cur_out = output.begin(), end_out = output.end(); + if (shift_left_) { + for (; cur0 != end0; ++cur0, ++cur1, ++cur_out) { + *cur_out = *cur0 << *cur1; + } + } else { + for (; cur0 != end0; ++cur0, ++cur1, ++cur_out) { + *cur_out = *cur0 >> *cur1; + } + } + + ORT_ENFORCE(cur1 == end1); + ORT_ENFORCE(cur_out == end_out); + }); +} + template class Sin final : public OpKernel { public: diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.h b/onnxruntime/core/providers/cpu/math/element_wise_ops.h index 50f1a9f925..4bbd0f93ac 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.h +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.h @@ -269,6 +269,16 @@ class Mean_8 final : public OpKernel { Status Compute(OpKernelContext* context) const override; }; +template +class BitShift final : public OpKernel { + public: + explicit BitShift(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + private: + bool shift_left_; +}; + // PRelu is activation function, but it's closer to binary elementwise ops in implementation template class PRelu final : public OpKernel { @@ -536,8 +546,8 @@ struct TensorAllocator { std::unique_ptr Allocate(const TensorShape& shape) { return onnxruntime::make_unique(DataTypeImpl::GetType(), - shape, - allocator_); + shape, + allocator_); } private: diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 6c5fbae387..9c9871eb6a 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -432,14 +432,10 @@ int real_main(int argc, char* argv[], Ort::Env& env) { {"onehot_with_negative_axis", "OneHot(11) not implemented yet"}, {"onehot_with_axis", "OneHot(11) not implemented yet"}, {"onehot_negative_indices", "OneHot(11) not implemented yet"}, - {"bitshift_right_uint8", "BitShift(11) not implemented yet"}, - {"bitshift_right_uint64", "BitShift(11) not implemented yet"}, - {"bitshift_right_uint32", "BitShift(11) not implemented yet"}, - {"bitshift_right_uint16", "BitShift(11) not implemented yet"}, - {"bitshift_left_uint8", "BitShift(11) not implemented yet"}, - {"bitshift_left_uint64", "BitShift(11) not implemented yet"}, - {"bitshift_left_uint32", "BitShift(11) not implemented yet"}, - {"bitshift_left_uint16", "BitShift(11) not implemented yet"}, + {"bitshift_right_uint8", "BitShift(11) uint8 support not enabled currently"}, + {"bitshift_right_uint16", "BitShift(11) uint16 support not enabled currently"}, + {"bitshift_left_uint8", "BitShift(11) uint8 support not enabled currently"}, + {"bitshift_left_uint16", "BitShift(11) uint16 support not enabled currently"}, {"reflect_pad", "test data type `int32_t` not supported yet, the `float` equivalent is covered via unit tests"}, {"edge_pad", "test data type `int32_t` not supported yet, the `float` equivalent is covered via unit tests"}, }; diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 0a2b105ec2..9c5a48e966 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1547,5 +1547,77 @@ TEST(ModOpTest, Int32_mod_bcast) { test.Run(); } +TEST(BitShiftOpTest, SimpleLeft) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "LEFT"); + test.AddInput("X", {3}, {16, 4, 1}); + test.AddInput("Y", {3}, {1, 2, 3}); + test.AddOutput("Z", {3}, {32, 16, 8}); + test.Run(); +} + +TEST(BitShiftOpTest, SimpleRight) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "RIGHT"); + test.AddInput("X", {3}, {16, 4, 1}); + test.AddInput("Y", {3}, {1, 2, 3}); + test.AddOutput("Z", {3}, {8, 1, 0}); + test.Run(); +} + +TEST(BitShiftOpTest, ScalarLeftX) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "LEFT"); + test.AddInput("X", {1}, {16}); + test.AddInput("Y", {3}, {1, 2, 3}); + test.AddOutput("Z", {3}, {32, 64, 128}); + test.Run(); +} + +TEST(BitShiftOpTest, ScalarLeftY) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "LEFT"); + test.AddInput("X", {3}, {16, 4, 1}); + test.AddInput("Y", {1}, {1}); + test.AddOutput("Z", {3}, {32, 8, 2}); + test.Run(); +} + +TEST(BitShiftOpTest, ScalarRightX) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "RIGHT"); + test.AddInput("X", {1}, {16}); + test.AddInput("Y", {3}, {1, 2, 3}); + test.AddOutput("Z", {3}, {8, 4, 2}); + test.Run(); +} + +TEST(BitShiftOpTest, ScalarRightY) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "RIGHT"); + test.AddInput("X", {3}, {16, 4, 1}); + test.AddInput("Y", {1}, {1}); + test.AddOutput("Z", {3}, {8, 2, 0}); + test.Run(); +} + +TEST(BitShiftOpTest, BroadcastYLeft) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "LEFT"); + test.AddInput("X", {3, 2}, {1, 2, 3, 4, 5, 6}); + test.AddInput("Y", {2}, {1, 2}); + test.AddOutput("Z", {3, 2}, {2, 8, 6, 16, 10, 24}); + test.Run(); +} + +TEST(BitShiftOpTest, BroadcastXRight) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "RIGHT"); + test.AddInput("X", {2}, {64, 32}); + test.AddInput("Y", {3, 2}, {1, 2, 3, 4, 5, 6}); + test.AddOutput("Z", {3, 2}, {32, 8, 8, 2, 2, 0}); + test.Run(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index fc93440659..7a532f1c09 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -90,6 +90,16 @@ def other_tests_failing_permanently_filters(): return filters + +def test_with_types_disabled_due_to_binary_size_concerns_filters(): + filters = ['^test_bitshift_right_uint16_cpu', + '^test_bitshift_right_uint8_cpu', + '^test_bitshift_left_uint16_cpu', + '^test_bitshift_left_uint8_cpu'] + + return filters + + def create_backend_test(testname=None): backend_test = OrtBackendTest(c2, __name__) @@ -103,14 +113,6 @@ def create_backend_test(testname=None): current_failing_tests = [#'^test_cast_STRING_to_FLOAT_cpu', # old test data that is bad on Linux CI builds '^test_qlinearconv_cpu', '^test_gru_seq_length_cpu', - '^test_bitshift_right_uint16_cpu', - '^test_bitshift_right_uint32_cpu', - '^test_bitshift_right_uint64_cpu', - '^test_bitshift_right_uint8_cpu', - '^test_bitshift_left_uint16_cpu', - '^test_bitshift_left_uint32_cpu', - '^test_bitshift_left_uint64_cpu', - '^test_bitshift_left_uint8_cpu', '^test_dynamicquantizelinear_expanded.*', '^test_dynamicquantizelinear_max_adjusted_expanded.*', '^test_dynamicquantizelinear_min_adjusted_expanded.*', @@ -176,7 +178,8 @@ def create_backend_test(testname=None): filters = current_failing_tests + \ tests_with_pre_opset7_dependencies_filters() + \ unsupported_usages_filters() + \ - other_tests_failing_permanently_filters() + other_tests_failing_permanently_filters() + \ + test_with_types_disabled_due_to_binary_size_concerns_filters() backend_test.exclude('(' + '|'.join(filters) + ')') print('excluded tests:', filters)