mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Add BitShift operator (#1981)
* Add BitShift operator. Enable uint32 and uint64 support initially.
This commit is contained in:
parent
d5d1719c1f
commit
fdbe365c37
6 changed files with 197 additions and 40 deletions
|
|
@ -392,6 +392,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, If
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterND);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gemm);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, GatherElements);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint64_t, BitShift);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Pad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, GatherND);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Range);
|
||||
|
|
@ -1010,6 +1012,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, GatherElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint64_t, BitShift)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, GatherND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Range)>,
|
||||
|
|
|
|||
|
|
@ -18,13 +18,14 @@ namespace onnxruntime {
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
|
||||
KERNEL_CLASS<TYPE>);
|
||||
|
||||
#define REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
OP_TYPE, \
|
||||
VERSION, \
|
||||
TYPE, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
|
||||
#define REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
OP_TYPE, \
|
||||
VERSION, \
|
||||
TYPE, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
|
||||
KERNEL_CLASS<TYPE>);
|
||||
|
||||
#define REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \
|
||||
|
|
@ -36,12 +37,13 @@ namespace onnxruntime {
|
|||
KERNEL_CLASS<TYPE>);
|
||||
|
||||
#define REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
|
||||
OP_TYPE, \
|
||||
VERSION_FROM, VERSION_TO, \
|
||||
TYPE, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
|
||||
OP_TYPE, \
|
||||
VERSION_FROM, VERSION_TO, \
|
||||
TYPE, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
|
||||
KERNEL_CLASS<TYPE>);
|
||||
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Add, 7, float, Add);
|
||||
|
|
@ -124,6 +126,11 @@ REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, float, Equal);
|
|||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mean, 6, 7, float, Mean_6);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Mean, 8, float, Mean_8);
|
||||
|
||||
//REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint8_t, BitShift);
|
||||
//REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint16_t, BitShift);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint32_t, BitShift);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint64_t, BitShift);
|
||||
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Erf, 9, float, Erf);
|
||||
|
||||
// REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Not, 1, bool, Not);
|
||||
|
|
@ -134,29 +141,33 @@ REG_ELEMENTWISE_TYPED_KERNEL(Erf, 9, float, Erf);
|
|||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Not,
|
||||
1,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<bool>())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<bool>())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()),
|
||||
Not);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
And,
|
||||
7,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<bool>())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<bool>())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()),
|
||||
And);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Or,
|
||||
7,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<bool>())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<bool>())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()),
|
||||
Or);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Xor,
|
||||
7,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<bool>())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<bool>())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()),
|
||||
Xor);
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -501,6 +512,67 @@ Status Mean_8<float>::Compute(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
BitShift<T>::BitShift(const OpKernelInfo& info) : OpKernel(info) {
|
||||
std::string direction;
|
||||
auto status = info.GetAttr("direction", &direction);
|
||||
ORT_ENFORCE(status.IsOK(), status);
|
||||
|
||||
if (direction == "LEFT")
|
||||
shift_left_ = true;
|
||||
else if (direction == "RIGHT")
|
||||
shift_left_ = false;
|
||||
else
|
||||
ORT_THROW("Invalid direction value of '", direction, "'. Valid values are 'LEFT' or 'RIGHT'.");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BitShift<T>::Compute(OpKernelContext* context) const {
|
||||
return BroadcastTwo<T, T>(
|
||||
*context,
|
||||
[this](EigenVectorMap<T> output, T input0, ConstEigenVectorMap<T> input1) {
|
||||
int64_t i = 0;
|
||||
if (shift_left_) {
|
||||
for (const auto& input : input1.array()) {
|
||||
output[i++] = input0 << input;
|
||||
}
|
||||
} else {
|
||||
for (const auto& input : input1.array()) {
|
||||
output[i++] = input0 >> input;
|
||||
}
|
||||
}
|
||||
},
|
||||
[this](EigenVectorMap<T> output, ConstEigenVectorMap<T> input0, T input1) {
|
||||
int64_t i = 0;
|
||||
if (shift_left_) {
|
||||
for (const auto& input : input0.array()) {
|
||||
output[i++] = input << input1;
|
||||
}
|
||||
} else {
|
||||
for (const auto& input : input0.array()) {
|
||||
output[i++] = input >> input1;
|
||||
}
|
||||
}
|
||||
},
|
||||
[this](EigenVectorMap<T> output, ConstEigenVectorMap<T> input0, ConstEigenVectorMap<T> input1) {
|
||||
auto cur0 = input0.begin(), end0 = input0.end();
|
||||
auto cur1 = input1.begin(), end1 = input1.end();
|
||||
auto cur_out = output.begin(), end_out = output.end();
|
||||
if (shift_left_) {
|
||||
for (; cur0 != end0; ++cur0, ++cur1, ++cur_out) {
|
||||
*cur_out = *cur0 << *cur1;
|
||||
}
|
||||
} else {
|
||||
for (; cur0 != end0; ++cur0, ++cur1, ++cur_out) {
|
||||
*cur_out = *cur0 >> *cur1;
|
||||
}
|
||||
}
|
||||
|
||||
ORT_ENFORCE(cur1 == end1);
|
||||
ORT_ENFORCE(cur_out == end_out);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Sin final : public OpKernel {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -269,6 +269,16 @@ class Mean_8 final : public OpKernel {
|
|||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BitShift final : public OpKernel {
|
||||
public:
|
||||
explicit BitShift(const OpKernelInfo& info);
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
bool shift_left_;
|
||||
};
|
||||
|
||||
// PRelu is activation function, but it's closer to binary elementwise ops in implementation
|
||||
template <typename T>
|
||||
class PRelu final : public OpKernel {
|
||||
|
|
@ -536,8 +546,8 @@ struct TensorAllocator {
|
|||
|
||||
std::unique_ptr<Tensor> Allocate(const TensorShape& shape) {
|
||||
return onnxruntime::make_unique<Tensor>(DataTypeImpl::GetType<T>(),
|
||||
shape,
|
||||
allocator_);
|
||||
shape,
|
||||
allocator_);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -432,14 +432,10 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
{"onehot_with_negative_axis", "OneHot(11) not implemented yet"},
|
||||
{"onehot_with_axis", "OneHot(11) not implemented yet"},
|
||||
{"onehot_negative_indices", "OneHot(11) not implemented yet"},
|
||||
{"bitshift_right_uint8", "BitShift(11) not implemented yet"},
|
||||
{"bitshift_right_uint64", "BitShift(11) not implemented yet"},
|
||||
{"bitshift_right_uint32", "BitShift(11) not implemented yet"},
|
||||
{"bitshift_right_uint16", "BitShift(11) not implemented yet"},
|
||||
{"bitshift_left_uint8", "BitShift(11) not implemented yet"},
|
||||
{"bitshift_left_uint64", "BitShift(11) not implemented yet"},
|
||||
{"bitshift_left_uint32", "BitShift(11) not implemented yet"},
|
||||
{"bitshift_left_uint16", "BitShift(11) not implemented yet"},
|
||||
{"bitshift_right_uint8", "BitShift(11) uint8 support not enabled currently"},
|
||||
{"bitshift_right_uint16", "BitShift(11) uint16 support not enabled currently"},
|
||||
{"bitshift_left_uint8", "BitShift(11) uint8 support not enabled currently"},
|
||||
{"bitshift_left_uint16", "BitShift(11) uint16 support not enabled currently"},
|
||||
{"reflect_pad", "test data type `int32_t` not supported yet, the `float` equivalent is covered via unit tests"},
|
||||
{"edge_pad", "test data type `int32_t` not supported yet, the `float` equivalent is covered via unit tests"},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1547,5 +1547,77 @@ TEST(ModOpTest, Int32_mod_bcast) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(BitShiftOpTest, SimpleLeft) {
|
||||
OpTester test("BitShift", 11);
|
||||
test.AddAttribute("direction", "LEFT");
|
||||
test.AddInput<uint32_t>("X", {3}, {16, 4, 1});
|
||||
test.AddInput<uint32_t>("Y", {3}, {1, 2, 3});
|
||||
test.AddOutput<uint32_t>("Z", {3}, {32, 16, 8});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(BitShiftOpTest, SimpleRight) {
|
||||
OpTester test("BitShift", 11);
|
||||
test.AddAttribute("direction", "RIGHT");
|
||||
test.AddInput<uint32_t>("X", {3}, {16, 4, 1});
|
||||
test.AddInput<uint32_t>("Y", {3}, {1, 2, 3});
|
||||
test.AddOutput<uint32_t>("Z", {3}, {8, 1, 0});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(BitShiftOpTest, ScalarLeftX) {
|
||||
OpTester test("BitShift", 11);
|
||||
test.AddAttribute("direction", "LEFT");
|
||||
test.AddInput<uint32_t>("X", {1}, {16});
|
||||
test.AddInput<uint32_t>("Y", {3}, {1, 2, 3});
|
||||
test.AddOutput<uint32_t>("Z", {3}, {32, 64, 128});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(BitShiftOpTest, ScalarLeftY) {
|
||||
OpTester test("BitShift", 11);
|
||||
test.AddAttribute("direction", "LEFT");
|
||||
test.AddInput<uint32_t>("X", {3}, {16, 4, 1});
|
||||
test.AddInput<uint32_t>("Y", {1}, {1});
|
||||
test.AddOutput<uint32_t>("Z", {3}, {32, 8, 2});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(BitShiftOpTest, ScalarRightX) {
|
||||
OpTester test("BitShift", 11);
|
||||
test.AddAttribute("direction", "RIGHT");
|
||||
test.AddInput<uint32_t>("X", {1}, {16});
|
||||
test.AddInput<uint32_t>("Y", {3}, {1, 2, 3});
|
||||
test.AddOutput<uint32_t>("Z", {3}, {8, 4, 2});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(BitShiftOpTest, ScalarRightY) {
|
||||
OpTester test("BitShift", 11);
|
||||
test.AddAttribute("direction", "RIGHT");
|
||||
test.AddInput<uint32_t>("X", {3}, {16, 4, 1});
|
||||
test.AddInput<uint32_t>("Y", {1}, {1});
|
||||
test.AddOutput<uint32_t>("Z", {3}, {8, 2, 0});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(BitShiftOpTest, BroadcastYLeft) {
|
||||
OpTester test("BitShift", 11);
|
||||
test.AddAttribute("direction", "LEFT");
|
||||
test.AddInput<uint64_t>("X", {3, 2}, {1, 2, 3, 4, 5, 6});
|
||||
test.AddInput<uint64_t>("Y", {2}, {1, 2});
|
||||
test.AddOutput<uint64_t>("Z", {3, 2}, {2, 8, 6, 16, 10, 24});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(BitShiftOpTest, BroadcastXRight) {
|
||||
OpTester test("BitShift", 11);
|
||||
test.AddAttribute("direction", "RIGHT");
|
||||
test.AddInput<uint64_t>("X", {2}, {64, 32});
|
||||
test.AddInput<uint64_t>("Y", {3, 2}, {1, 2, 3, 4, 5, 6});
|
||||
test.AddOutput<uint64_t>("Z", {3, 2}, {32, 8, 8, 2, 2, 0});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -90,6 +90,16 @@ def other_tests_failing_permanently_filters():
|
|||
return filters
|
||||
|
||||
|
||||
|
||||
def test_with_types_disabled_due_to_binary_size_concerns_filters():
|
||||
filters = ['^test_bitshift_right_uint16_cpu',
|
||||
'^test_bitshift_right_uint8_cpu',
|
||||
'^test_bitshift_left_uint16_cpu',
|
||||
'^test_bitshift_left_uint8_cpu']
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
def create_backend_test(testname=None):
|
||||
backend_test = OrtBackendTest(c2, __name__)
|
||||
|
||||
|
|
@ -103,14 +113,6 @@ def create_backend_test(testname=None):
|
|||
current_failing_tests = [#'^test_cast_STRING_to_FLOAT_cpu', # old test data that is bad on Linux CI builds
|
||||
'^test_qlinearconv_cpu',
|
||||
'^test_gru_seq_length_cpu',
|
||||
'^test_bitshift_right_uint16_cpu',
|
||||
'^test_bitshift_right_uint32_cpu',
|
||||
'^test_bitshift_right_uint64_cpu',
|
||||
'^test_bitshift_right_uint8_cpu',
|
||||
'^test_bitshift_left_uint16_cpu',
|
||||
'^test_bitshift_left_uint32_cpu',
|
||||
'^test_bitshift_left_uint64_cpu',
|
||||
'^test_bitshift_left_uint8_cpu',
|
||||
'^test_dynamicquantizelinear_expanded.*',
|
||||
'^test_dynamicquantizelinear_max_adjusted_expanded.*',
|
||||
'^test_dynamicquantizelinear_min_adjusted_expanded.*',
|
||||
|
|
@ -176,7 +178,8 @@ def create_backend_test(testname=None):
|
|||
filters = current_failing_tests + \
|
||||
tests_with_pre_opset7_dependencies_filters() + \
|
||||
unsupported_usages_filters() + \
|
||||
other_tests_failing_permanently_filters()
|
||||
other_tests_failing_permanently_filters() + \
|
||||
test_with_types_disabled_due_to_binary_size_concerns_filters()
|
||||
|
||||
backend_test.exclude('(' + '|'.join(filters) + ')')
|
||||
print('excluded tests:', filters)
|
||||
|
|
|
|||
Loading…
Reference in a new issue