Add BatchParallelFor, TryParallelFor, TryBatchParallelFor into ThreadPool (#2476)

This commit is contained in:
Yulong Wang 2019-11-27 00:32:26 -08:00 committed by GitHub
parent d6c84925d5
commit e29fb5cef1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 177 additions and 34 deletions

View file

@ -46,6 +46,11 @@ class ThreadPool {
*/
void ParallelFor(int32_t total, std::function<void(int32_t)> fn);
/*
Schedule work in the interval [0, total), with calls split into (num_batches) batches.
*/
void BatchParallelFor(int32_t total, std::function<void(int32_t)> fn, int32_t num_batches = 0);
/*
Schedule work in the interval [first, last].
*/
@ -54,6 +59,43 @@ class ThreadPool {
// This is not supported until the latest Eigen
// void SetStealPartitions(const std::vector<std::pair<unsigned, unsigned>>& partitions);
/**
Tries to call the given function in parallel, with calls split into (num_batches) batches.
**/
template <typename F>
inline static void TryBatchParallelFor(concurrency::ThreadPool* tp, int32_t total, F&& fn, int32_t num_batches = 0) {
if (tp != nullptr) {
if (num_batches <= 0) {
num_batches = tp->NumThreads() + 1;
}
tp->BatchParallelFor(total, std::forward<F>(fn), num_batches);
} else {
#ifdef USE_OPENMP
#pragma omp parallel for
#endif
for (int32_t i = 0; i < total; ++i) {
fn(i);
}
}
}
/**
Tries to call the given function in parallel.
**/
template <typename F>
inline static void TryParallelFor(concurrency::ThreadPool* tp, int32_t total, F&& fn) {
if (tp != nullptr) {
tp->ParallelFor(total, std::forward<F>(fn));
} else {
#ifdef USE_OPENMP
#pragma omp parallel for
#endif
for (int32_t i = 0; i < total; ++i) {
fn(i);
}
}
}
int NumThreads() const;
int CurrentThreadId() const;

View file

@ -43,17 +43,6 @@ namespace contrib {
ADD_TYPED_CROPANDRESIZE_OP(float);
template <typename T>
static void TryParallelFor(concurrency::ThreadPool* tp, int32_t total, T&& fn) {
if (tp != nullptr)
tp->ParallelFor(total, fn);
else {
for (int32_t i = 0; i != total; ++i) {
fn(i);
}
}
}
template <typename T>
void CropAndResizeForward(const TensorShape& output_shape,
const T* bottom_data,
@ -71,9 +60,7 @@ void CropAndResizeForward(const TensorShape& output_shape,
int64_t pooled_height = output_shape[2];
int64_t pooled_width = output_shape[3];
// TODO: This should do blocks of work based on the number of threads in the threadpool with each block
// being n_rois / num_threads
std::function<void(int32_t)> work_object = [&](int32_t n) {
ThreadPool::TryBatchParallelFor(ttp, static_cast<int32_t>(n_rois), [&](int32_t n) {
int64_t index_n = n * channels * pooled_width * pooled_height;
const T* offset_bottom_rois = bottom_rois + n * num_roi_cols;
@ -182,9 +169,7 @@ void CropAndResizeForward(const TensorShape& output_shape,
}
} // for pw
} // for ph
}; // for n
TryParallelFor(ttp, static_cast<int32_t>(n_rois), work_object);
}); // for n
}
template <typename T>

View file

@ -57,6 +57,36 @@ void ThreadPool::ParallelFor(int32_t total, std::function<void(int32_t)> fn) {
barrier.Wait();
}
void ThreadPool::BatchParallelFor(int32_t total, std::function<void(int32_t)> fn, int32_t num_batches) {
if (total <= 0)
return;
if (total == 1) {
fn(0);
return;
}
if (num_batches <= 1) {
for (int i = 0; i < total; i++) {
fn(i);
}
return;
}
if (num_batches >= total) {
ParallelFor(total, fn);
return;
}
ParallelFor(num_batches, [&](int batch_index) {
int start = batch_index * total / num_batches;
int end = (batch_index + 1) * total / num_batches;
for (int i = start; i < end; i++) {
fn(i);
}
});
}
void ThreadPool::ParallelForRange(int64_t first, int64_t last, std::function<void(int64_t, int64_t)> fn) {
if (last <= first) return;
if (last - first == 1) {

View file

@ -42,17 +42,6 @@ ADD_TYPED_ROIALIGN_OP(float);
ADD_TYPED_ROIALIGN_OP(double);
namespace {
template <typename T>
void TryParallelFor(concurrency::ThreadPool* tp, int32_t total, T&& fn) {
if (tp != nullptr)
tp->ParallelFor(total, fn);
else {
for (int32_t i = 0; i != total; ++i) {
fn(i);
}
}
}
template <typename T>
struct PreCalc {
int64_t pos1;
@ -183,9 +172,7 @@ void RoiAlignForward(const TensorShape& output_shape,
int64_t pooled_height = output_shape[2];
int64_t pooled_width = output_shape[3];
// TODO: This should do blocks of work based on the number of threads in the threadpool with each block
// being n_rois / num_threads
std::function<void(int32_t)> work_object = [&](int32_t n) {
ThreadPool::TryBatchParallelFor(ttp, static_cast<int32_t>(n_rois), [&](int32_t n) {
int64_t index_n = n * channels * pooled_width * pooled_height;
const T* offset_bottom_rois = bottom_rois + n * num_roi_cols;
@ -281,9 +268,7 @@ void RoiAlignForward(const TensorShape& output_shape,
} // for pw
} // for ph
} // for c
}; // for n
TryParallelFor(ttp, static_cast<int32_t>(n_rois), work_object);
}); // for n
}
} // namespace

View file

@ -0,0 +1,101 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/platform/threadpool.h"
#include <core/common/make_unique.h>
#include "gtest/gtest.h"
#include <algorithm>
#include <memory>
#include <functional>
#include <mutex>
using namespace onnxruntime::concurrency;
namespace {
struct TestData {
explicit TestData(int num) : data(num, 0) {}
std::vector<int> data;
std::mutex mutex;
};
// This unittest tests ThreadPool function by counting the number of calls to function with each index.
// the function should be called exactly once for each element.
std::unique_ptr<TestData> CreateTestData(int num) {
return onnxruntime::make_unique<TestData>(num);
}
void IncrementElement(TestData& test_data, int i) {
std::lock_guard<std::mutex> lock(test_data.mutex);
test_data.data[i]++;
}
void ValidateTestData(TestData& test_data) {
ASSERT_TRUE(std::count_if(test_data.data.cbegin(),
test_data.data.cend(),
[](int i) { return i != 1; }) == 0);
}
void CreateThreadPoolAndTest(const std::string& name, int num_threads, const std::function<void(ThreadPool*)>& test_body) {
auto tp = onnxruntime::make_unique<ThreadPool>(name, num_threads);
test_body(tp.get());
}
void TestParallelFor(const std::string& name, int num_threads, int num_tasks) {
auto test_data = CreateTestData(num_tasks);
CreateThreadPoolAndTest(name, num_threads, [&](ThreadPool* tp) {
tp->ParallelFor(num_tasks, [&](int i) {
IncrementElement(*test_data, i);
});
});
ValidateTestData(*test_data);
}
void TestBatchParallelFor(const std::string& name, int num_threads, int num_tasks, int batch_size) {
auto test_data = CreateTestData(num_tasks);
CreateThreadPoolAndTest(name, num_threads, [&](ThreadPool* tp) {
tp->BatchParallelFor(
num_tasks, [&](int i) {
IncrementElement(*test_data, i);
},
batch_size);
});
ValidateTestData(*test_data);
}
} // namespace
TEST(ThreadPoolTest, TestParallelFor_2_Thread_NoTask) {
TestParallelFor("TestParallelFor_2_Thread_NoTask", 2, 0);
}
TEST(ThreadPoolTest, TestParallelFor_2_Thread_50_Task) {
TestParallelFor("TestParallelFor_2_Thread_50_Task", 2, 50);
}
TEST(ThreadPoolTest, TestParallelFor_1_Thread_50_Task) {
TestParallelFor("TestParallelFor_1_Thread_50_Task", 1, 50);
}
TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_10_Batch) {
TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_10_Batch", 2, 50, 10);
}
TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_0_Batch) {
TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_0_Batch", 2, 50, 0);
}
TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_1_Batch) {
TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_1_Batch", 2, 50, 1);
}
TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_100_Batch) {
TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_100_Batch", 2, 50, 100);
}
TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_81_Task_20_Batch) {
TestBatchParallelFor("TestBatchParallelFor_2_Thread_81_Task_20_Batch", 2, 81, 20);
}