Use new cost based threadpool abstractions in CPU gradient operators. (#3807)

* Use ThreadPool abstractions instead of OpenMP.

* PR feedback.
This commit is contained in:
M. Zeeshan Siddiqui 2020-05-04 15:23:10 -07:00 committed by GitHub
parent 156368b67f
commit 6f95cdfa68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 30 deletions

View file

@ -17,14 +17,9 @@ Status GistBinarizeDecoderOp::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(1);
ORT_ENFORCE(X != nullptr);
const TensorShape& shape = X->Shape();
Tensor* Y = context->Output(0, shape);
const auto* src = X->template Data<bool>();
auto* dst = Y->template MutableData<float>();
#ifdef USE_OPENMP
#pragma omp parallel for
#endif
for (int64_t i = 0; i < X->Shape().Size(); ++i) {
dst[i] = src[i] ? 1.0f : 0.0f;
}

View file

@ -19,14 +19,9 @@ Status GistBinarizeEncoderOp::Compute(OpKernelContext* context) const {
const TensorShape& shape = X->Shape();
Tensor* Y = context->Output(0, shape);
Tensor* Y1 = context->Output(1, shape);
auto X_type = X->DataType();
auto* src = X->template Data<float>();
auto* dst = Y1->template MutableData<bool>();
#ifdef USE_OPENMP
#pragma omp parallel for
#endif
for (int64_t i = 0; i < X->Shape().Size(); ++i) {
dst[i] = src[i] > 0.0;
}

View file

@ -3,6 +3,7 @@
#include "orttraining/training_ops/cpu/tensor/gather_grad.h"
#include "core/common/common.h"
#include "core/platform/threadpool.h"
namespace onnxruntime {
namespace contrib {
@ -20,14 +21,14 @@ ONNX_OPERATOR_KERNEL_EX(
DataTypeImpl::GetTensorType<int64_t>()}),
GatherGrad);
#define TYPED_GRAD_FUNCTION_CALL(T) \
if (T_type == DataTypeImpl::GetType<T>()) { \
if (Tind_type == DataTypeImpl::GetType<int32_t>()) { \
return ComputeImpl<T, int32_t>(data_shape, indices, grad, output); \
} \
if (Tind_type == DataTypeImpl::GetType<int64_t>()) { \
return ComputeImpl<T, int64_t>(data_shape, indices, grad, output); \
} \
#define TYPED_GRAD_FUNCTION_CALL(T, tp) \
if (T_type == DataTypeImpl::GetType<T>()) { \
if (Tind_type == DataTypeImpl::GetType<int32_t>()) { \
return ComputeImpl<T, int32_t>(data_shape, indices, grad, output, tp); \
} \
if (Tind_type == DataTypeImpl::GetType<int64_t>()) { \
return ComputeImpl<T, int64_t>(data_shape, indices, grad, output, tp); \
} \
}
Status GatherGrad::Compute(OpKernelContext* context) const {
@ -41,14 +42,15 @@ Status GatherGrad::Compute(OpKernelContext* context) const {
MLDataType T_type = grad.DataType();
MLDataType Tind_type = indices.DataType();
TYPED_GRAD_FUNCTION_CALL(float);
TYPED_GRAD_FUNCTION_CALL(double);
TYPED_GRAD_FUNCTION_CALL(float, context->GetOperatorThreadPool());
TYPED_GRAD_FUNCTION_CALL(double, context->GetOperatorThreadPool());
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for T or Tind not supported yet in GatherGrad.");
}
template <typename T, typename Tind>
Status GatherGrad::ComputeImpl(const TensorShape& data_shape, const Tensor& indices, const Tensor& grad, Tensor& output) const {
Status GatherGrad::ComputeImpl(const TensorShape& data_shape, const Tensor& indices, const Tensor& grad, Tensor& output,
concurrency::ThreadPool* tp) const {
const Tind* indices_data = indices.template Data<Tind>();
const T* grad_data = grad.template Data<T>();
T* output_data = output.template MutableData<T>();
@ -71,21 +73,25 @@ Status GatherGrad::ComputeImpl(const TensorShape& data_shape, const Tensor& indi
}
}
#ifdef USE_OPENMP
#pragma omp parallel for
#endif
for (int64_t g = 0; g < grad_size; g++) {
std::mutex mtx;
auto lambda = [&](int64_t g) {
const int64_t input_block_index = g / output_block_size;
const int64_t block_offset = g % output_block_size;
const int64_t indices_index = block_offset / block_size;
const int64_t offset = block_offset % block_size;
const Tind idx = indices_data[indices_index];
const int64_t input_index = input_block_index * input_block_size + idx * block_size + offset;
#ifdef USE_OPENMP
#pragma omp atomic
#endif
//REVIEW(codemzs): This lock can become a performance bottleneck. An area for potential improvement.
std::lock_guard<std::mutex> lck(mtx);
output_data[input_index] += grad_data[g];
}
};
concurrency::ThreadPool::TryParallelFor(tp, grad_size, static_cast<double>(block_size),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (int index = static_cast<int>(first), end = static_cast<int>(last); index < end; ++index) {
lambda(index);
}
});
return Status::OK();
}

View file

@ -20,7 +20,8 @@ class GatherGrad final : public OpKernel {
private:
template <typename T, typename Tind>
Status ComputeImpl(const TensorShape& data_shape, const Tensor& indices, const Tensor& grad, Tensor& output) const;
Status ComputeImpl(const TensorShape& data_shape, const Tensor& indices, const Tensor& grad, Tensor& output,
concurrency::ThreadPool* tp) const;
int64_t axis_;
};