mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
cpu to support bitwise ops (#14197)
This commit is contained in:
parent
77b455b969
commit
7b6d880b28
5 changed files with 256 additions and 2 deletions
|
|
@ -44,6 +44,10 @@ Do not modify directly.*
|
|||
|||[9, 13]|**T** = tensor(double), tensor(float)|
|
||||
|||[7, 8]|**T** = tensor(double), tensor(float)|
|
||||
|BitShift|*in* X:**T**<br> *in* Y:**T**<br> *out* Z:**T**|11+|**T** = tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|BitwiseAnd|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|BitwiseNot|*in* X:**T**<br> *out* Y:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|BitwiseOr|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|BitwiseXor|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|BlackmanWindow|*in* size:**T1**<br> *out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)<br/> **T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Cast|*in* input:**T1**<br> *out* output:**T2**|13+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
|
|||
|
|
@ -830,6 +830,38 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceSumSquare);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseAnd);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseAnd);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseAnd);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseAnd);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseAnd);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseAnd);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseAnd);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseAnd);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseNot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseNot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseNot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseNot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseNot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseNot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseNot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseNot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseOr);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseOr);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseOr);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseOr);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseOr);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseOr);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseOr);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseOr);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseXor);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseXor);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseXor);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseXor);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseXor);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseXor);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseXor);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseXor);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Pad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements);
|
||||
|
|
@ -2131,6 +2163,38 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
ReduceSumSquare)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double,
|
||||
ReduceSumSquare)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseAnd)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseAnd)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseAnd)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseAnd)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseAnd)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseAnd)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseAnd)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseAnd)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseNot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseNot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseNot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseNot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseNot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseNot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseNot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseNot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseOr)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseOr)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseOr)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseOr)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseOr)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseOr)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseOr)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseOr)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseXor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseXor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseXor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, BitwiseXor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, BitwiseXor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint16_t, BitwiseXor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint32_t, BitwiseXor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint64_t, BitwiseXor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
|
||||
|
|
|
|||
|
|
@ -374,6 +374,42 @@ REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint8_t, BitShift);
|
|||
REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint32_t, BitShift);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint64_t, BitShift);
|
||||
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, int8_t, BitwiseAnd);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, int16_t, BitwiseAnd);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, int32_t, BitwiseAnd);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, int64_t, BitwiseAnd);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, uint8_t, BitwiseAnd);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, uint16_t, BitwiseAnd);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, uint32_t, BitwiseAnd);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseAnd, 18, uint64_t, BitwiseAnd);
|
||||
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, int8_t, BitwiseNot);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, int16_t, BitwiseNot);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, int32_t, BitwiseNot);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, int64_t, BitwiseNot);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, uint8_t, BitwiseNot);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, uint16_t, BitwiseNot);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, uint32_t, BitwiseNot);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseNot, 18, uint64_t, BitwiseNot);
|
||||
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, int8_t, BitwiseOr);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, int16_t, BitwiseOr);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, int32_t, BitwiseOr);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, int64_t, BitwiseOr);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, uint8_t, BitwiseOr);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, uint16_t, BitwiseOr);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, uint32_t, BitwiseOr);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseOr, 18, uint64_t, BitwiseOr);
|
||||
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, int8_t, BitwiseXor);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, int16_t, BitwiseXor);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, int32_t, BitwiseXor);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, int64_t, BitwiseXor);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, uint8_t, BitwiseXor);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, uint16_t, BitwiseXor);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, uint32_t, BitwiseXor);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitwiseXor, 18, uint64_t, BitwiseXor);
|
||||
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Erf, 9, 12, float, Erf);
|
||||
// Supposed to add BFloat16 but we are not supporting now, however, separate registration
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Erf, 13, float, Erf);
|
||||
|
|
@ -1155,7 +1191,122 @@ Status BitShift<T>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
class Sin final : public OpKernel {
|
||||
Status BitwiseAnd<T>::Compute(OpKernelContext* context) const {
|
||||
ProcessBroadcastSpanFuncs funcs {
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
const T X = per_iter_bh.ScalarInput0<T>();
|
||||
auto Y = per_iter_bh.SpanInput1<T>();
|
||||
auto output = per_iter_bh.OutputSpan<T>();
|
||||
|
||||
std::transform(Y.begin(), Y.end(), output.begin(),
|
||||
[X](T y) {
|
||||
return std::bit_and<T>()(X, y);
|
||||
});
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
auto X = per_iter_bh.SpanInput0<T>();
|
||||
const T Y = per_iter_bh.ScalarInput1<T>();
|
||||
auto output = per_iter_bh.OutputSpan<T>();
|
||||
|
||||
std::transform(X.begin(), X.end(), output.begin(),
|
||||
[Y](T x) {
|
||||
return static_cast<T>(std::bit_and<T>()(x, Y));
|
||||
});
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
auto X = per_iter_bh.SpanInput0<T>();
|
||||
auto Y = per_iter_bh.SpanInput1<T>();
|
||||
auto output = per_iter_bh.OutputSpan<T>();
|
||||
|
||||
std::transform(X.begin(), X.end(), Y.begin(), output.begin(), std::bit_and<T>());
|
||||
}};
|
||||
|
||||
UntypedBroadcastTwo(*context, funcs, 1.0f);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BitwiseNot<T>::Compute(OpKernelContext* context) const {
|
||||
auto& input = *context->Input<Tensor>(0);
|
||||
auto& output = *context->Output(0, input.Shape());
|
||||
|
||||
std::transform(EigenMap<T>(input).array().begin(), EigenMap<T>(input).array().end(), EigenMap<T>(output).array().begin(), std::bit_not<T>());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BitwiseOr<T>::Compute(OpKernelContext* context) const {
|
||||
ProcessBroadcastSpanFuncs funcs{
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
const T X = per_iter_bh.ScalarInput0<T>();
|
||||
auto Y = per_iter_bh.SpanInput1<T>();
|
||||
auto output = per_iter_bh.OutputSpan<T>();
|
||||
|
||||
std::transform(Y.begin(), Y.end(), output.begin(),
|
||||
[X](T y) {
|
||||
return std::bit_or<T>()(X, y);
|
||||
});
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
auto X = per_iter_bh.SpanInput0<T>();
|
||||
const T Y = per_iter_bh.ScalarInput1<T>();
|
||||
auto output = per_iter_bh.OutputSpan<T>();
|
||||
|
||||
std::transform(X.begin(), X.end(), output.begin(),
|
||||
[Y](T x) {
|
||||
return static_cast<T>(std::bit_or<T>()(x, Y));
|
||||
});
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
auto X = per_iter_bh.SpanInput0<T>();
|
||||
auto Y = per_iter_bh.SpanInput1<T>();
|
||||
auto output = per_iter_bh.OutputSpan<T>();
|
||||
|
||||
std::transform(X.begin(), X.end(), Y.begin(), output.begin(), std::bit_or<T>());
|
||||
}};
|
||||
|
||||
UntypedBroadcastTwo(*context, funcs, 1.0f);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BitwiseXor<T>::Compute(OpKernelContext* context) const {
|
||||
ProcessBroadcastSpanFuncs funcs{
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
const T X = per_iter_bh.ScalarInput0<T>();
|
||||
auto Y = per_iter_bh.SpanInput1<T>();
|
||||
auto output = per_iter_bh.OutputSpan<T>();
|
||||
|
||||
std::transform(Y.begin(), Y.end(), output.begin(),
|
||||
[X](T y) {
|
||||
return std::bit_xor<T>()(X, y);
|
||||
});
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
auto X = per_iter_bh.SpanInput0<T>();
|
||||
const T Y = per_iter_bh.ScalarInput1<T>();
|
||||
auto output = per_iter_bh.OutputSpan<T>();
|
||||
|
||||
std::transform(X.begin(), X.end(), output.begin(),
|
||||
[Y](T x) {
|
||||
return static_cast<T>(std::bit_xor<T>()(x, Y));
|
||||
});
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
auto X = per_iter_bh.SpanInput0<T>();
|
||||
auto Y = per_iter_bh.SpanInput1<T>();
|
||||
auto output = per_iter_bh.OutputSpan<T>();
|
||||
|
||||
std::transform(X.begin(), X.end(), Y.begin(), output.begin(), std::bit_xor<T>());
|
||||
}};
|
||||
|
||||
UntypedBroadcastTwo(*context, funcs, 1.0f);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Sin final : public OpKernel {
|
||||
public:
|
||||
Sin(const OpKernelInfo& info) : OpKernel(info) {
|
||||
}
|
||||
|
|
|
|||
|
|
@ -427,6 +427,42 @@ class BitShift final : public OpKernel {
|
|||
bool shift_left_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BitwiseAnd final : public OpKernel {
|
||||
public:
|
||||
explicit BitwiseAnd(const OpKernelInfo& info) : OpKernel(info) {
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BitwiseNot final : public OpKernel {
|
||||
public:
|
||||
explicit BitwiseNot(const OpKernelInfo& info) : OpKernel(info) {
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BitwiseOr final : public OpKernel {
|
||||
public:
|
||||
explicit BitwiseOr(const OpKernelInfo& info) : OpKernel(info) {
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BitwiseXor final : public OpKernel {
|
||||
public:
|
||||
explicit BitwiseXor(const OpKernelInfo& info) : OpKernel(info) {
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
// PRelu is activation function, but it's closer to binary elementwise ops in implementation
|
||||
template <typename T>
|
||||
class PRelu final : public OpKernel {
|
||||
|
|
|
|||
|
|
@ -116,7 +116,6 @@
|
|||
"^test_div_uint8_cuda",
|
||||
"^test_add_uint8_cuda",
|
||||
"^test_roialign_aligned_*",
|
||||
"^test_bitwise_*",
|
||||
"^test_clip_default_int8_max_expanded_cpu",
|
||||
"^test_clip_default_int8_min_expanded_cpu",
|
||||
"^test_col2im_*",
|
||||
|
|
|
|||
Loading…
Reference in a new issue