From 6f95cdfa68d5d950f342f4f8ee83930cfbff8d67 Mon Sep 17 00:00:00 2001 From: "M. Zeeshan Siddiqui" Date: Mon, 4 May 2020 15:23:10 -0700 Subject: [PATCH] Use new cost based threadpool abstractions in CPU gradient operators. (#3807) * Use ThreadPool abstractions instead of OpenMP. * PR feedback. --- .../training_ops/cpu/gist/gistdecode_op.cc | 5 --- .../training_ops/cpu/gist/gistencode_op.cc | 5 --- .../training_ops/cpu/tensor/gather_grad.cc | 44 +++++++++++-------- .../training_ops/cpu/tensor/gather_grad.h | 3 +- 4 files changed, 27 insertions(+), 30 deletions(-) diff --git a/orttraining/orttraining/training_ops/cpu/gist/gistdecode_op.cc b/orttraining/orttraining/training_ops/cpu/gist/gistdecode_op.cc index 7d1810b337..180497d162 100644 --- a/orttraining/orttraining/training_ops/cpu/gist/gistdecode_op.cc +++ b/orttraining/orttraining/training_ops/cpu/gist/gistdecode_op.cc @@ -17,14 +17,9 @@ Status GistBinarizeDecoderOp::Compute(OpKernelContext* context) const { const auto* X = context->Input(1); ORT_ENFORCE(X != nullptr); const TensorShape& shape = X->Shape(); - Tensor* Y = context->Output(0, shape); - const auto* src = X->template Data(); auto* dst = Y->template MutableData(); -#ifdef USE_OPENMP -#pragma omp parallel for -#endif for (int64_t i = 0; i < X->Shape().Size(); ++i) { dst[i] = src[i] ? 1.0f : 0.0f; } diff --git a/orttraining/orttraining/training_ops/cpu/gist/gistencode_op.cc b/orttraining/orttraining/training_ops/cpu/gist/gistencode_op.cc index 9d28a484d8..cb81fa4ca5 100644 --- a/orttraining/orttraining/training_ops/cpu/gist/gistencode_op.cc +++ b/orttraining/orttraining/training_ops/cpu/gist/gistencode_op.cc @@ -19,14 +19,9 @@ Status GistBinarizeEncoderOp::Compute(OpKernelContext* context) const { const TensorShape& shape = X->Shape(); Tensor* Y = context->Output(0, shape); Tensor* Y1 = context->Output(1, shape); - auto X_type = X->DataType(); - auto* src = X->template Data(); auto* dst = Y1->template MutableData(); -#ifdef USE_OPENMP -#pragma omp parallel for -#endif for (int64_t i = 0; i < X->Shape().Size(); ++i) { dst[i] = src[i] > 0.0; } diff --git a/orttraining/orttraining/training_ops/cpu/tensor/gather_grad.cc b/orttraining/orttraining/training_ops/cpu/tensor/gather_grad.cc index 9e9b9ba99e..5f37bf35ce 100644 --- a/orttraining/orttraining/training_ops/cpu/tensor/gather_grad.cc +++ b/orttraining/orttraining/training_ops/cpu/tensor/gather_grad.cc @@ -3,6 +3,7 @@ #include "orttraining/training_ops/cpu/tensor/gather_grad.h" #include "core/common/common.h" +#include "core/platform/threadpool.h" namespace onnxruntime { namespace contrib { @@ -20,14 +21,14 @@ ONNX_OPERATOR_KERNEL_EX( DataTypeImpl::GetTensorType()}), GatherGrad); -#define TYPED_GRAD_FUNCTION_CALL(T) \ - if (T_type == DataTypeImpl::GetType()) { \ - if (Tind_type == DataTypeImpl::GetType()) { \ - return ComputeImpl(data_shape, indices, grad, output); \ - } \ - if (Tind_type == DataTypeImpl::GetType()) { \ - return ComputeImpl(data_shape, indices, grad, output); \ - } \ +#define TYPED_GRAD_FUNCTION_CALL(T, tp) \ + if (T_type == DataTypeImpl::GetType()) { \ + if (Tind_type == DataTypeImpl::GetType()) { \ + return ComputeImpl(data_shape, indices, grad, output, tp); \ + } \ + if (Tind_type == DataTypeImpl::GetType()) { \ + return ComputeImpl(data_shape, indices, grad, output, tp); \ + } \ } Status GatherGrad::Compute(OpKernelContext* context) const { @@ -41,14 +42,15 @@ Status GatherGrad::Compute(OpKernelContext* context) const { MLDataType T_type = grad.DataType(); MLDataType Tind_type = indices.DataType(); - TYPED_GRAD_FUNCTION_CALL(float); - TYPED_GRAD_FUNCTION_CALL(double); + TYPED_GRAD_FUNCTION_CALL(float, context->GetOperatorThreadPool()); + TYPED_GRAD_FUNCTION_CALL(double, context->GetOperatorThreadPool()); return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for T or Tind not supported yet in GatherGrad."); } template -Status GatherGrad::ComputeImpl(const TensorShape& data_shape, const Tensor& indices, const Tensor& grad, Tensor& output) const { +Status GatherGrad::ComputeImpl(const TensorShape& data_shape, const Tensor& indices, const Tensor& grad, Tensor& output, + concurrency::ThreadPool* tp) const { const Tind* indices_data = indices.template Data(); const T* grad_data = grad.template Data(); T* output_data = output.template MutableData(); @@ -71,21 +73,25 @@ Status GatherGrad::ComputeImpl(const TensorShape& data_shape, const Tensor& indi } } -#ifdef USE_OPENMP -#pragma omp parallel for -#endif - for (int64_t g = 0; g < grad_size; g++) { + std::mutex mtx; + auto lambda = [&](int64_t g) { const int64_t input_block_index = g / output_block_size; const int64_t block_offset = g % output_block_size; const int64_t indices_index = block_offset / block_size; const int64_t offset = block_offset % block_size; const Tind idx = indices_data[indices_index]; const int64_t input_index = input_block_index * input_block_size + idx * block_size + offset; -#ifdef USE_OPENMP -#pragma omp atomic -#endif + //REVIEW(codemzs): This lock can become a performance bottleneck. An area for potential improvement. + std::lock_guard lck(mtx); output_data[input_index] += grad_data[g]; - } + }; + + concurrency::ThreadPool::TryParallelFor(tp, grad_size, static_cast(block_size), + [&lambda](ptrdiff_t first, ptrdiff_t last) { + for (int index = static_cast(first), end = static_cast(last); index < end; ++index) { + lambda(index); + } + }); return Status::OK(); } diff --git a/orttraining/orttraining/training_ops/cpu/tensor/gather_grad.h b/orttraining/orttraining/training_ops/cpu/tensor/gather_grad.h index 0898949a11..624326d9dd 100644 --- a/orttraining/orttraining/training_ops/cpu/tensor/gather_grad.h +++ b/orttraining/orttraining/training_ops/cpu/tensor/gather_grad.h @@ -20,7 +20,8 @@ class GatherGrad final : public OpKernel { private: template - Status ComputeImpl(const TensorShape& data_shape, const Tensor& indices, const Tensor& grad, Tensor& output) const; + Status ComputeImpl(const TensorShape& data_shape, const Tensor& indices, const Tensor& grad, Tensor& output, + concurrency::ThreadPool* tp) const; int64_t axis_; };