Refactor scatter/gather ops to use the new cost based threadpool abstractions. (#3776)

* Update Scatter and Gather ops by replacing pragma omp invocations with the new threadpool abstractions.

* Use forward declarations

* PR comments
This commit is contained in:
Pranav Sharma 2020-05-02 00:09:31 -07:00 committed by GitHub
parent 4f9f6aedea
commit 2b8d9ef0fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 128 additions and 113 deletions

View file

@ -175,7 +175,7 @@ static inline void ExecuteLambdaInParallel(TLambda lambda, int max, int step, do
}
#else
const int total_tasks = max / (step > 0 ? step : 1) + (max % step > 0 ? 1 : 0);
concurrency::ThreadPool::TryParallelFor(ttp, total_tasks, cost, [lambda, step](ptrdiff_t first, ptrdiff_t last) {
concurrency::ThreadPool::TryParallelFor(ttp, total_tasks, cost, [&lambda, step](ptrdiff_t first, ptrdiff_t last) {
for (int i = static_cast<int>(first), end = static_cast<int>(last); i < end; ++i) {
lambda(i * step);
}

View file

@ -4,6 +4,7 @@
//https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gather
#include "core/providers/cpu/tensor/gather.h"
#include "core/common/common.h"
#include "core/platform/threadpool.h"
namespace onnxruntime {
@ -57,11 +58,10 @@ template <typename Tin>
Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uint8_t* dst_base, bool is_string_type,
const size_t element_bytes, const int64_t block_size, const int64_t M,
const int64_t N, const int64_t data_batch_bytes, const int64_t gathered_batch_bytes,
const TensorShape& input_data_shape, const int64_t axis) {
const TensorShape& input_data_shape, const int64_t axis, concurrency::ThreadPool* tp) {
const Tin* indices_data = indices_tensor->template Data<Tin>();
// Check the indices first in case there's a out of bound index.
// We can't merge this code in the omp loop below as omp does not allow return in the loop
auto axis_dim_limit = input_data_shape[axis];
for (int64_t i = 0; i < N; ++i) {
@ -73,10 +73,7 @@ Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uin
}
}
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t index = 0; index < M * N; ++index) {
auto lambda = [&](int64_t index) {
int64_t batch = index / N;
int64_t i = index % N;
@ -93,7 +90,13 @@ Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uin
} else {
memcpy(dst_base + dst_offset, src_base + src_offset, block_size);
}
}
};
concurrency::ThreadPool::TryParallelFor(tp, M * N, static_cast<double>(block_size),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (int index = static_cast<int>(first), end = static_cast<int>(last); index < end; ++index) {
lambda(index);
}
});
return Status::OK();
}
@ -117,13 +120,15 @@ Status Gather::Compute(OpKernelContext* context) const {
const auto* src_base = static_cast<const uint8_t*>(p.input_tensor->DataRaw());
auto* dst_base = static_cast<uint8_t*>(p.output_tensor->MutableDataRaw());
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
if (p.indices_tensor->IsDataType<int32_t>()) {
return GatherCopyData<int32_t>(p.indices_tensor, src_base, dst_base, is_string_type, element_bytes,
block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis);
block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis, tp);
}
if (p.indices_tensor->IsDataType<int64_t>()) {
return GatherCopyData<int64_t>(p.indices_tensor, src_base, dst_base, is_string_type, element_bytes,
block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis);
block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis, tp);
}
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tind not supported yet in Gather.");

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "gather_nd.h"
#include "core/platform/threadpool.h"
namespace onnxruntime {
@ -43,7 +44,7 @@ ONNX_CPU_OPERATOR_KERNEL(
GatherND);
template <typename Tind>
Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p, concurrency::ThreadPool* tp) const {
const auto* input_tensor = context->Input<Tensor>(0);
const auto* indices_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(input_tensor != nullptr && indices_tensor != nullptr, "GatherNDBase PrepareForCompute: Input count mismatch");
@ -72,9 +73,6 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con
auto* output_tensor = context->Output(0, TensorShape(std::move(shape)));
std::vector<int64_t> sizes_from_slice_dims(num_slice_dims);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t i = 0; i < num_slice_dims; ++i) {
sizes_from_slice_dims[i] = input_shape.SizeFromDimension(batch_dims_ + i + 1);
}
@ -95,10 +93,7 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con
}
// Compute the element_offset
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t slice_idx = 0; slice_idx < num_slices; ++slice_idx) {
auto lambda = [&](int64_t slice_idx) {
const size_t batch_idx = slice_idx / num_slices_per_batch;
const size_t input_base_offset = batch_idx * input_batch_stride;
@ -118,46 +113,58 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con
}
p.slice_offsets[slice_idx] = input_base_offset + relative_slice_offset;
}
};
concurrency::ThreadPool::TryParallelFor(tp, num_slices, static_cast<double>(num_slice_dims),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (int slice_idx = static_cast<int>(first), end = static_cast<int>(last); slice_idx < end; ++slice_idx) {
lambda(slice_idx);
}
});
return err_index == 0 ? Status::OK()
: ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index found, index = ", err_index);
}
template Status GatherNDBase::PrepareForCompute<int32_t>(OpKernelContext*, Prepare&) const;
template Status GatherNDBase::PrepareForCompute<int64_t>(OpKernelContext*, Prepare&) const;
template Status GatherNDBase::PrepareForCompute<int32_t>(OpKernelContext*, Prepare&, concurrency::ThreadPool*) const;
template Status GatherNDBase::PrepareForCompute<int64_t>(OpKernelContext*, Prepare&, concurrency::ThreadPool*) const;
Status GatherND::Compute(OpKernelContext* context) const {
Prepare p;
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
ORT_RETURN_IF_ERROR(context->Input<Tensor>(1)->IsDataType<int32_t>()
? PrepareForCompute<int32_t>(context, p)
: PrepareForCompute<int64_t>(context, p));
? PrepareForCompute<int32_t>(context, p, tp)
: PrepareForCompute<int64_t>(context, p, tp));
return nullptr == p.input_str_base ? GatherNumber(p) : GatherString(p);
return nullptr == p.input_str_base ? GatherNumber(p, tp) : GatherString(p, tp);
}
Status GatherND::GatherNumber(const Prepare& p) const {
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t slice_idx = 0; slice_idx < static_cast<int64_t>(p.slice_offsets.size()); ++slice_idx) {
Status GatherND::GatherNumber(const Prepare& p, concurrency::ThreadPool* tp) const {
auto lambda = [&](int64_t slice_idx) {
memcpy(p.output_base + slice_idx * p.bytes_per_slice, p.input_base + p.slice_offsets[slice_idx] * p.element_bytes,
p.bytes_per_slice);
}
};
concurrency::ThreadPool::TryParallelFor(tp, p.slice_offsets.size(), static_cast<double>(p.bytes_per_slice),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (int slice_idx = static_cast<int>(first), end = static_cast<int>(last); slice_idx < end; ++slice_idx) {
lambda(slice_idx);
}
});
return Status::OK();
}
Status GatherND::GatherString(const Prepare& p) const {
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t slice_idx = 0; slice_idx < static_cast<int64_t>(p.slice_offsets.size()); ++slice_idx) {
Status GatherND::GatherString(const Prepare& p, concurrency::ThreadPool* tp) const {
auto lambda = [&](int64_t slice_idx) {
const int64_t slice_base_offset = slice_idx * p.element_count_per_slice;
for (int64_t j = 0; j < static_cast<int64_t>(p.element_count_per_slice); ++j) {
p.output_str_base[slice_base_offset + j] = p.input_str_base[p.slice_offsets[slice_idx] + j];
}
}
};
concurrency::ThreadPool::TryParallelFor(tp, p.slice_offsets.size(), static_cast<double>(p.element_count_per_slice),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (int slice_idx = static_cast<int>(first), end = static_cast<int>(last); slice_idx < end; ++slice_idx) {
lambda(slice_idx);
}
});
return Status::OK();
}

