From 5e8952ef89b52aec7a0646840cf6fe693cedfd48 Mon Sep 17 00:00:00 2001 From: Tim Harris Date: Wed, 28 Oct 2020 09:49:18 +0000 Subject: [PATCH] ThreadPool clean up : mm_pause in loops, correctly spin-then-wait, and adopt static methods consistently in the API (#5590) Description: This change makes three changes to the ThreadPool class to clean up issues identified during performance analysis and optimization. (1) It uses mm_pause intrinsics in spin loops, helping avoid consuming pipeline resources while waiting. (2) It re-organizes the spin-then-steal loop for work distribution to start out spinning as intended, rather than to start out trying to steal. (3) It updates the ThreadPool class's API to be consistent in the use of static methods for public functions. The PR includes minor doc updates and corresponding changes to test cases. Motivation and Context The change helps ensure consistency in behavior between the OpenMP and Eigen-based implementations. Unlike the instance methods, the static methods abstract over the different ways in which threading can be implemented; they will map onto the OpenMP or Eigen-based implementations when threading is used. When threading is not used they will run work sequentially. --- docs/NotesOnThreading.md | 5 +- include/onnxruntime/core/common/spin_pause.h | 28 ++++++++ include/onnxruntime/core/platform/Barrier.h | 3 +- .../platform/EigenNonBlockingThreadPool.h | 4 +- .../onnxruntime/core/platform/threadpool.h | 69 ++++++++++++------- onnxruntime/core/common/threadpool.cc | 22 +++--- .../core/framework/parallel_executor.cc | 2 +- .../providers/cpu/math/element_wise_ops.h | 3 +- .../test/framework/inference_session_test.cc | 4 +- onnxruntime/test/onnx/dataitem_request.cc | 2 +- onnxruntime/test/onnx/testcase_request.cc | 2 +- onnxruntime/test/platform/threadpool_test.cc | 24 +++---- .../orttraining/models/runner/data_loader.cc | 3 +- 13 files changed, 113 insertions(+), 58 deletions(-) create mode 100644 include/onnxruntime/core/common/spin_pause.h diff --git a/docs/NotesOnThreading.md b/docs/NotesOnThreading.md index f97cc77b26..c0e8bf4aee 100644 --- a/docs/NotesOnThreading.md +++ b/docs/NotesOnThreading.md @@ -9,11 +9,14 @@ When developing an op, please use these abstractions to parallelize your code. T When OpenMP is enabled, they resort to using OpenMP. When OpenMP is disabled they resort to sequential execution if the threadpool ptr is NULL or schedule the tasks on the threadpool otherwise. Examples of these abstractions are: ([threadpool.h](https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/platform/threadpool.h) has more documentation for these) -* TryBatchParallelFor * TryParallelFor * TrySimpleParallelFor +* TryBatchParallelFor +* ShouldParallelize * DegreeOfParallelism +These static methods abstract over the different implementation choices. They can run over the ORT thread pool, or run over OpenMP, or run sequentially. + **Please do not write #ifdef pragma omp in operator code**. For intra op parallelism ORT users can use either OpenMP or ORT threadpool. The choice of using OpenMP is indicated by building ORT with ```--use_openmp``` switch. For inter op parallelism, however, we always use the ORT threadpool. diff --git a/include/onnxruntime/core/common/spin_pause.h b/include/onnxruntime/core/common/spin_pause.h new file mode 100644 index 0000000000..49b71e5567 --- /dev/null +++ b/include/onnxruntime/core/common/spin_pause.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(_M_AMD64) +#include +#endif + +#if defined(__x86_64__) +#include +#endif + +namespace onnxruntime { + +namespace concurrency { + +// Intrinsic to use in spin-loops + +inline void SpinPause() { +#if defined(_M_AMD64) || defined(__x86_64__) + _mm_pause(); +#endif +} + +} // namespace concurrency + +} // namespace onnxruntime diff --git a/include/onnxruntime/core/platform/Barrier.h b/include/onnxruntime/core/platform/Barrier.h index f851cae40e..915cfc5095 100644 --- a/include/onnxruntime/core/platform/Barrier.h +++ b/include/onnxruntime/core/platform/Barrier.h @@ -9,6 +9,7 @@ #include +#include "core/common/spin_pause.h" #include "core/platform/ort_mutex.h" #include @@ -48,7 +49,7 @@ class Barrier { void Wait() { if (spin_) { while ((state_ >> 1) != 0) { - /* spin */ + onnxruntime::concurrency::SpinPause(); } } else { unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); diff --git a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h index ce193ebc80..13acaedbcf 100644 --- a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h +++ b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h @@ -34,6 +34,7 @@ #endif #include "core/common/denormal.h" #include "core/common/make_unique.h" +#include "core/common/spin_pause.h" #include "core/platform/ort_mutex.h" #include "core/platform/Barrier.h" @@ -861,7 +862,8 @@ int CurrentThreadId() const EIGEN_FINAL { SetGoodWorkerHint(thread_id, true); for (int i = 0; i < spin_count && !t.f && !cancelled_ && !done_; i++) { - t = (i%steal_count == 0) ? TrySteal() : q.PopFront(); + t = ((i+1)%steal_count == 0) ? TrySteal() : q.PopFront(); + onnxruntime::concurrency::SpinPause(); } SetGoodWorkerHint(thread_id, false); diff --git a/include/onnxruntime/core/platform/threadpool.h b/include/onnxruntime/core/platform/threadpool.h index adc4215db1..2af086d656 100644 --- a/include/onnxruntime/core/platform/threadpool.h +++ b/include/onnxruntime/core/platform/threadpool.h @@ -81,12 +81,14 @@ class ThreadPool { // synchronously if it cannot be enqueued. This will occur if the thread pool's // degree-of-parallelism is 1, but it may also occur for implementation-dependent // reasons such as if queues used for buffering work are full. - void Schedule(std::function fn); - - // Returns the number of shards used by ParallelForFixedBlockSizeScheduling - // with these parameters. - int NumShardsUsedByFixedBlockSizeScheduling(std::ptrdiff_t total, - std::ptrdiff_t block_size) const; + static void Schedule(ThreadPool *tp, + std::function fn) { + if (tp) { + tp->Schedule(fn); + } else { + fn(); + } + } // ParallelFor shards the "total" units of work assuming each unit of work // having roughly "cost_per_unit" cost, in cycles. Each unit of work is @@ -99,32 +101,17 @@ class ThreadPool { // Context creation. Underestimating may not fully make use of the specified // parallelism, and may also cause inefficiencies due to load balancing // issues and stragglers. - void ParallelFor(std::ptrdiff_t total, double cost_per_unit, - const std::function& fn); - static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total, double cost_per_unit, + + static void TryParallelFor(ThreadPool* tp, std::ptrdiff_t total, double cost_per_unit, const std::function& fn) { TryParallelFor(tp, total, TensorOpCost{0, 0, static_cast(cost_per_unit)}, fn); } - void ParallelFor(std::ptrdiff_t total, const TensorOpCost& cost_per_unit, - const std::function& fn); - - static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total, const TensorOpCost& cost_per_unit, + static void TryParallelFor(ThreadPool* tp, std::ptrdiff_t total, const TensorOpCost& cost_per_unit, const std::function& fn); - // Return the degree of parallelism that code should assume when using the thread pool. - // This API takes into account if OpenMP is enabled/disabled, and if the thread pool ptr is - // nullptr. It decouples the degree of parallelism for use with the thread pool from - // the implementation choice of whether this matches the number of threads created in - // the pool. - // - // Currently, a loop with degree-of-parallelism N is supported by a pool of N-1 threads - // working in combination with the thread initiating the loop. - static int DegreeOfParallelism(const concurrency::ThreadPool* tp); - // Directly schedule the 'total' tasks to the underlying threadpool, without // cutting them by halves - void SimpleParallelFor(std::ptrdiff_t total, const std::function& fn); inline static void TrySimpleParallelFor(ThreadPool* tp, std::ptrdiff_t total, const std::function& fn) { @@ -225,6 +212,27 @@ class ThreadPool { return info; } + //...................................................................... + // + // The following static methods take into account whether OpenMP is + // enabled/disabled, and if the thread pool pointer is nullptr + // during sequential execution. + + // Provide a hint to the caller for whether or not to parallelize + // work. This lets a caller switch to a sequential version of an + // algorithm rather than using calls via the ParallelFor functions. + + static bool ShouldParallelize(const ThreadPool *tp); + + // Return the degree of parallelism that code should assume when using the thread pool. + // It decouples the degree of parallelism for use with the thread pool from + // the implementation choice of whether this matches the number of threads created in + // the pool. + // + // Currently, a loop with degree-of-parallelism N is supported by a pool of N-1 threads + // working in combination with the thread initiating the loop. + static int DegreeOfParallelism(const ThreadPool* tp); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ThreadPool); private: @@ -250,7 +258,6 @@ class ThreadPool { // Each shard may be executed on a different thread in parallel, depending on // the number of threads available in the pool. // When (i+1)*block_size > total, fn(i*block_size, total) is called instead. - // Here, k = NumShardsUsedByFixedBlockSizeScheduling(total, block_size). // Requires 0 < block_size <= total. void ParallelForFixedBlockSizeScheduling(std::ptrdiff_t total, std::ptrdiff_t block_size, const std::function& fn); @@ -262,6 +269,18 @@ class ThreadPool { bool ShouldParallelizeLoop(const std::ptrdiff_t num_iterations, const std::ptrdiff_t block_size = 1) const; + // Internal (non-static) parallel loop methods. Unlike the public static methods, + // these will not handle the cases of OpenMP builds. or builds without a threadpool. + void ParallelFor(std::ptrdiff_t total, double cost_per_unit, + const std::function& fn); + + void ParallelFor(std::ptrdiff_t total, const TensorOpCost& cost_per_unit, + const std::function& fn); + + void SimpleParallelFor(std::ptrdiff_t total, const std::function& fn); + + void Schedule(std::function fn); + ThreadOptions thread_options_; // If a thread pool is created with degree_of_parallelism != 1 then an underlying diff --git a/onnxruntime/core/common/threadpool.cc b/onnxruntime/core/common/threadpool.cc index b0471f7b65..f0026de9b7 100644 --- a/onnxruntime/core/common/threadpool.cc +++ b/onnxruntime/core/common/threadpool.cc @@ -232,16 +232,6 @@ bool ThreadPool::ShouldParallelizeLoop(const std::ptrdiff_t num_iterations, return true; } -int ThreadPool::NumShardsUsedByFixedBlockSizeScheduling(const std::ptrdiff_t total, - const std::ptrdiff_t block_size) const { - if (!ShouldParallelizeLoop(total, block_size)) { - return 1; - } else { - // TODO:check overflow? - return static_cast((total + block_size - 1) / block_size); - } -} - using CostModel = Eigen::TensorCostModel; // Calculates block size based on (1) the iteration cost and (2) parallel @@ -324,6 +314,10 @@ void ThreadPool::ParallelFor(std::ptrdiff_t total, double cost_per_unit, ParallelFor(total, TensorOpCost{0, 0, static_cast(cost_per_unit)}, fn); } +bool ThreadPool::ShouldParallelize(const concurrency::ThreadPool* tp) { + return (DegreeOfParallelism(tp) != 1); +} + int ThreadPool::DegreeOfParallelism(const concurrency::ThreadPool* tp) { #ifdef _OPENMP // When using OpenMP, omp_get_num_threads() returns the number of threads in the @@ -362,8 +356,12 @@ void ThreadPool::TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t tota const std::function& fn) { #ifdef _OPENMP ORT_ENFORCE(total >= 0); - if (total == 1 || total == 0) { - fn(0, total); + if (total == 0) { + return; + } + + if (total == 1) { + fn(0, 1); return; } diff --git a/onnxruntime/core/framework/parallel_executor.cc b/onnxruntime/core/framework/parallel_executor.cc index 72edb410bd..d7865b2767 100644 --- a/onnxruntime/core/framework/parallel_executor.cc +++ b/onnxruntime/core/framework/parallel_executor.cc @@ -288,7 +288,7 @@ void ParallelExecutor::EnqueueNode(size_t p_node_index, const SessionState& sess out_standings_++; } - executor_pool_->Schedule([this, p_node_index, &session_state, &logger]() { + onnxruntime::concurrency::ThreadPool::Schedule(executor_pool_, [this, p_node_index, &session_state, &logger]() { auto create_exception_message = [p_node_index, &session_state](const std::exception* ex) { const auto* node = session_state.GetGraphViewer().GetNode(p_node_index); diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.h b/onnxruntime/core/providers/cpu/math/element_wise_ops.h index 20eb1ed726..e17aedbc35 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.h +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.h @@ -925,7 +925,8 @@ template void BroadcastLooper(TBroadcastHelper& helper, const ProcessBroadcastSpanFuncs& functors) { ORT_ENFORCE(helper.HaveTwoTensorInputs(), "BroadcastLooper requires two tensors as input."); - if (helper.Threadpool() != nullptr && helper.SingleSpanOutput()) { + bool par_available = concurrency::ThreadPool::ShouldParallelize(helper.Threadpool()); + if (par_available && helper.SingleSpanOutput()) { ParallelizeSingleSpan(helper, functors); } else { if (helper.IsInput0Scalar()) { diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 304b08f9c4..ed725461f9 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -50,6 +50,8 @@ using namespace std; using namespace ONNX_NAMESPACE; using namespace onnxruntime::logging; +using namespace onnxruntime::concurrency; + namespace { struct KernelRegistryAndStatus { std::shared_ptr kernel_registry = std::make_shared(); @@ -2547,7 +2549,7 @@ void VerifyThreadPoolWithDenormalAsZero(onnxruntime::concurrency::ThreadPool* tp std::array input_double; input_double.fill(denormal_double); - tp->SimpleParallelFor(num_tasks, [&](std::ptrdiff_t i) { + ThreadPool::TrySimpleParallelFor(tp, num_tasks, [&](std::ptrdiff_t i) { input_float[i] *= 2; input_double[i] *= 2; }); diff --git a/onnxruntime/test/onnx/dataitem_request.cc b/onnxruntime/test/onnx/dataitem_request.cc index 3ab04a931d..168067558e 100644 --- a/onnxruntime/test/onnx/dataitem_request.cc +++ b/onnxruntime/test/onnx/dataitem_request.cc @@ -42,7 +42,7 @@ void DataTaskRequestContext::Request(const Callback& cb, concurrency::ThreadPool std::unique_ptr self(new DataTaskRequestContext(cb, c, session, allocator, task_id)); CallableFactory f(self.get()); auto runnable = f.GetCallable<&DataTaskRequestContext::RunAsync>(); - tp->Schedule([runnable]() { runnable.Invoke(); }); + onnxruntime::concurrency::ThreadPool::Schedule(tp, [runnable]() { runnable.Invoke(); }); self.release(); } diff --git a/onnxruntime/test/onnx/testcase_request.cc b/onnxruntime/test/onnx/testcase_request.cc index 10de15aa12..35469c0e81 100644 --- a/onnxruntime/test/onnx/testcase_request.cc +++ b/onnxruntime/test/onnx/testcase_request.cc @@ -88,7 +88,7 @@ void TestCaseRequestContext::Request(const Callback& cb, PThreadPool tpool, std::unique_ptr self(new TestCaseRequestContext(cb, tpool, c, env, session_opts, test_case_id)); CallableFactory f(self.get()); auto runnable = f.GetCallable<&TestCaseRequestContext::RunAsync>(); - tpool->Schedule([runnable, concurrent_runs]() { runnable.Invoke(concurrent_runs); }); + onnxruntime::concurrency::ThreadPool::Schedule(tpool, [runnable, concurrent_runs]() { runnable.Invoke(concurrent_runs); }); self.release(); } diff --git a/onnxruntime/test/platform/threadpool_test.cc b/onnxruntime/test/platform/threadpool_test.cc index 34ca40b87e..f2e4930081 100644 --- a/onnxruntime/test/platform/threadpool_test.cc +++ b/onnxruntime/test/platform/threadpool_test.cc @@ -52,7 +52,7 @@ void CreateThreadPoolAndTest(const std::string&, int num_threads, const std::fun void TestParallelFor(const std::string& name, int num_threads, int num_tasks) { auto test_data = CreateTestData(num_tasks); CreateThreadPoolAndTest(name, num_threads, [&](ThreadPool* tp) { - tp->SimpleParallelFor(num_tasks, [&](std::ptrdiff_t i) { IncrementElement(*test_data, i); }); + ThreadPool::TrySimpleParallelFor(tp, num_tasks, [&](std::ptrdiff_t i) { IncrementElement(*test_data, i); }); }); ValidateTestData(*test_data); } @@ -82,15 +82,15 @@ void TestMultipleParallelFor(const std::string& name, int num_threads, int num_c // For a range of scenarios, run some tests via the thread pool, and one directly for (int c = 0; c < num_concurrent - 1; c++) { - tp->Schedule([&, c]() { - tp->SimpleParallelFor(num_tasks, [&](std::ptrdiff_t i) { - IncrementElement(*td[c], i); + ThreadPool::Schedule(tp, [&, c]() { + ThreadPool::TrySimpleParallelFor(tp, num_tasks, [&](std::ptrdiff_t i) { + IncrementElement(*td[c], i); + }); + b.Notify(); }); - b.Notify(); - }); } - tp->SimpleParallelFor(num_tasks, [&](std::ptrdiff_t i) { + ThreadPool::TrySimpleParallelFor(tp, num_tasks, [&](std::ptrdiff_t i) { IncrementElement(*td[num_concurrent - 1], i); }); @@ -117,7 +117,7 @@ void TestBurstScheduling(const std::string& name, int num_tasks) { CreateThreadPoolAndTest(name, 2, [&](ThreadPool* tp) { // First variant : schedule from outside the pool for (int tasks = 0; tasks < num_tasks; tasks++) { - tp->Schedule([&]() { + ThreadPool::Schedule(tp, [&]() { ctr++; }); } @@ -125,9 +125,9 @@ void TestBurstScheduling(const std::string& name, int num_tasks) { ASSERT_TRUE(ctr == num_tasks); CreateThreadPoolAndTest(name, 2, [&](ThreadPool* tp) { // Second variant : schedule from inside the pool - tp->Schedule([&, tp]() { + ThreadPool::Schedule(tp, [&, tp]() { for (int tasks = 0; tasks < num_tasks; tasks++) { - tp->Schedule([&]() { + ThreadPool::Schedule(tp, [&]() { ctr++; }); } @@ -153,7 +153,7 @@ void TestPoolCreation(const std::string&, int iter) { nullptr, num_threads, true); - tp->ParallelFor(per_iter, 0.0, + ThreadPool::TryParallelFor(tp.get(), per_iter, 0.0, [&](std::ptrdiff_t s, std::ptrdiff_t e) { ctr += e - s; }); @@ -302,7 +302,7 @@ TEST(ThreadPoolTest, TestStackSize) { Notification n; ULONG_PTR low_limit, high_limit; bool has_thread_limit_info = false; - tp->Schedule([&]() { + ThreadPool::Schedule(tp.get(), [&]() { HMODULE kernel32_module = GetModuleHandle(TEXT("kernel32.dll")); assert(kernel32_module != nullptr); FnGetCurrentThreadStackLimits GetTS = diff --git a/orttraining/orttraining/models/runner/data_loader.cc b/orttraining/orttraining/models/runner/data_loader.cc index 1376db0629..756541b6af 100644 --- a/orttraining/orttraining/models/runner/data_loader.cc +++ b/orttraining/orttraining/models/runner/data_loader.cc @@ -12,6 +12,7 @@ namespace training { using FileInputStream = google::protobuf::io::FileInputStream; using CodedInputStream = google::protobuf::io::CodedInputStream; +using ThreadPool = onnxruntime::concurrency::ThreadPool; static std::vector GetAllDataFiles(const PathString& dir_path) { std::vector data_files; @@ -117,7 +118,7 @@ void DataLoader::EnsurePreloadedOrThrow() { } void DataLoader::LoadAndRemoveInternalAsync(size_t index_to_load, bool need_remove, size_t index_to_remove) { - data_loader_thread_pool_->Schedule([this, index_to_load, need_remove, index_to_remove]() { + ThreadPool::Schedule(data_loader_thread_pool_.get(), [this, index_to_load, need_remove, index_to_remove]() { std::shared_ptr data_set = std::make_shared(input_tensor_names_); if (index_to_load >= NumShards()) { LOGS_DEFAULT(WARNING)