mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Optimize use of Eigen::DenseBase::select() for PRelu (#15287)
MSVC and gcc are both not good at optimizing select(), even in trivial usage outside of ORT. gcc seems to do better with -ffast-math (not used by ORT) but /fp:fast does nothing for MSVC This PR delivers a 33% speedup on the same model (360us -> 270us on Windows; 205 us -> 153 us on Linux; measured on different systems). TODO: Examine and fix Elu and other similar activation functions for the use of `Eigen::select` Co-authored-by: @fpribeiro ### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
fff75a301c
commit
c06ab5e353
1 changed files with 26 additions and 16 deletions
|
|
@ -370,7 +370,7 @@ REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mean, 8, 12, float, Mean_8);
|
|||
REG_ELEMENTWISE_TYPED_KERNEL(Mean, 13, float, Mean_8);
|
||||
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint8_t, BitShift);
|
||||
//REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint16_t, BitShift);
|
||||
// REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint16_t, BitShift);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint32_t, BitShift);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(BitShift, 11, uint64_t, BitShift);
|
||||
|
||||
|
|
@ -1192,16 +1192,16 @@ Status BitShift<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
template <typename T>
|
||||
Status BitwiseAnd<T>::Compute(OpKernelContext* context) const {
|
||||
ProcessBroadcastSpanFuncs funcs {
|
||||
ProcessBroadcastSpanFuncs funcs{
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
const T X = per_iter_bh.ScalarInput0<T>();
|
||||
auto Y = per_iter_bh.SpanInput1<T>();
|
||||
auto output = per_iter_bh.OutputSpan<T>();
|
||||
|
||||
std::transform(Y.begin(), Y.end(), output.begin(),
|
||||
[X](T y) {
|
||||
return std::bit_and<T>()(X, y);
|
||||
});
|
||||
[X](T y) {
|
||||
return std::bit_and<T>()(X, y);
|
||||
});
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
auto X = per_iter_bh.SpanInput0<T>();
|
||||
|
|
@ -1209,9 +1209,9 @@ Status BitwiseAnd<T>::Compute(OpKernelContext* context) const {
|
|||
auto output = per_iter_bh.OutputSpan<T>();
|
||||
|
||||
std::transform(X.begin(), X.end(), output.begin(),
|
||||
[Y](T x) {
|
||||
return static_cast<T>(std::bit_and<T>()(x, Y));
|
||||
});
|
||||
[Y](T x) {
|
||||
return static_cast<T>(std::bit_and<T>()(x, Y));
|
||||
});
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
auto X = per_iter_bh.SpanInput0<T>();
|
||||
|
|
@ -1220,7 +1220,7 @@ Status BitwiseAnd<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
std::transform(X.begin(), X.end(), Y.begin(), output.begin(), std::bit_and<T>());
|
||||
}};
|
||||
|
||||
|
||||
UntypedBroadcastTwo(*context, funcs, 1.0f);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -1306,7 +1306,7 @@ Status BitwiseXor<T>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
class Sin final : public OpKernel {
|
||||
class Sin final : public OpKernel {
|
||||
public:
|
||||
Sin(const OpKernelInfo& info) : OpKernel(info) {
|
||||
}
|
||||
|
|
@ -1580,14 +1580,24 @@ Status PRelu<float>::Compute(OpKernelContext* context) const {
|
|||
per_iter_bh.OutputEigen<float>() = input0 * per_iter_bh.EigenInput1<float>().array();
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
auto input0 = per_iter_bh.EigenInput0<float>();
|
||||
float input1 = per_iter_bh.ScalarInput1<float>();
|
||||
per_iter_bh.OutputEigen<float>() = (input0.array() > 0).select(input0, input0 * input1);
|
||||
const float* input0 = per_iter_bh.EigenInput0<float>().data();
|
||||
const float input1 = per_iter_bh.ScalarInput1<float>();
|
||||
float* output = per_iter_bh.OutputEigen<float>().data();
|
||||
size_t size = per_iter_bh.OutputEigen<float>().size();
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
output[i] = static_cast<float>(input0[i] > 0) * input0[i] +
|
||||
(1.0f - static_cast<float>(input0[i] > 0)) * input0[i] * input1;
|
||||
}
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
auto input0 = per_iter_bh.EigenInput0<float>();
|
||||
auto input1 = per_iter_bh.EigenInput1<float>();
|
||||
per_iter_bh.OutputEigen<float>() = (input0.array() > 0).select(input0, input0.cwiseProduct(input1));
|
||||
const float* input0 = per_iter_bh.EigenInput0<float>().data();
|
||||
const float* input1 = per_iter_bh.EigenInput1<float>().data();
|
||||
float* output = per_iter_bh.OutputEigen<float>().data();
|
||||
size_t size = per_iter_bh.OutputEigen<float>().size();
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
output[i] = static_cast<float>(input0[i] > 0) * input0[i] +
|
||||
(1.0f - static_cast<float>(input0[i] > 0)) * input0[i] * input1[i];
|
||||
}
|
||||
}};
|
||||
|
||||
UntypedBroadcastTwo(*context, funcs, 1.0);
|
||||
|
|
|
|||
Loading…
Reference in a new issue