View file

@ -5,10 +5,11 @@
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/platform/threadpool.h"
namespace onnxruntime {
namespace concurrency {
class ThreadPool;
}
class GatherNDBase {
protected:
struct Prepare {
@ -32,7 +33,7 @@ class GatherNDBase {
}; // struct Prepare
template <typename Tind>
Status PrepareForCompute(OpKernelContext* context, Prepare& p) const;
Status PrepareForCompute(OpKernelContext* context, Prepare& p, concurrency::ThreadPool* tp) const;
int64_t batch_dims_;
}; // class GatherNDBase
@ -44,8 +45,8 @@ class GatherND final : public OpKernel, protected GatherNDBase {
Status Compute(OpKernelContext* context) const override;
private:
Status GatherNumber(const Prepare& p) const;
Status GatherString(const Prepare& p) const;
Status GatherNumber(const Prepare& p, concurrency::ThreadPool* tp) const;
Status GatherString(const Prepare& p, concurrency::ThreadPool* tp) const;
};
} // namespace onnxruntime

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "scatter_nd.h"
#include "core/platform/threadpool.h"
namespace onnxruntime {
@ -9,37 +10,36 @@ ONNX_CPU_OPERATOR_KERNEL(
ScatterND,
11,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<int64_t>()),
ScatterND);
template<typename Tind>
template <typename Tind>
Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
auto input_tensor = context->Input<Tensor>(0);
auto input_tensor = context->Input<Tensor>(0);
auto indice_tensor = context->Input<Tensor>(1);
auto update_tensor = context->Input<Tensor>(2);
ORT_ENFORCE(input_tensor != nullptr);
ORT_ENFORCE(input_tensor != nullptr);
ORT_ENFORCE(indice_tensor != nullptr);
ORT_ENFORCE(update_tensor != nullptr);
auto input_shape = input_tensor->Shape();
auto indice_shape = indice_tensor->Shape();
auto update_shape = update_tensor->Shape();
auto input_shape = input_tensor->Shape();
auto indice_shape = indice_tensor->Shape();
auto update_shape = update_tensor->Shape();
if (indice_shape.NumDimensions() == 0 || input_shape.NumDimensions() == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input tensor and indices tensor must has rank larger than 0. ",
"input shape: ", input_shape, ", indices shape: ", indice_shape);
"input tensor and indices tensor must has rank larger than 0. ",
"input shape: ", input_shape, ", indices shape: ", indice_shape);
}
auto indice_rank = indice_shape.NumDimensions();
auto last_indice_dimension = indice_shape[indice_rank - 1];
if (last_indice_dimension > static_cast<int64_t>(input_shape.NumDimensions())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"last dimension of indices must not be larger than rank of input tensor");
"last dimension of indices must not be larger than rank of input tensor");
}
bool is_update_shape_invalid = [&](){
bool is_update_shape_invalid = [&]() {
auto update_rank = update_shape.NumDimensions();
auto input_rank = input_shape.NumDimensions();
if (update_rank < indice_rank - 1) {
@ -59,8 +59,8 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co
}();
if (is_update_shape_invalid) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"updates tensor should have shape equal to indices.shape[:-1] + data.shape[indices.shape[-1]:]. ",
"updates shape: ", update_shape, ", indices shape: ", indice_shape, ", data shape: ", input_shape);
"updates tensor should have shape equal to indices.shape[:-1] + data.shape[indices.shape[-1]:]. ",
"updates shape: ", update_shape, ", indices shape: ", indice_shape, ", data shape: ", input_shape);
}
auto output_tensor = context->Output(0, TensorShape(input_shape));
@ -81,34 +81,28 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co
}
}
std::vector<int64_t> element_counts(last_indice_dimension, 0LL); // Number of elements for each input dimension
std::vector<int64_t> element_counts(last_indice_dimension, 0LL); // Number of elements for each input dimension
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t i = 0; i < last_indice_dimension; ++i) {
element_counts[i] = input_shape.SizeFromDimension(i + 1);
}
int64_t err_indice = 0;
p.element_bytes = input_tensor->DataType()->Size();
p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension);
p.bytes_to_copy = p.element_bytes * p.element_to_copy;
p.element_bytes = input_tensor->DataType()->Size();
p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension);
p.bytes_to_copy = p.element_bytes * p.element_to_copy;
auto indice_offset = static_cast<const Tind*>(indice_tensor->DataRaw());
auto offset_count = indice_shape.Size() / last_indice_dimension; // Times to copy
auto offset_count = indice_shape.Size() / last_indice_dimension; // Times to copy
p.element_offsets.assign(offset_count, 0LL);
if (input_tensor->IsDataTypeString()) {
p.input_str_base = static_cast<const std::string*>(update_tensor->DataRaw());
p.input_str_base = static_cast<const std::string*>(update_tensor->DataRaw());
p.output_str_base = static_cast<std::string*>(output_tensor->MutableDataRaw());
} else {
p.input_base = static_cast<const uint8_t*>(update_tensor->DataRaw());
p.output_base = static_cast<uint8_t*>(output_tensor->MutableDataRaw());
p.input_base = static_cast<const uint8_t*>(update_tensor->DataRaw());
p.output_base = static_cast<uint8_t*>(output_tensor->MutableDataRaw());
}
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t i = 0; i < offset_count; ++i) {
for (int64_t j = 0; j < last_indice_dimension; ++j) {
auto indice = *(indice_offset + i * last_indice_dimension + j);
@ -118,40 +112,46 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co
p.element_offsets[i] += indice * element_counts[j];
}
}
return err_indice == 0 ? Status::OK() :
ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice);
return err_indice == 0 ? Status::OK() : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice);
}
template Status ScatterNDBase::PrepareForCompute<int64_t>(OpKernelContext*, Prepare&) const;
Status ScatterND::Compute(OpKernelContext* context) const {
Prepare p;
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
ORT_RETURN_IF_ERROR(PrepareForCompute<int64_t>(context, p));
return nullptr == p.input_str_base ? ScatterNumber(p) : ScatterString(p);
return nullptr == p.input_str_base ? ScatterNumber(p, tp) : ScatterString(p, tp);
}
Status ScatterND::ScatterNumber(const Prepare& p) const {
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t i = 0; i < static_cast<int64_t>(p.element_offsets.size()); ++i) {
Status ScatterND::ScatterNumber(const Prepare& p, concurrency::ThreadPool* tp) const {
auto lambda = [&](int64_t i) {
memcpy(p.output_base + p.element_offsets[i] * p.element_bytes,
p.input_base + i * p.bytes_to_copy,
p.bytes_to_copy);
}
};
concurrency::ThreadPool::TryParallelFor(tp, p.element_offsets.size(), static_cast<double>(p.bytes_to_copy),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (int i = static_cast<int>(first), end = static_cast<int>(last); i < end; ++i) {
lambda(i);
}
});
return Status::OK();
}
Status ScatterND::ScatterString(const Prepare& p) const {
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t i = 0; i < static_cast<int64_t>(p.element_offsets.size()); ++i) {
Status ScatterND::ScatterString(const Prepare& p, concurrency::ThreadPool* tp) const {
auto lambda = [&](int64_t i) {
for (int64_t j = 0; j < static_cast<int64_t>(p.element_to_copy); ++j) {
p.output_str_base[p.element_offsets[i] + j] = p.input_str_base[i * p.element_to_copy + j];
}
}
};
concurrency::ThreadPool::TryParallelFor(tp, p.element_offsets.size(), static_cast<double>(p.element_to_copy),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (int i = static_cast<int>(first), end = static_cast<int>(last); i < end; ++i) {
lambda(i);
}
});
return Status::OK();
}
}
} // namespace onnxruntime

