ThreadPool clean up : mm_pause in loops, correctly spin-then-wait, and adopt static methods consistently in the API (#5590)

Description: This change makes three changes to the ThreadPool class to clean up issues identified during performance analysis and optimization. (1) It uses mm_pause intrinsics in spin loops, helping avoid consuming pipeline resources while waiting. (2) It re-organizes the spin-then-steal loop for work distribution to start out spinning as intended, rather than to start out trying to steal. (3) It updates the ThreadPool class's API to be consistent in the use of static methods for public functions. The PR includes minor doc updates and corresponding changes to test cases.

Motivation and Context
The change helps ensure consistency in behavior between the OpenMP and Eigen-based implementations. Unlike the instance methods, the static methods abstract over the different ways in which threading can be implemented; they will map onto the OpenMP or Eigen-based implementations when threading is used. When threading is not used they will run work sequentially.
This commit is contained in:
Tim Harris 2020-10-28 09:49:18 +00:00 committed by GitHub
parent 92662659ba
commit 5e8952ef89
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 113 additions and 58 deletions

View file

@ -9,11 +9,14 @@ When developing an op, please use these abstractions to parallelize your code. T
When OpenMP is enabled, they resort to using OpenMP. When OpenMP is disabled they resort to sequential execution if the threadpool ptr is NULL or schedule the tasks on the threadpool otherwise.
Examples of these abstractions are: ([threadpool.h](https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/platform/threadpool.h) has more documentation for these)
* TryBatchParallelFor
* TryParallelFor
* TrySimpleParallelFor
* TryBatchParallelFor
* ShouldParallelize
* DegreeOfParallelism
These static methods abstract over the different implementation choices. They can run over the ORT thread pool, or run over OpenMP, or run sequentially.
**Please do not write #ifdef pragma omp in operator code**.
For intra op parallelism ORT users can use either OpenMP or ORT threadpool. The choice of using OpenMP is indicated by building ORT with ```--use_openmp``` switch. For inter op parallelism, however, we always use the ORT threadpool.

View file

@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#if defined(_M_AMD64)
#include <intrin.h>
#endif
#if defined(__x86_64__)
#include <xmmintrin.h>
#endif
namespace onnxruntime {
namespace concurrency {
// Intrinsic to use in spin-loops
inline void SpinPause() {
#if defined(_M_AMD64) || defined(__x86_64__)
_mm_pause();
#endif
}
} // namespace concurrency
} // namespace onnxruntime

View file

@ -9,6 +9,7 @@
#include <assert.h>
#include "core/common/spin_pause.h"
#include "core/platform/ort_mutex.h"
#include <mutex>
@ -48,7 +49,7 @@ class Barrier {
void Wait() {
if (spin_) {
while ((state_ >> 1) != 0) {
/* spin */
onnxruntime::concurrency::SpinPause();
}
} else {
unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);

View file

@ -34,6 +34,7 @@
#endif
#include "core/common/denormal.h"
#include "core/common/make_unique.h"
#include "core/common/spin_pause.h"
#include "core/platform/ort_mutex.h"
#include "core/platform/Barrier.h"
@ -861,7 +862,8 @@ int CurrentThreadId() const EIGEN_FINAL {
SetGoodWorkerHint(thread_id, true);
for (int i = 0; i < spin_count && !t.f && !cancelled_ && !done_; i++) {
t = (i%steal_count == 0) ? TrySteal() : q.PopFront();
t = ((i+1)%steal_count == 0) ? TrySteal() : q.PopFront();
onnxruntime::concurrency::SpinPause();
}
SetGoodWorkerHint(thread_id, false);

View file

@ -81,12 +81,14 @@ class ThreadPool {
// synchronously if it cannot be enqueued. This will occur if the thread pool's
// degree-of-parallelism is 1, but it may also occur for implementation-dependent
// reasons such as if queues used for buffering work are full.
void Schedule(std::function<void()> fn);
// Returns the number of shards used by ParallelForFixedBlockSizeScheduling
// with these parameters.
int NumShardsUsedByFixedBlockSizeScheduling(std::ptrdiff_t total,
std::ptrdiff_t block_size) const;
static void Schedule(ThreadPool *tp,
std::function<void()> fn) {
if (tp) {
tp->Schedule(fn);
} else {
fn();
}
}
// ParallelFor shards the "total" units of work assuming each unit of work
// having roughly "cost_per_unit" cost, in cycles. Each unit of work is
@ -99,32 +101,17 @@ class ThreadPool {
// Context creation. Underestimating may not fully make use of the specified
// parallelism, and may also cause inefficiencies due to load balancing
// issues and stragglers.
void ParallelFor(std::ptrdiff_t total, double cost_per_unit,
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t last)>& fn);
static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total, double cost_per_unit,
static void TryParallelFor(ThreadPool* tp, std::ptrdiff_t total, double cost_per_unit,
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t last)>& fn) {
TryParallelFor(tp, total, TensorOpCost{0, 0, static_cast<double>(cost_per_unit)}, fn);
}
void ParallelFor(std::ptrdiff_t total, const TensorOpCost& cost_per_unit,
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t)>& fn);
static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total, const TensorOpCost& cost_per_unit,
static void TryParallelFor(ThreadPool* tp, std::ptrdiff_t total, const TensorOpCost& cost_per_unit,
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t last)>& fn);
// Return the degree of parallelism that code should assume when using the thread pool.
// This API takes into account if OpenMP is enabled/disabled, and if the thread pool ptr is
// nullptr. It decouples the degree of parallelism for use with the thread pool from
// the implementation choice of whether this matches the number of threads created in
// the pool.
//
// Currently, a loop with degree-of-parallelism N is supported by a pool of N-1 threads
// working in combination with the thread initiating the loop.
static int DegreeOfParallelism(const concurrency::ThreadPool* tp);
// Directly schedule the 'total' tasks to the underlying threadpool, without
// cutting them by halves
void SimpleParallelFor(std::ptrdiff_t total, const std::function<void(std::ptrdiff_t)>& fn);
inline static void TrySimpleParallelFor(ThreadPool* tp, std::ptrdiff_t total,
const std::function<void(std::ptrdiff_t)>& fn) {
@ -225,6 +212,27 @@ class ThreadPool {
return info;
}
//......................................................................
//
// The following static methods take into account whether OpenMP is
// enabled/disabled, and if the thread pool pointer is nullptr
// during sequential execution.
// Provide a hint to the caller for whether or not to parallelize
// work. This lets a caller switch to a sequential version of an
// algorithm rather than using calls via the ParallelFor functions.
static bool ShouldParallelize(const ThreadPool *tp);
// Return the degree of parallelism that code should assume when using the thread pool.
// It decouples the degree of parallelism for use with the thread pool from
// the implementation choice of whether this matches the number of threads created in
// the pool.
//
// Currently, a loop with degree-of-parallelism N is supported by a pool of N-1 threads
// working in combination with the thread initiating the loop.
static int DegreeOfParallelism(const ThreadPool* tp);
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ThreadPool);
private:
@ -250,7 +258,6 @@ class ThreadPool {
// Each shard may be executed on a different thread in parallel, depending on
// the number of threads available in the pool.
// When (i+1)*block_size > total, fn(i*block_size, total) is called instead.
// Here, k = NumShardsUsedByFixedBlockSizeScheduling(total, block_size).
// Requires 0 < block_size <= total.
void ParallelForFixedBlockSizeScheduling(std::ptrdiff_t total, std::ptrdiff_t block_size,
const std::function<void(std::ptrdiff_t, std::ptrdiff_t)>& fn);
@ -262,6 +269,18 @@ class ThreadPool {
bool ShouldParallelizeLoop(const std::ptrdiff_t num_iterations,
const std::ptrdiff_t block_size = 1) const;
// Internal (non-static) parallel loop methods. Unlike the public static methods,
// these will not handle the cases of OpenMP builds. or builds without a threadpool.
void ParallelFor(std::ptrdiff_t total, double cost_per_unit,
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t last)>& fn);
void ParallelFor(std::ptrdiff_t total, const TensorOpCost& cost_per_unit,
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t)>& fn);
void SimpleParallelFor(std::ptrdiff_t total, const std::function<void(std::ptrdiff_t)>& fn);
void Schedule(std::function<void()> fn);
ThreadOptions thread_options_;
// If a thread pool is created with degree_of_parallelism != 1 then an underlying

View file

@ -232,16 +232,6 @@ bool ThreadPool::ShouldParallelizeLoop(const std::ptrdiff_t num_iterations,
return true;
}
int ThreadPool::NumShardsUsedByFixedBlockSizeScheduling(const std::ptrdiff_t total,
const std::ptrdiff_t block_size) const {
if (!ShouldParallelizeLoop(total, block_size)) {
return 1;
} else {
// TODO:check overflow?
return static_cast<int>((total + block_size - 1) / block_size);
}
}
using CostModel = Eigen::TensorCostModel<Eigen::ThreadPoolDevice>;
// Calculates block size based on (1) the iteration cost and (2) parallel
@ -324,6 +314,10 @@ void ThreadPool::ParallelFor(std::ptrdiff_t total, double cost_per_unit,
ParallelFor(total, TensorOpCost{0, 0, static_cast<double>(cost_per_unit)}, fn);
}
bool ThreadPool::ShouldParallelize(const concurrency::ThreadPool* tp) {
return (DegreeOfParallelism(tp) != 1);
}
int ThreadPool::DegreeOfParallelism(const concurrency::ThreadPool* tp) {
#ifdef _OPENMP
// When using OpenMP, omp_get_num_threads() returns the number of threads in the
@ -362,8 +356,12 @@ void ThreadPool::TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t tota
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t last)>& fn) {
#ifdef _OPENMP
ORT_ENFORCE(total >= 0);
if (total == 1 || total == 0) {
fn(0, total);
if (total == 0) {
return;
}
if (total == 1) {
fn(0, 1);
return;
}

View file

@ -288,7 +288,7 @@ void ParallelExecutor::EnqueueNode(size_t p_node_index, const SessionState& sess
out_standings_++;
}
executor_pool_->Schedule([this, p_node_index, &session_state, &logger]() {
onnxruntime::concurrency::ThreadPool::Schedule(executor_pool_, [this, p_node_index, &session_state, &logger]() {
auto create_exception_message = [p_node_index, &session_state](const std::exception* ex) {
const auto* node = session_state.GetGraphViewer().GetNode(p_node_index);

View file

@ -925,7 +925,8 @@ template <typename TBroadcastHelper>
void BroadcastLooper(TBroadcastHelper& helper, const ProcessBroadcastSpanFuncs& functors) {
ORT_ENFORCE(helper.HaveTwoTensorInputs(), "BroadcastLooper requires two tensors as input.");
if (helper.Threadpool() != nullptr && helper.SingleSpanOutput()) {
bool par_available = concurrency::ThreadPool::ShouldParallelize(helper.Threadpool());
if (par_available && helper.SingleSpanOutput()) {
ParallelizeSingleSpan(helper, functors);
} else {
if (helper.IsInput0Scalar()) {

View file

@ -50,6 +50,8 @@
using namespace std;
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::logging;
using namespace onnxruntime::concurrency;
namespace {
struct KernelRegistryAndStatus {
std::shared_ptr<onnxruntime::KernelRegistry> kernel_registry = std::make_shared<onnxruntime::KernelRegistry>();
@ -2547,7 +2549,7 @@ void VerifyThreadPoolWithDenormalAsZero(onnxruntime::concurrency::ThreadPool* tp
std::array<double, num_tasks> input_double;
input_double.fill(denormal_double);
tp->SimpleParallelFor(num_tasks, [&](std::ptrdiff_t i) {
ThreadPool::TrySimpleParallelFor(tp, num_tasks, [&](std::ptrdiff_t i) {
input_float[i] *= 2;
input_double[i] *= 2;
});

View file

@ -42,7 +42,7 @@ void DataTaskRequestContext::Request(const Callback& cb, concurrency::ThreadPool
std::unique_ptr<DataTaskRequestContext> self(new DataTaskRequestContext(cb, c, session, allocator, task_id));
CallableFactory<DataTaskRequestContext, void> f(self.get());
auto runnable = f.GetCallable<&DataTaskRequestContext::RunAsync>();
tp->Schedule([runnable]() { runnable.Invoke(); });
onnxruntime::concurrency::ThreadPool::Schedule(tp, [runnable]() { runnable.Invoke(); });
self.release();
}

View file

@ -88,7 +88,7 @@ void TestCaseRequestContext::Request(const Callback& cb, PThreadPool tpool,
std::unique_ptr<TestCaseRequestContext> self(new TestCaseRequestContext(cb, tpool, c, env, session_opts, test_case_id));
CallableFactory<TestCaseRequestContext, void, size_t> f(self.get());
auto runnable = f.GetCallable<&TestCaseRequestContext::RunAsync>();
tpool->Schedule([runnable, concurrent_runs]() { runnable.Invoke(concurrent_runs); });
onnxruntime::concurrency::ThreadPool::Schedule(tpool, [runnable, concurrent_runs]() { runnable.Invoke(concurrent_runs); });
self.release();
}

View file

@ -52,7 +52,7 @@ void CreateThreadPoolAndTest(const std::string&, int num_threads, const std::fun
void TestParallelFor(const std::string& name, int num_threads, int num_tasks) {
auto test_data = CreateTestData(num_tasks);
CreateThreadPoolAndTest(name, num_threads, [&](ThreadPool* tp) {
tp->SimpleParallelFor(num_tasks, [&](std::ptrdiff_t i) { IncrementElement(*test_data, i); });
ThreadPool::TrySimpleParallelFor(tp, num_tasks, [&](std::ptrdiff_t i) { IncrementElement(*test_data, i); });
});
ValidateTestData(*test_data);
}
@ -82,15 +82,15 @@ void TestMultipleParallelFor(const std::string& name, int num_threads, int num_c
// For a range of scenarios, run some tests via the thread pool, and one directly
for (int c = 0; c < num_concurrent - 1; c++) {
tp->Schedule([&, c]() {
tp->SimpleParallelFor(num_tasks, [&](std::ptrdiff_t i) {
IncrementElement(*td[c], i);
ThreadPool::Schedule(tp, [&, c]() {
ThreadPool::TrySimpleParallelFor(tp, num_tasks, [&](std::ptrdiff_t i) {
IncrementElement(*td[c], i);
});
b.Notify();
});
b.Notify();
});
}
tp->SimpleParallelFor(num_tasks, [&](std::ptrdiff_t i) {
ThreadPool::TrySimpleParallelFor(tp, num_tasks, [&](std::ptrdiff_t i) {
IncrementElement(*td[num_concurrent - 1], i);
});
@ -117,7 +117,7 @@ void TestBurstScheduling(const std::string& name, int num_tasks) {
CreateThreadPoolAndTest(name, 2, [&](ThreadPool* tp) {
// First variant : schedule from outside the pool
for (int tasks = 0; tasks < num_tasks; tasks++) {
tp->Schedule([&]() {
ThreadPool::Schedule(tp, [&]() {
ctr++;
});
}
@ -125,9 +125,9 @@ void TestBurstScheduling(const std::string& name, int num_tasks) {
ASSERT_TRUE(ctr == num_tasks);
CreateThreadPoolAndTest(name, 2, [&](ThreadPool* tp) {
// Second variant : schedule from inside the pool
tp->Schedule([&, tp]() {
ThreadPool::Schedule(tp, [&, tp]() {
for (int tasks = 0; tasks < num_tasks; tasks++) {
tp->Schedule([&]() {
ThreadPool::Schedule(tp, [&]() {
ctr++;
});
}
@ -153,7 +153,7 @@ void TestPoolCreation(const std::string&, int iter) {
nullptr,
num_threads,
true);
tp->ParallelFor(per_iter, 0.0,
ThreadPool::TryParallelFor(tp.get(), per_iter, 0.0,
[&](std::ptrdiff_t s, std::ptrdiff_t e) {
ctr += e - s;
});
@ -302,7 +302,7 @@ TEST(ThreadPoolTest, TestStackSize) {
Notification n;
ULONG_PTR low_limit, high_limit;
bool has_thread_limit_info = false;
tp->Schedule([&]() {
ThreadPool::Schedule(tp.get(), [&]() {
HMODULE kernel32_module = GetModuleHandle(TEXT("kernel32.dll"));
assert(kernel32_module != nullptr);
FnGetCurrentThreadStackLimits GetTS =

View file

@ -12,6 +12,7 @@ namespace training {
using FileInputStream = google::protobuf::io::FileInputStream;
using CodedInputStream = google::protobuf::io::CodedInputStream;
using ThreadPool = onnxruntime::concurrency::ThreadPool;
static std::vector<PathString> GetAllDataFiles(const PathString& dir_path) {
std::vector<PathString> data_files;
@ -117,7 +118,7 @@ void DataLoader::EnsurePreloadedOrThrow() {
}
void DataLoader::LoadAndRemoveInternalAsync(size_t index_to_load, bool need_remove, size_t index_to_remove) {
data_loader_thread_pool_->Schedule([this, index_to_load, need_remove, index_to_remove]() {
ThreadPool::Schedule(data_loader_thread_pool_.get(), [this, index_to_load, need_remove, index_to_remove]() {
std::shared_ptr<DataSet> data_set = std::make_shared<DataSet>(input_tensor_names_);
if (index_to_load >= NumShards()) {
LOGS_DEFAULT(WARNING)