diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h index efb258d03b..cc13805f9f 100644 --- a/onnxruntime/contrib_ops/cpu/activations.h +++ b/onnxruntime/contrib_ops/cpu/activations.h @@ -7,6 +7,7 @@ #include "core/framework/op_kernel.h" #include "core/util/math_cpuonly.h" #include "core/mlas/inc/mlas.h" +#include "core/platform/threadpool.h" #include namespace onnxruntime { @@ -38,6 +39,31 @@ class Gelu : public OpKernel { Status Compute(OpKernelContext* context) const override { const auto* X = context->Input(0); Tensor* Y = context->Output(0, X->Shape()); + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); + if (nullptr != tp) { + const T* input = X->template Data(); + T* output = Y->template MutableData(); + int task_count = tp->NumThreads() + 1; + int64_t elem_count = X->Shape().Size(); + if (elem_count > task_count) { + tp->ParallelFor(task_count, [input, + output, + elem_count, + task_count](int32_t i) { + int64_t elem_inx_start = i * elem_count / task_count; + int64_t elem_inx_end = (i + 1) * elem_count / task_count; + for (int64_t elem_inx = elem_inx_start; elem_inx < elem_inx_end; elem_inx++) { + output[elem_inx] = input[elem_inx] * static_cast(M_SQRT1_2); + } + MlasComputeErf(output + elem_inx_start, output + elem_inx_start, elem_inx_end - elem_inx_start); + for (int64_t elem_inx = elem_inx_start; elem_inx < elem_inx_end; elem_inx++) { + output[elem_inx] = 0.5f * input[elem_inx] * (output[elem_inx] + 1.0f); + } + }); + return Status::OK(); + } + } + EIGEN_X_VAR(xm); EIGEN_Y_VAR(ym); ym = xm * static_cast(M_SQRT1_2);