mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Add BatchParallelFor, TryParallelFor, TryBatchParallelFor into ThreadPool (#2476)
This commit is contained in:
parent
d6c84925d5
commit
e29fb5cef1
5 changed files with 177 additions and 34 deletions
|
|
@ -46,6 +46,11 @@ class ThreadPool {
|
|||
*/
|
||||
void ParallelFor(int32_t total, std::function<void(int32_t)> fn);
|
||||
|
||||
/*
|
||||
Schedule work in the interval [0, total), with calls split into (num_batches) batches.
|
||||
*/
|
||||
void BatchParallelFor(int32_t total, std::function<void(int32_t)> fn, int32_t num_batches = 0);
|
||||
|
||||
/*
|
||||
Schedule work in the interval [first, last].
|
||||
*/
|
||||
|
|
@ -54,6 +59,43 @@ class ThreadPool {
|
|||
// This is not supported until the latest Eigen
|
||||
// void SetStealPartitions(const std::vector<std::pair<unsigned, unsigned>>& partitions);
|
||||
|
||||
/**
|
||||
Tries to call the given function in parallel, with calls split into (num_batches) batches.
|
||||
**/
|
||||
template <typename F>
|
||||
inline static void TryBatchParallelFor(concurrency::ThreadPool* tp, int32_t total, F&& fn, int32_t num_batches = 0) {
|
||||
if (tp != nullptr) {
|
||||
if (num_batches <= 0) {
|
||||
num_batches = tp->NumThreads() + 1;
|
||||
}
|
||||
tp->BatchParallelFor(total, std::forward<F>(fn), num_batches);
|
||||
} else {
|
||||
#ifdef USE_OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int32_t i = 0; i < total; ++i) {
|
||||
fn(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
Tries to call the given function in parallel.
|
||||
**/
|
||||
template <typename F>
|
||||
inline static void TryParallelFor(concurrency::ThreadPool* tp, int32_t total, F&& fn) {
|
||||
if (tp != nullptr) {
|
||||
tp->ParallelFor(total, std::forward<F>(fn));
|
||||
} else {
|
||||
#ifdef USE_OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int32_t i = 0; i < total; ++i) {
|
||||
fn(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int NumThreads() const;
|
||||
|
||||
int CurrentThreadId() const;
|
||||
|
|
|
|||
|
|
@ -43,17 +43,6 @@ namespace contrib {
|
|||
|
||||
ADD_TYPED_CROPANDRESIZE_OP(float);
|
||||
|
||||
template <typename T>
|
||||
static void TryParallelFor(concurrency::ThreadPool* tp, int32_t total, T&& fn) {
|
||||
if (tp != nullptr)
|
||||
tp->ParallelFor(total, fn);
|
||||
else {
|
||||
for (int32_t i = 0; i != total; ++i) {
|
||||
fn(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CropAndResizeForward(const TensorShape& output_shape,
|
||||
const T* bottom_data,
|
||||
|
|
@ -71,9 +60,7 @@ void CropAndResizeForward(const TensorShape& output_shape,
|
|||
int64_t pooled_height = output_shape[2];
|
||||
int64_t pooled_width = output_shape[3];
|
||||
|
||||
// TODO: This should do blocks of work based on the number of threads in the threadpool with each block
|
||||
// being n_rois / num_threads
|
||||
std::function<void(int32_t)> work_object = [&](int32_t n) {
|
||||
ThreadPool::TryBatchParallelFor(ttp, static_cast<int32_t>(n_rois), [&](int32_t n) {
|
||||
int64_t index_n = n * channels * pooled_width * pooled_height;
|
||||
|
||||
const T* offset_bottom_rois = bottom_rois + n * num_roi_cols;
|
||||
|
|
@ -182,9 +169,7 @@ void CropAndResizeForward(const TensorShape& output_shape,
|
|||
}
|
||||
} // for pw
|
||||
} // for ph
|
||||
}; // for n
|
||||
|
||||
TryParallelFor(ttp, static_cast<int32_t>(n_rois), work_object);
|
||||
}); // for n
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -57,6 +57,36 @@ void ThreadPool::ParallelFor(int32_t total, std::function<void(int32_t)> fn) {
|
|||
barrier.Wait();
|
||||
}
|
||||
|
||||
void ThreadPool::BatchParallelFor(int32_t total, std::function<void(int32_t)> fn, int32_t num_batches) {
|
||||
if (total <= 0)
|
||||
return;
|
||||
|
||||
if (total == 1) {
|
||||
fn(0);
|
||||
return;
|
||||
}
|
||||
|
||||
if (num_batches <= 1) {
|
||||
for (int i = 0; i < total; i++) {
|
||||
fn(i);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (num_batches >= total) {
|
||||
ParallelFor(total, fn);
|
||||
return;
|
||||
}
|
||||
|
||||
ParallelFor(num_batches, [&](int batch_index) {
|
||||
int start = batch_index * total / num_batches;
|
||||
int end = (batch_index + 1) * total / num_batches;
|
||||
for (int i = start; i < end; i++) {
|
||||
fn(i);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ThreadPool::ParallelForRange(int64_t first, int64_t last, std::function<void(int64_t, int64_t)> fn) {
|
||||
if (last <= first) return;
|
||||
if (last - first == 1) {
|
||||
|
|
|
|||
|
|
@ -42,17 +42,6 @@ ADD_TYPED_ROIALIGN_OP(float);
|
|||
ADD_TYPED_ROIALIGN_OP(double);
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
void TryParallelFor(concurrency::ThreadPool* tp, int32_t total, T&& fn) {
|
||||
if (tp != nullptr)
|
||||
tp->ParallelFor(total, fn);
|
||||
else {
|
||||
for (int32_t i = 0; i != total; ++i) {
|
||||
fn(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct PreCalc {
|
||||
int64_t pos1;
|
||||
|
|
@ -183,9 +172,7 @@ void RoiAlignForward(const TensorShape& output_shape,
|
|||
int64_t pooled_height = output_shape[2];
|
||||
int64_t pooled_width = output_shape[3];
|
||||
|
||||
// TODO: This should do blocks of work based on the number of threads in the threadpool with each block
|
||||
// being n_rois / num_threads
|
||||
std::function<void(int32_t)> work_object = [&](int32_t n) {
|
||||
ThreadPool::TryBatchParallelFor(ttp, static_cast<int32_t>(n_rois), [&](int32_t n) {
|
||||
int64_t index_n = n * channels * pooled_width * pooled_height;
|
||||
|
||||
const T* offset_bottom_rois = bottom_rois + n * num_roi_cols;
|
||||
|
|
@ -281,9 +268,7 @@ void RoiAlignForward(const TensorShape& output_shape,
|
|||
} // for pw
|
||||
} // for ph
|
||||
} // for c
|
||||
}; // for n
|
||||
|
||||
TryParallelFor(ttp, static_cast<int32_t>(n_rois), work_object);
|
||||
}); // for n
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
|||
101
onnxruntime/test/platform/threadpool_test.cc
Normal file
101
onnxruntime/test/platform/threadpool_test.cc
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/platform/threadpool.h"
|
||||
|
||||
#include <core/common/make_unique.h>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
|
||||
using namespace onnxruntime::concurrency;
|
||||
|
||||
namespace {
|
||||
|
||||
struct TestData {
|
||||
explicit TestData(int num) : data(num, 0) {}
|
||||
std::vector<int> data;
|
||||
std::mutex mutex;
|
||||
};
|
||||
|
||||
// This unittest tests ThreadPool function by counting the number of calls to function with each index.
|
||||
// the function should be called exactly once for each element.
|
||||
|
||||
std::unique_ptr<TestData> CreateTestData(int num) {
|
||||
return onnxruntime::make_unique<TestData>(num);
|
||||
}
|
||||
|
||||
void IncrementElement(TestData& test_data, int i) {
|
||||
std::lock_guard<std::mutex> lock(test_data.mutex);
|
||||
test_data.data[i]++;
|
||||
}
|
||||
|
||||
void ValidateTestData(TestData& test_data) {
|
||||
ASSERT_TRUE(std::count_if(test_data.data.cbegin(),
|
||||
test_data.data.cend(),
|
||||
[](int i) { return i != 1; }) == 0);
|
||||
}
|
||||
|
||||
void CreateThreadPoolAndTest(const std::string& name, int num_threads, const std::function<void(ThreadPool*)>& test_body) {
|
||||
auto tp = onnxruntime::make_unique<ThreadPool>(name, num_threads);
|
||||
test_body(tp.get());
|
||||
}
|
||||
|
||||
void TestParallelFor(const std::string& name, int num_threads, int num_tasks) {
|
||||
auto test_data = CreateTestData(num_tasks);
|
||||
CreateThreadPoolAndTest(name, num_threads, [&](ThreadPool* tp) {
|
||||
tp->ParallelFor(num_tasks, [&](int i) {
|
||||
IncrementElement(*test_data, i);
|
||||
});
|
||||
});
|
||||
ValidateTestData(*test_data);
|
||||
}
|
||||
|
||||
void TestBatchParallelFor(const std::string& name, int num_threads, int num_tasks, int batch_size) {
|
||||
auto test_data = CreateTestData(num_tasks);
|
||||
CreateThreadPoolAndTest(name, num_threads, [&](ThreadPool* tp) {
|
||||
tp->BatchParallelFor(
|
||||
num_tasks, [&](int i) {
|
||||
IncrementElement(*test_data, i);
|
||||
},
|
||||
batch_size);
|
||||
});
|
||||
ValidateTestData(*test_data);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(ThreadPoolTest, TestParallelFor_2_Thread_NoTask) {
|
||||
TestParallelFor("TestParallelFor_2_Thread_NoTask", 2, 0);
|
||||
}
|
||||
|
||||
TEST(ThreadPoolTest, TestParallelFor_2_Thread_50_Task) {
|
||||
TestParallelFor("TestParallelFor_2_Thread_50_Task", 2, 50);
|
||||
}
|
||||
|
||||
TEST(ThreadPoolTest, TestParallelFor_1_Thread_50_Task) {
|
||||
TestParallelFor("TestParallelFor_1_Thread_50_Task", 1, 50);
|
||||
}
|
||||
|
||||
TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_10_Batch) {
|
||||
TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_10_Batch", 2, 50, 10);
|
||||
}
|
||||
|
||||
TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_0_Batch) {
|
||||
TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_0_Batch", 2, 50, 0);
|
||||
}
|
||||
|
||||
TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_1_Batch) {
|
||||
TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_1_Batch", 2, 50, 1);
|
||||
}
|
||||
|
||||
TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_100_Batch) {
|
||||
TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_100_Batch", 2, 50, 100);
|
||||
}
|
||||
|
||||
TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_81_Task_20_Batch) {
|
||||
TestBatchParallelFor("TestBatchParallelFor_2_Thread_81_Task_20_Batch", 2, 81, 20);
|
||||
}
|
||||
Loading…
Reference in a new issue