Parallel Gelu with ParallelFor (#2399)

Parallel Gelu to get better performance for Gelu
This commit is contained in:
Yufeng Li 2019-11-22 11:48:46 -08:00 committed by GitHub
parent ca0ed96621
commit d49cbf6e08
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -7,6 +7,7 @@
#include "core/framework/op_kernel.h"
#include "core/util/math_cpuonly.h"
#include "core/mlas/inc/mlas.h"
#include "core/platform/threadpool.h"
#include <unsupported/Eigen/SpecialFunctions>
namespace onnxruntime {
@ -38,6 +39,31 @@ class Gelu : public OpKernel {
Status Compute(OpKernelContext* context) const override {
const auto* X = context->Input<Tensor>(0);
Tensor* Y = context->Output(0, X->Shape());
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
if (nullptr != tp) {
const T* input = X->template Data<T>();
T* output = Y->template MutableData<T>();
int task_count = tp->NumThreads() + 1;
int64_t elem_count = X->Shape().Size();
if (elem_count > task_count) {
tp->ParallelFor(task_count, [input,
output,
elem_count,
task_count](int32_t i) {
int64_t elem_inx_start = i * elem_count / task_count;
int64_t elem_inx_end = (i + 1) * elem_count / task_count;
for (int64_t elem_inx = elem_inx_start; elem_inx < elem_inx_end; elem_inx++) {
output[elem_inx] = input[elem_inx] * static_cast<float>(M_SQRT1_2);
}
MlasComputeErf(output + elem_inx_start, output + elem_inx_start, elem_inx_end - elem_inx_start);
for (int64_t elem_inx = elem_inx_start; elem_inx < elem_inx_end; elem_inx++) {
output[elem_inx] = 0.5f * input[elem_inx] * (output[elem_inx] + 1.0f);
}
});
return Status::OK();
}
}
EIGEN_X_VAR(xm);
EIGEN_Y_VAR(ym);
ym = xm * static_cast<float>(M_SQRT1_2);