From e29fb5cef1d0e8ad27a4ede6f37606a368f3682d Mon Sep 17 00:00:00 2001 From: Yulong Wang Date: Wed, 27 Nov 2019 00:32:26 -0800 Subject: [PATCH] Add BatchParallelFor, TryParallelFor, TryBatchParallelFor into ThreadPool (#2476) --- .../onnxruntime/core/platform/threadpool.h | 42 ++++++++ .../contrib_ops/cpu/crop_and_resize.cc | 19 +--- onnxruntime/core/common/threadpool.cc | 30 ++++++ .../cpu/object_detection/roialign.cc | 19 +--- onnxruntime/test/platform/threadpool_test.cc | 101 ++++++++++++++++++ 5 files changed, 177 insertions(+), 34 deletions(-) create mode 100644 onnxruntime/test/platform/threadpool_test.cc diff --git a/include/onnxruntime/core/platform/threadpool.h b/include/onnxruntime/core/platform/threadpool.h index 3337583612..e6b442348a 100644 --- a/include/onnxruntime/core/platform/threadpool.h +++ b/include/onnxruntime/core/platform/threadpool.h @@ -46,6 +46,11 @@ class ThreadPool { */ void ParallelFor(int32_t total, std::function fn); + /* + Schedule work in the interval [0, total), with calls split into (num_batches) batches. + */ + void BatchParallelFor(int32_t total, std::function fn, int32_t num_batches = 0); + /* Schedule work in the interval [first, last]. */ @@ -54,6 +59,43 @@ class ThreadPool { // This is not supported until the latest Eigen // void SetStealPartitions(const std::vector>& partitions); + /** + Tries to call the given function in parallel, with calls split into (num_batches) batches. + **/ + template + inline static void TryBatchParallelFor(concurrency::ThreadPool* tp, int32_t total, F&& fn, int32_t num_batches = 0) { + if (tp != nullptr) { + if (num_batches <= 0) { + num_batches = tp->NumThreads() + 1; + } + tp->BatchParallelFor(total, std::forward(fn), num_batches); + } else { +#ifdef USE_OPENMP +#pragma omp parallel for +#endif + for (int32_t i = 0; i < total; ++i) { + fn(i); + } + } + } + + /** + Tries to call the given function in parallel. + **/ + template + inline static void TryParallelFor(concurrency::ThreadPool* tp, int32_t total, F&& fn) { + if (tp != nullptr) { + tp->ParallelFor(total, std::forward(fn)); + } else { +#ifdef USE_OPENMP +#pragma omp parallel for +#endif + for (int32_t i = 0; i < total; ++i) { + fn(i); + } + } + } + int NumThreads() const; int CurrentThreadId() const; diff --git a/onnxruntime/contrib_ops/cpu/crop_and_resize.cc b/onnxruntime/contrib_ops/cpu/crop_and_resize.cc index f966c9c821..f486d6a6a9 100644 --- a/onnxruntime/contrib_ops/cpu/crop_and_resize.cc +++ b/onnxruntime/contrib_ops/cpu/crop_and_resize.cc @@ -43,17 +43,6 @@ namespace contrib { ADD_TYPED_CROPANDRESIZE_OP(float); -template -static void TryParallelFor(concurrency::ThreadPool* tp, int32_t total, T&& fn) { - if (tp != nullptr) - tp->ParallelFor(total, fn); - else { - for (int32_t i = 0; i != total; ++i) { - fn(i); - } - } -} - template void CropAndResizeForward(const TensorShape& output_shape, const T* bottom_data, @@ -71,9 +60,7 @@ void CropAndResizeForward(const TensorShape& output_shape, int64_t pooled_height = output_shape[2]; int64_t pooled_width = output_shape[3]; - // TODO: This should do blocks of work based on the number of threads in the threadpool with each block - // being n_rois / num_threads - std::function work_object = [&](int32_t n) { + ThreadPool::TryBatchParallelFor(ttp, static_cast(n_rois), [&](int32_t n) { int64_t index_n = n * channels * pooled_width * pooled_height; const T* offset_bottom_rois = bottom_rois + n * num_roi_cols; @@ -182,9 +169,7 @@ void CropAndResizeForward(const TensorShape& output_shape, } } // for pw } // for ph - }; // for n - - TryParallelFor(ttp, static_cast(n_rois), work_object); + }); // for n } template diff --git a/onnxruntime/core/common/threadpool.cc b/onnxruntime/core/common/threadpool.cc index 0595f8c56e..9f08dc68bb 100644 --- a/onnxruntime/core/common/threadpool.cc +++ b/onnxruntime/core/common/threadpool.cc @@ -57,6 +57,36 @@ void ThreadPool::ParallelFor(int32_t total, std::function fn) { barrier.Wait(); } +void ThreadPool::BatchParallelFor(int32_t total, std::function fn, int32_t num_batches) { + if (total <= 0) + return; + + if (total == 1) { + fn(0); + return; + } + + if (num_batches <= 1) { + for (int i = 0; i < total; i++) { + fn(i); + } + return; + } + + if (num_batches >= total) { + ParallelFor(total, fn); + return; + } + + ParallelFor(num_batches, [&](int batch_index) { + int start = batch_index * total / num_batches; + int end = (batch_index + 1) * total / num_batches; + for (int i = start; i < end; i++) { + fn(i); + } + }); +} + void ThreadPool::ParallelForRange(int64_t first, int64_t last, std::function fn) { if (last <= first) return; if (last - first == 1) { diff --git a/onnxruntime/core/providers/cpu/object_detection/roialign.cc b/onnxruntime/core/providers/cpu/object_detection/roialign.cc index e68258cb3b..f3bab71ef5 100644 --- a/onnxruntime/core/providers/cpu/object_detection/roialign.cc +++ b/onnxruntime/core/providers/cpu/object_detection/roialign.cc @@ -42,17 +42,6 @@ ADD_TYPED_ROIALIGN_OP(float); ADD_TYPED_ROIALIGN_OP(double); namespace { -template -void TryParallelFor(concurrency::ThreadPool* tp, int32_t total, T&& fn) { - if (tp != nullptr) - tp->ParallelFor(total, fn); - else { - for (int32_t i = 0; i != total; ++i) { - fn(i); - } - } -} - template struct PreCalc { int64_t pos1; @@ -183,9 +172,7 @@ void RoiAlignForward(const TensorShape& output_shape, int64_t pooled_height = output_shape[2]; int64_t pooled_width = output_shape[3]; - // TODO: This should do blocks of work based on the number of threads in the threadpool with each block - // being n_rois / num_threads - std::function work_object = [&](int32_t n) { + ThreadPool::TryBatchParallelFor(ttp, static_cast(n_rois), [&](int32_t n) { int64_t index_n = n * channels * pooled_width * pooled_height; const T* offset_bottom_rois = bottom_rois + n * num_roi_cols; @@ -281,9 +268,7 @@ void RoiAlignForward(const TensorShape& output_shape, } // for pw } // for ph } // for c - }; // for n - - TryParallelFor(ttp, static_cast(n_rois), work_object); + }); // for n } } // namespace diff --git a/onnxruntime/test/platform/threadpool_test.cc b/onnxruntime/test/platform/threadpool_test.cc new file mode 100644 index 0000000000..ec628bc293 --- /dev/null +++ b/onnxruntime/test/platform/threadpool_test.cc @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/platform/threadpool.h" + +#include + +#include "gtest/gtest.h" +#include +#include +#include +#include + +using namespace onnxruntime::concurrency; + +namespace { + +struct TestData { + explicit TestData(int num) : data(num, 0) {} + std::vector data; + std::mutex mutex; +}; + +// This unittest tests ThreadPool function by counting the number of calls to function with each index. +// the function should be called exactly once for each element. + +std::unique_ptr CreateTestData(int num) { + return onnxruntime::make_unique(num); +} + +void IncrementElement(TestData& test_data, int i) { + std::lock_guard lock(test_data.mutex); + test_data.data[i]++; +} + +void ValidateTestData(TestData& test_data) { + ASSERT_TRUE(std::count_if(test_data.data.cbegin(), + test_data.data.cend(), + [](int i) { return i != 1; }) == 0); +} + +void CreateThreadPoolAndTest(const std::string& name, int num_threads, const std::function& test_body) { + auto tp = onnxruntime::make_unique(name, num_threads); + test_body(tp.get()); +} + +void TestParallelFor(const std::string& name, int num_threads, int num_tasks) { + auto test_data = CreateTestData(num_tasks); + CreateThreadPoolAndTest(name, num_threads, [&](ThreadPool* tp) { + tp->ParallelFor(num_tasks, [&](int i) { + IncrementElement(*test_data, i); + }); + }); + ValidateTestData(*test_data); +} + +void TestBatchParallelFor(const std::string& name, int num_threads, int num_tasks, int batch_size) { + auto test_data = CreateTestData(num_tasks); + CreateThreadPoolAndTest(name, num_threads, [&](ThreadPool* tp) { + tp->BatchParallelFor( + num_tasks, [&](int i) { + IncrementElement(*test_data, i); + }, + batch_size); + }); + ValidateTestData(*test_data); +} + +} // namespace + +TEST(ThreadPoolTest, TestParallelFor_2_Thread_NoTask) { + TestParallelFor("TestParallelFor_2_Thread_NoTask", 2, 0); +} + +TEST(ThreadPoolTest, TestParallelFor_2_Thread_50_Task) { + TestParallelFor("TestParallelFor_2_Thread_50_Task", 2, 50); +} + +TEST(ThreadPoolTest, TestParallelFor_1_Thread_50_Task) { + TestParallelFor("TestParallelFor_1_Thread_50_Task", 1, 50); +} + +TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_10_Batch) { + TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_10_Batch", 2, 50, 10); +} + +TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_0_Batch) { + TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_0_Batch", 2, 50, 0); +} + +TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_1_Batch) { + TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_1_Batch", 2, 50, 1); +} + +TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_100_Batch) { + TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_100_Batch", 2, 50, 100); +} + +TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_81_Task_20_Batch) { + TestBatchParallelFor("TestBatchParallelFor_2_Thread_81_Task_20_Batch", 2, 81, 20); +}