mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Use new cost based threadpool abstractions in CPU gradient operators. (#3807)
* Use ThreadPool abstractions instead of OpenMP. * PR feedback.
This commit is contained in:
parent
156368b67f
commit
6f95cdfa68
4 changed files with 27 additions and 30 deletions
|
|
@ -17,14 +17,9 @@ Status GistBinarizeDecoderOp::Compute(OpKernelContext* context) const {
|
|||
const auto* X = context->Input<Tensor>(1);
|
||||
ORT_ENFORCE(X != nullptr);
|
||||
const TensorShape& shape = X->Shape();
|
||||
|
||||
Tensor* Y = context->Output(0, shape);
|
||||
|
||||
const auto* src = X->template Data<bool>();
|
||||
auto* dst = Y->template MutableData<float>();
|
||||
#ifdef USE_OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t i = 0; i < X->Shape().Size(); ++i) {
|
||||
dst[i] = src[i] ? 1.0f : 0.0f;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,14 +19,9 @@ Status GistBinarizeEncoderOp::Compute(OpKernelContext* context) const {
|
|||
const TensorShape& shape = X->Shape();
|
||||
Tensor* Y = context->Output(0, shape);
|
||||
Tensor* Y1 = context->Output(1, shape);
|
||||
|
||||
auto X_type = X->DataType();
|
||||
|
||||
auto* src = X->template Data<float>();
|
||||
auto* dst = Y1->template MutableData<bool>();
|
||||
#ifdef USE_OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t i = 0; i < X->Shape().Size(); ++i) {
|
||||
dst[i] = src[i] > 0.0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "orttraining/training_ops/cpu/tensor/gather_grad.h"
|
||||
#include "core/common/common.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -20,14 +21,14 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
GatherGrad);
|
||||
|
||||
#define TYPED_GRAD_FUNCTION_CALL(T) \
|
||||
if (T_type == DataTypeImpl::GetType<T>()) { \
|
||||
if (Tind_type == DataTypeImpl::GetType<int32_t>()) { \
|
||||
return ComputeImpl<T, int32_t>(data_shape, indices, grad, output); \
|
||||
} \
|
||||
if (Tind_type == DataTypeImpl::GetType<int64_t>()) { \
|
||||
return ComputeImpl<T, int64_t>(data_shape, indices, grad, output); \
|
||||
} \
|
||||
#define TYPED_GRAD_FUNCTION_CALL(T, tp) \
|
||||
if (T_type == DataTypeImpl::GetType<T>()) { \
|
||||
if (Tind_type == DataTypeImpl::GetType<int32_t>()) { \
|
||||
return ComputeImpl<T, int32_t>(data_shape, indices, grad, output, tp); \
|
||||
} \
|
||||
if (Tind_type == DataTypeImpl::GetType<int64_t>()) { \
|
||||
return ComputeImpl<T, int64_t>(data_shape, indices, grad, output, tp); \
|
||||
} \
|
||||
}
|
||||
|
||||
Status GatherGrad::Compute(OpKernelContext* context) const {
|
||||
|
|
@ -41,14 +42,15 @@ Status GatherGrad::Compute(OpKernelContext* context) const {
|
|||
|
||||
MLDataType T_type = grad.DataType();
|
||||
MLDataType Tind_type = indices.DataType();
|
||||
TYPED_GRAD_FUNCTION_CALL(float);
|
||||
TYPED_GRAD_FUNCTION_CALL(double);
|
||||
TYPED_GRAD_FUNCTION_CALL(float, context->GetOperatorThreadPool());
|
||||
TYPED_GRAD_FUNCTION_CALL(double, context->GetOperatorThreadPool());
|
||||
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for T or Tind not supported yet in GatherGrad.");
|
||||
}
|
||||
|
||||
template <typename T, typename Tind>
|
||||
Status GatherGrad::ComputeImpl(const TensorShape& data_shape, const Tensor& indices, const Tensor& grad, Tensor& output) const {
|
||||
Status GatherGrad::ComputeImpl(const TensorShape& data_shape, const Tensor& indices, const Tensor& grad, Tensor& output,
|
||||
concurrency::ThreadPool* tp) const {
|
||||
const Tind* indices_data = indices.template Data<Tind>();
|
||||
const T* grad_data = grad.template Data<T>();
|
||||
T* output_data = output.template MutableData<T>();
|
||||
|
|
@ -71,21 +73,25 @@ Status GatherGrad::ComputeImpl(const TensorShape& data_shape, const Tensor& indi
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef USE_OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t g = 0; g < grad_size; g++) {
|
||||
std::mutex mtx;
|
||||
auto lambda = [&](int64_t g) {
|
||||
const int64_t input_block_index = g / output_block_size;
|
||||
const int64_t block_offset = g % output_block_size;
|
||||
const int64_t indices_index = block_offset / block_size;
|
||||
const int64_t offset = block_offset % block_size;
|
||||
const Tind idx = indices_data[indices_index];
|
||||
const int64_t input_index = input_block_index * input_block_size + idx * block_size + offset;
|
||||
#ifdef USE_OPENMP
|
||||
#pragma omp atomic
|
||||
#endif
|
||||
//REVIEW(codemzs): This lock can become a performance bottleneck. An area for potential improvement.
|
||||
std::lock_guard<std::mutex> lck(mtx);
|
||||
output_data[input_index] += grad_data[g];
|
||||
}
|
||||
};
|
||||
|
||||
concurrency::ThreadPool::TryParallelFor(tp, grad_size, static_cast<double>(block_size),
|
||||
[&lambda](ptrdiff_t first, ptrdiff_t last) {
|
||||
for (int index = static_cast<int>(first), end = static_cast<int>(last); index < end; ++index) {
|
||||
lambda(index);
|
||||
}
|
||||
});
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,8 @@ class GatherGrad final : public OpKernel {
|
|||
|
||||
private:
|
||||
template <typename T, typename Tind>
|
||||
Status ComputeImpl(const TensorShape& data_shape, const Tensor& indices, const Tensor& grad, Tensor& output) const;
|
||||
Status ComputeImpl(const TensorShape& data_shape, const Tensor& indices, const Tensor& grad, Tensor& output,
|
||||
concurrency::ThreadPool* tp) const;
|
||||
|
||||
int64_t axis_;
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue