mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Parallel Gelu with ParallelFor (#2399)
Parallel Gelu to get better performance for Gelu
This commit is contained in:
parent
ca0ed96621
commit
d49cbf6e08
1 changed files with 26 additions and 0 deletions
|
|
@ -7,6 +7,7 @@
|
|||
#include "core/framework/op_kernel.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
#include <unsupported/Eigen/SpecialFunctions>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -38,6 +39,31 @@ class Gelu : public OpKernel {
|
|||
Status Compute(OpKernelContext* context) const override {
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
Tensor* Y = context->Output(0, X->Shape());
|
||||
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
|
||||
if (nullptr != tp) {
|
||||
const T* input = X->template Data<T>();
|
||||
T* output = Y->template MutableData<T>();
|
||||
int task_count = tp->NumThreads() + 1;
|
||||
int64_t elem_count = X->Shape().Size();
|
||||
if (elem_count > task_count) {
|
||||
tp->ParallelFor(task_count, [input,
|
||||
output,
|
||||
elem_count,
|
||||
task_count](int32_t i) {
|
||||
int64_t elem_inx_start = i * elem_count / task_count;
|
||||
int64_t elem_inx_end = (i + 1) * elem_count / task_count;
|
||||
for (int64_t elem_inx = elem_inx_start; elem_inx < elem_inx_end; elem_inx++) {
|
||||
output[elem_inx] = input[elem_inx] * static_cast<float>(M_SQRT1_2);
|
||||
}
|
||||
MlasComputeErf(output + elem_inx_start, output + elem_inx_start, elem_inx_end - elem_inx_start);
|
||||
for (int64_t elem_inx = elem_inx_start; elem_inx < elem_inx_end; elem_inx++) {
|
||||
output[elem_inx] = 0.5f * input[elem_inx] * (output[elem_inx] + 1.0f);
|
||||
}
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_X_VAR(xm);
|
||||
EIGEN_Y_VAR(ym);
|
||||
ym = xm * static_cast<float>(M_SQRT1_2);
|
||||
|
|
|
|||
Loading…
Reference in a new issue