View file

@ -7,41 +7,43 @@
#include "core/framework/op_kernel.h"
namespace onnxruntime {
class ScatterNDBase
{
protected:
namespace concurrency {
class ThreadPool;
}
class ScatterNDBase {
protected:
struct Prepare {
const uint8_t* input_base;
const std::string* input_str_base;
uint8_t* output_base;
std::string* output_str_base;
uint64_t bytes_to_copy;
uint64_t element_bytes;
uint64_t element_to_copy;
const uint8_t* input_base;
const std::string* input_str_base;
uint8_t* output_base;
std::string* output_str_base;
uint64_t bytes_to_copy;
uint64_t element_bytes;
uint64_t element_to_copy;
std::vector<uint64_t> element_offsets;
Prepare(): input_base (nullptr),
input_str_base (nullptr),
output_base (nullptr),
output_str_base (nullptr),
bytes_to_copy (0),
element_bytes (0),
element_to_copy (0),
element_offsets (0) {}
}; // struct Prepare
Prepare() : input_base(nullptr),
input_str_base(nullptr),
output_base(nullptr),
output_str_base(nullptr),
bytes_to_copy(0),
element_bytes(0),
element_to_copy(0),
element_offsets(0) {}
}; // struct Prepare
template<typename Tind>
template <typename Tind>
Status PrepareForCompute(OpKernelContext* context, Prepare& p) const;
}; // class ScatterNDBase
}; // class ScatterNDBase
class ScatterND final : public OpKernel, protected ScatterNDBase {
public:
public:
explicit ScatterND(const OpKernelInfo& info) : OpKernel(info) {}
Status Compute(OpKernelContext* context) const override;
private:
Status ScatterNumber(const Prepare& p) const;
Status ScatterString(const Prepare& p) const;
private:
Status ScatterNumber(const Prepare& p, concurrency::ThreadPool* tp) const;
Status ScatterString(const Prepare& p, concurrency::ThreadPool* tp) const;
};
} // namespace onnxruntime
} // namespace onnxruntime