From 2b8d9ef0fde81ea362d8b15577b4d839e804b228 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Sat, 2 May 2020 00:09:31 -0700 Subject: [PATCH] Refactor scatter/gather ops to use the new cost based threadpool abstractions. (#3776) * Update Scatter and Gather ops by replacing pragma omp invocations with the new threadpool abstractions. * Use forward declarations * PR comments --- .../core/providers/cpu/rnn/deep_cpu_lstm.cc | 2 +- .../core/providers/cpu/tensor/gather.cc | 23 +++-- .../core/providers/cpu/tensor/gather_nd.cc | 61 +++++++------ .../core/providers/cpu/tensor/gather_nd.h | 11 +-- .../core/providers/cpu/tensor/scatter_nd.cc | 88 +++++++++---------- .../core/providers/cpu/tensor/scatter_nd.h | 56 ++++++------ 6 files changed, 128 insertions(+), 113 deletions(-) diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index f8a8b7dac8..9d4e7c2ec2 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -175,7 +175,7 @@ static inline void ExecuteLambdaInParallel(TLambda lambda, int max, int step, do } #else const int total_tasks = max / (step > 0 ? step : 1) + (max % step > 0 ? 1 : 0); - concurrency::ThreadPool::TryParallelFor(ttp, total_tasks, cost, [lambda, step](ptrdiff_t first, ptrdiff_t last) { + concurrency::ThreadPool::TryParallelFor(ttp, total_tasks, cost, [&lambda, step](ptrdiff_t first, ptrdiff_t last) { for (int i = static_cast(first), end = static_cast(last); i < end; ++i) { lambda(i * step); } diff --git a/onnxruntime/core/providers/cpu/tensor/gather.cc b/onnxruntime/core/providers/cpu/tensor/gather.cc index 968a821af2..ae7fc852f0 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather.cc +++ b/onnxruntime/core/providers/cpu/tensor/gather.cc @@ -4,6 +4,7 @@ //https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gather #include "core/providers/cpu/tensor/gather.h" #include "core/common/common.h" +#include "core/platform/threadpool.h" namespace onnxruntime { @@ -57,11 +58,10 @@ template Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uint8_t* dst_base, bool is_string_type, const size_t element_bytes, const int64_t block_size, const int64_t M, const int64_t N, const int64_t data_batch_bytes, const int64_t gathered_batch_bytes, - const TensorShape& input_data_shape, const int64_t axis) { + const TensorShape& input_data_shape, const int64_t axis, concurrency::ThreadPool* tp) { const Tin* indices_data = indices_tensor->template Data(); // Check the indices first in case there's a out of bound index. - // We can't merge this code in the omp loop below as omp does not allow return in the loop auto axis_dim_limit = input_data_shape[axis]; for (int64_t i = 0; i < N; ++i) { @@ -73,10 +73,7 @@ Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uin } } -#ifdef _OPENMP -#pragma omp parallel for -#endif - for (int64_t index = 0; index < M * N; ++index) { + auto lambda = [&](int64_t index) { int64_t batch = index / N; int64_t i = index % N; @@ -93,7 +90,13 @@ Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uin } else { memcpy(dst_base + dst_offset, src_base + src_offset, block_size); } - } + }; + concurrency::ThreadPool::TryParallelFor(tp, M * N, static_cast(block_size), + [&lambda](ptrdiff_t first, ptrdiff_t last) { + for (int index = static_cast(first), end = static_cast(last); index < end; ++index) { + lambda(index); + } + }); return Status::OK(); } @@ -117,13 +120,15 @@ Status Gather::Compute(OpKernelContext* context) const { const auto* src_base = static_cast(p.input_tensor->DataRaw()); auto* dst_base = static_cast(p.output_tensor->MutableDataRaw()); + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); + if (p.indices_tensor->IsDataType()) { return GatherCopyData(p.indices_tensor, src_base, dst_base, is_string_type, element_bytes, - block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis); + block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis, tp); } if (p.indices_tensor->IsDataType()) { return GatherCopyData(p.indices_tensor, src_base, dst_base, is_string_type, element_bytes, - block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis); + block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis, tp); } return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tind not supported yet in Gather."); diff --git a/onnxruntime/core/providers/cpu/tensor/gather_nd.cc b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc index e92f456c86..94b8f22539 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather_nd.cc +++ b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "gather_nd.h" +#include "core/platform/threadpool.h" namespace onnxruntime { @@ -43,7 +44,7 @@ ONNX_CPU_OPERATOR_KERNEL( GatherND); template -Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const { +Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p, concurrency::ThreadPool* tp) const { const auto* input_tensor = context->Input(0); const auto* indices_tensor = context->Input(1); ORT_ENFORCE(input_tensor != nullptr && indices_tensor != nullptr, "GatherNDBase PrepareForCompute: Input count mismatch"); @@ -72,9 +73,6 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con auto* output_tensor = context->Output(0, TensorShape(std::move(shape))); std::vector sizes_from_slice_dims(num_slice_dims); -#ifdef _OPENMP -#pragma omp parallel for -#endif for (int64_t i = 0; i < num_slice_dims; ++i) { sizes_from_slice_dims[i] = input_shape.SizeFromDimension(batch_dims_ + i + 1); } @@ -95,10 +93,7 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con } // Compute the element_offset -#ifdef _OPENMP -#pragma omp parallel for -#endif - for (int64_t slice_idx = 0; slice_idx < num_slices; ++slice_idx) { + auto lambda = [&](int64_t slice_idx) { const size_t batch_idx = slice_idx / num_slices_per_batch; const size_t input_base_offset = batch_idx * input_batch_stride; @@ -118,46 +113,58 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con } p.slice_offsets[slice_idx] = input_base_offset + relative_slice_offset; - } + }; + concurrency::ThreadPool::TryParallelFor(tp, num_slices, static_cast(num_slice_dims), + [&lambda](ptrdiff_t first, ptrdiff_t last) { + for (int slice_idx = static_cast(first), end = static_cast(last); slice_idx < end; ++slice_idx) { + lambda(slice_idx); + } + }); return err_index == 0 ? Status::OK() : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index found, index = ", err_index); } -template Status GatherNDBase::PrepareForCompute(OpKernelContext*, Prepare&) const; -template Status GatherNDBase::PrepareForCompute(OpKernelContext*, Prepare&) const; +template Status GatherNDBase::PrepareForCompute(OpKernelContext*, Prepare&, concurrency::ThreadPool*) const; +template Status GatherNDBase::PrepareForCompute(OpKernelContext*, Prepare&, concurrency::ThreadPool*) const; Status GatherND::Compute(OpKernelContext* context) const { Prepare p; + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); ORT_RETURN_IF_ERROR(context->Input(1)->IsDataType() - ? PrepareForCompute(context, p) - : PrepareForCompute(context, p)); + ? PrepareForCompute(context, p, tp) + : PrepareForCompute(context, p, tp)); - return nullptr == p.input_str_base ? GatherNumber(p) : GatherString(p); + return nullptr == p.input_str_base ? GatherNumber(p, tp) : GatherString(p, tp); } -Status GatherND::GatherNumber(const Prepare& p) const { -#ifdef _OPENMP -#pragma omp parallel for -#endif - for (int64_t slice_idx = 0; slice_idx < static_cast(p.slice_offsets.size()); ++slice_idx) { +Status GatherND::GatherNumber(const Prepare& p, concurrency::ThreadPool* tp) const { + auto lambda = [&](int64_t slice_idx) { memcpy(p.output_base + slice_idx * p.bytes_per_slice, p.input_base + p.slice_offsets[slice_idx] * p.element_bytes, p.bytes_per_slice); - } - + }; + concurrency::ThreadPool::TryParallelFor(tp, p.slice_offsets.size(), static_cast(p.bytes_per_slice), + [&lambda](ptrdiff_t first, ptrdiff_t last) { + for (int slice_idx = static_cast(first), end = static_cast(last); slice_idx < end; ++slice_idx) { + lambda(slice_idx); + } + }); return Status::OK(); } -Status GatherND::GatherString(const Prepare& p) const { -#ifdef _OPENMP -#pragma omp parallel for -#endif - for (int64_t slice_idx = 0; slice_idx < static_cast(p.slice_offsets.size()); ++slice_idx) { +Status GatherND::GatherString(const Prepare& p, concurrency::ThreadPool* tp) const { + auto lambda = [&](int64_t slice_idx) { const int64_t slice_base_offset = slice_idx * p.element_count_per_slice; for (int64_t j = 0; j < static_cast(p.element_count_per_slice); ++j) { p.output_str_base[slice_base_offset + j] = p.input_str_base[p.slice_offsets[slice_idx] + j]; } - } + }; + concurrency::ThreadPool::TryParallelFor(tp, p.slice_offsets.size(), static_cast(p.element_count_per_slice), + [&lambda](ptrdiff_t first, ptrdiff_t last) { + for (int slice_idx = static_cast(first), end = static_cast(last); slice_idx < end; ++slice_idx) { + lambda(slice_idx); + } + }); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/tensor/gather_nd.h b/onnxruntime/core/providers/cpu/tensor/gather_nd.h index d3765751bc..26c3088b4a 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather_nd.h +++ b/onnxruntime/core/providers/cpu/tensor/gather_nd.h @@ -5,10 +5,11 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/platform/threadpool.h" namespace onnxruntime { - +namespace concurrency { +class ThreadPool; +} class GatherNDBase { protected: struct Prepare { @@ -32,7 +33,7 @@ class GatherNDBase { }; // struct Prepare template - Status PrepareForCompute(OpKernelContext* context, Prepare& p) const; + Status PrepareForCompute(OpKernelContext* context, Prepare& p, concurrency::ThreadPool* tp) const; int64_t batch_dims_; }; // class GatherNDBase @@ -44,8 +45,8 @@ class GatherND final : public OpKernel, protected GatherNDBase { Status Compute(OpKernelContext* context) const override; private: - Status GatherNumber(const Prepare& p) const; - Status GatherString(const Prepare& p) const; + Status GatherNumber(const Prepare& p, concurrency::ThreadPool* tp) const; + Status GatherString(const Prepare& p, concurrency::ThreadPool* tp) const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc b/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc index 7b626ce0ef..9b1b173c43 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "scatter_nd.h" +#include "core/platform/threadpool.h" namespace onnxruntime { @@ -9,37 +10,36 @@ ONNX_CPU_OPERATOR_KERNEL( ScatterND, 11, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), ScatterND); -template +template Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const { - - auto input_tensor = context->Input(0); + auto input_tensor = context->Input(0); auto indice_tensor = context->Input(1); auto update_tensor = context->Input(2); - ORT_ENFORCE(input_tensor != nullptr); + ORT_ENFORCE(input_tensor != nullptr); ORT_ENFORCE(indice_tensor != nullptr); ORT_ENFORCE(update_tensor != nullptr); - auto input_shape = input_tensor->Shape(); - auto indice_shape = indice_tensor->Shape(); - auto update_shape = update_tensor->Shape(); + auto input_shape = input_tensor->Shape(); + auto indice_shape = indice_tensor->Shape(); + auto update_shape = update_tensor->Shape(); if (indice_shape.NumDimensions() == 0 || input_shape.NumDimensions() == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input tensor and indices tensor must has rank larger than 0. ", - "input shape: ", input_shape, ", indices shape: ", indice_shape); + "input tensor and indices tensor must has rank larger than 0. ", + "input shape: ", input_shape, ", indices shape: ", indice_shape); } auto indice_rank = indice_shape.NumDimensions(); auto last_indice_dimension = indice_shape[indice_rank - 1]; if (last_indice_dimension > static_cast(input_shape.NumDimensions())) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "last dimension of indices must not be larger than rank of input tensor"); + "last dimension of indices must not be larger than rank of input tensor"); } - bool is_update_shape_invalid = [&](){ + bool is_update_shape_invalid = [&]() { auto update_rank = update_shape.NumDimensions(); auto input_rank = input_shape.NumDimensions(); if (update_rank < indice_rank - 1) { @@ -59,8 +59,8 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co }(); if (is_update_shape_invalid) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "updates tensor should have shape equal to indices.shape[:-1] + data.shape[indices.shape[-1]:]. ", - "updates shape: ", update_shape, ", indices shape: ", indice_shape, ", data shape: ", input_shape); + "updates tensor should have shape equal to indices.shape[:-1] + data.shape[indices.shape[-1]:]. ", + "updates shape: ", update_shape, ", indices shape: ", indice_shape, ", data shape: ", input_shape); } auto output_tensor = context->Output(0, TensorShape(input_shape)); @@ -81,34 +81,28 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co } } - std::vector element_counts(last_indice_dimension, 0LL); // Number of elements for each input dimension + std::vector element_counts(last_indice_dimension, 0LL); // Number of elements for each input dimension -#ifdef _OPENMP -#pragma omp parallel for -#endif for (int64_t i = 0; i < last_indice_dimension; ++i) { element_counts[i] = input_shape.SizeFromDimension(i + 1); } int64_t err_indice = 0; - p.element_bytes = input_tensor->DataType()->Size(); - p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension); - p.bytes_to_copy = p.element_bytes * p.element_to_copy; + p.element_bytes = input_tensor->DataType()->Size(); + p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension); + p.bytes_to_copy = p.element_bytes * p.element_to_copy; auto indice_offset = static_cast(indice_tensor->DataRaw()); - auto offset_count = indice_shape.Size() / last_indice_dimension; // Times to copy + auto offset_count = indice_shape.Size() / last_indice_dimension; // Times to copy p.element_offsets.assign(offset_count, 0LL); if (input_tensor->IsDataTypeString()) { - p.input_str_base = static_cast(update_tensor->DataRaw()); + p.input_str_base = static_cast(update_tensor->DataRaw()); p.output_str_base = static_cast(output_tensor->MutableDataRaw()); } else { - p.input_base = static_cast(update_tensor->DataRaw()); - p.output_base = static_cast(output_tensor->MutableDataRaw()); + p.input_base = static_cast(update_tensor->DataRaw()); + p.output_base = static_cast(output_tensor->MutableDataRaw()); } -#ifdef _OPENMP -#pragma omp parallel for -#endif for (int64_t i = 0; i < offset_count; ++i) { for (int64_t j = 0; j < last_indice_dimension; ++j) { auto indice = *(indice_offset + i * last_indice_dimension + j); @@ -118,40 +112,46 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co p.element_offsets[i] += indice * element_counts[j]; } } - return err_indice == 0 ? Status::OK() : - ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice); + return err_indice == 0 ? Status::OK() : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice); } template Status ScatterNDBase::PrepareForCompute(OpKernelContext*, Prepare&) const; Status ScatterND::Compute(OpKernelContext* context) const { Prepare p; + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); ORT_RETURN_IF_ERROR(PrepareForCompute(context, p)); - return nullptr == p.input_str_base ? ScatterNumber(p) : ScatterString(p); + return nullptr == p.input_str_base ? ScatterNumber(p, tp) : ScatterString(p, tp); } -Status ScatterND::ScatterNumber(const Prepare& p) const { -#ifdef _OPENMP -#pragma omp parallel for -#endif - for (int64_t i = 0; i < static_cast(p.element_offsets.size()); ++i) { +Status ScatterND::ScatterNumber(const Prepare& p, concurrency::ThreadPool* tp) const { + auto lambda = [&](int64_t i) { memcpy(p.output_base + p.element_offsets[i] * p.element_bytes, p.input_base + i * p.bytes_to_copy, p.bytes_to_copy); - } + }; + concurrency::ThreadPool::TryParallelFor(tp, p.element_offsets.size(), static_cast(p.bytes_to_copy), + [&lambda](ptrdiff_t first, ptrdiff_t last) { + for (int i = static_cast(first), end = static_cast(last); i < end; ++i) { + lambda(i); + } + }); return Status::OK(); } -Status ScatterND::ScatterString(const Prepare& p) const { -#ifdef _OPENMP -#pragma omp parallel for -#endif - for (int64_t i = 0; i < static_cast(p.element_offsets.size()); ++i) { +Status ScatterND::ScatterString(const Prepare& p, concurrency::ThreadPool* tp) const { + auto lambda = [&](int64_t i) { for (int64_t j = 0; j < static_cast(p.element_to_copy); ++j) { p.output_str_base[p.element_offsets[i] + j] = p.input_str_base[i * p.element_to_copy + j]; } - } + }; + concurrency::ThreadPool::TryParallelFor(tp, p.element_offsets.size(), static_cast(p.element_to_copy), + [&lambda](ptrdiff_t first, ptrdiff_t last) { + for (int i = static_cast(first), end = static_cast(last); i < end; ++i) { + lambda(i); + } + }); return Status::OK(); } -} \ No newline at end of file +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/cpu/tensor/scatter_nd.h b/onnxruntime/core/providers/cpu/tensor/scatter_nd.h index 3c91441bbe..9d9a52ea41 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter_nd.h +++ b/onnxruntime/core/providers/cpu/tensor/scatter_nd.h @@ -7,41 +7,43 @@ #include "core/framework/op_kernel.h" namespace onnxruntime { - -class ScatterNDBase -{ -protected: +namespace concurrency { +class ThreadPool; +} +class ScatterNDBase { + protected: struct Prepare { - const uint8_t* input_base; - const std::string* input_str_base; - uint8_t* output_base; - std::string* output_str_base; - uint64_t bytes_to_copy; - uint64_t element_bytes; - uint64_t element_to_copy; + const uint8_t* input_base; + const std::string* input_str_base; + uint8_t* output_base; + std::string* output_str_base; + uint64_t bytes_to_copy; + uint64_t element_bytes; + uint64_t element_to_copy; std::vector element_offsets; - Prepare(): input_base (nullptr), - input_str_base (nullptr), - output_base (nullptr), - output_str_base (nullptr), - bytes_to_copy (0), - element_bytes (0), - element_to_copy (0), - element_offsets (0) {} - }; // struct Prepare + Prepare() : input_base(nullptr), + input_str_base(nullptr), + output_base(nullptr), + output_str_base(nullptr), + bytes_to_copy(0), + element_bytes(0), + element_to_copy(0), + element_offsets(0) {} + }; // struct Prepare - template + template Status PrepareForCompute(OpKernelContext* context, Prepare& p) const; -}; // class ScatterNDBase +}; // class ScatterNDBase class ScatterND final : public OpKernel, protected ScatterNDBase { -public: + public: explicit ScatterND(const OpKernelInfo& info) : OpKernel(info) {} Status Compute(OpKernelContext* context) const override; -private: - Status ScatterNumber(const Prepare& p) const; - Status ScatterString(const Prepare& p) const; + + private: + Status ScatterNumber(const Prepare& p, concurrency::ThreadPool* tp) const; + Status ScatterString(const Prepare& p, concurrency::ThreadPool* tp) const; }; -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime \ No newline at end of file