mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
While investigating an unrelated issue, I noticed that the thread pool may drop tasks when a burst of 1024+ tasks is submitted by a thread from inside the pool. Today, in general, we execute work synchronously in this case. However, there is a bug where work submitted by a thread already inside the pool will be discarded instead of executed. Currently the only scenario where I can see this occurring is when the parallel executor is used with a model in which such a large number of nodes become eligible to run all at once. This PR fixes the underlying issue and adds a test case for burst-submission of work.
1015 lines
37 KiB
C++
1015 lines
37 KiB
C++
// This file is part of Eigen, a lightweight C++ template library
|
|
// for linear algebra.
|
|
//
|
|
// Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com>
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla
|
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
/* Modifications Copyright (c) Microsoft. */
|
|
|
|
#include <type_traits>
|
|
|
|
#pragma once
|
|
#include "onnxruntime_config.h"
|
|
// build/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h:162:71:
|
|
// error: ignoring attributes on template argument "Eigen::PacketType<const float, Eigen::DefaultDevice>::type {aka
|
|
// __vector(4) float}" [-Werror=ignored-attributes]
|
|
#if defined(__GNUC__)
|
|
#pragma GCC diagnostic push
|
|
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
|
#elif defined(_MSC_VER)
|
|
#pragma warning(push)
|
|
#pragma warning(disable : 4127)
|
|
#pragma warning(disable : 4805)
|
|
#endif
|
|
|
|
#include "unsupported/Eigen/CXX11/ThreadPool"
|
|
|
|
#if defined(__GNUC__)
|
|
#pragma GCC diagnostic pop
|
|
#elif defined(_MSC_VER)
|
|
#pragma warning(pop)
|
|
#endif
|
|
#include "core/common/make_unique.h"
|
|
#include "core/platform/ort_mutex.h"
|
|
#include "core/platform/Barrier.h"
|
|
|
|
namespace onnxruntime {
|
|
|
|
namespace concurrency {
|
|
|
|
// Extended Eigen thread pool interface, avoiding the need to modify the ThreadPoolInterface.h
|
|
// header from the external Eigen repository.
|
|
|
|
class ExtendedThreadPoolInterface : public Eigen::ThreadPoolInterface {
|
|
public:
|
|
// Run fn with up to n degree-of-parallelism enlisting the thread pool for
|
|
// help. The degree-of-parallelism includes the caller, and so if n==1
|
|
// then the function will run directly in the caller. The fork-join
|
|
// synchronization is handled in the thread pool, and so any state captured
|
|
// by fn() is safe from concurrent access once RunInParallel returns.
|
|
virtual void RunInParallel(std::function<void()> fn, unsigned n) = 0;
|
|
};
|
|
|
|
} // namespace concurrency
|
|
|
|
template <typename Work, typename Tag, unsigned kSize>
|
|
class RunQueue {
|
|
public:
|
|
RunQueue() : front_(0), back_(0) {
|
|
// require power-of-two for fast masking
|
|
assert((kSize & (kSize - 1)) == 0);
|
|
assert(kSize > 2); // why would you do this?
|
|
assert(kSize <= (64 << 10)); // leave enough space for counter
|
|
for (unsigned i = 0; i < kSize; i++) array_[i].state.store(ElemState::kEmpty, std::memory_order_relaxed);
|
|
}
|
|
|
|
~RunQueue() {
|
|
assert(Size() == 0);
|
|
}
|
|
|
|
// PushFront inserts w at the beginning of the queue.
|
|
// If queue is full returns w, otherwise returns default-constructed Work.
|
|
Work PushFront(Work w) {
|
|
unsigned front = front_.load(std::memory_order_relaxed);
|
|
Elem& e = array_[front & kMask];
|
|
ElemState s = e.state.load(std::memory_order_relaxed);
|
|
if (s != ElemState::kEmpty ||
|
|
!e.state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire))
|
|
return w;
|
|
front_.store(front + 1 + (kSize << 1), std::memory_order_relaxed);
|
|
e.w = std::move(w);
|
|
e.tag = Tag();
|
|
e.state.store(ElemState::kReady, std::memory_order_release);
|
|
return Work();
|
|
}
|
|
|
|
// PopFront removes and returns the first element in the queue.
|
|
// If the queue was empty returns default-constructed Work.
|
|
Work PopFront() {
|
|
unsigned front;
|
|
Elem *e;
|
|
ElemState s;
|
|
|
|
// Drain revoked items from the front of the queue. CAS to busy to synchronize with
|
|
// any attempt to take the same item from the back of the queue.
|
|
do {
|
|
front = front_.load(std::memory_order_relaxed);
|
|
e = &array_[(front - 1) & kMask];
|
|
s = e->state.load(std::memory_order_relaxed);
|
|
if (s == ElemState::kRevoked &&
|
|
e->state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire)) {
|
|
e->state.store(ElemState::kEmpty, std::memory_order_release);
|
|
front = ((front - 1) & kMask2) | (front & ~kMask2);
|
|
front_.store(front, std::memory_order_relaxed);
|
|
}
|
|
} while (s == ElemState::kRevoked);
|
|
|
|
// Attempt to take next item. State kEmpty shows the queue is empty, kBusy shows
|
|
// the work is in progress on the item at the front of the queue.
|
|
if (s != ElemState::kReady ||
|
|
!e->state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire))
|
|
return Work();
|
|
Work w = std::move(e->w);
|
|
e->tag = Tag();
|
|
e->state.store(ElemState::kEmpty, std::memory_order_release);
|
|
front = ((front - 1) & kMask2) | (front & ~kMask2);
|
|
front_.store(front, std::memory_order_relaxed);
|
|
return w;
|
|
}
|
|
|
|
// PushBack adds w at the end of the queue.
|
|
// If queue is full returns w, otherwise returns default-constructed Work.
|
|
Work PushBack(Work w) {
|
|
std::unique_lock<OrtMutex> lock(mutex_);
|
|
unsigned back = back_.load(std::memory_order_relaxed);
|
|
Elem& e = array_[(back - 1) & kMask];
|
|
ElemState s = e.state.load(std::memory_order_relaxed);
|
|
if (s != ElemState::kEmpty ||
|
|
!e.state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire))
|
|
return w;
|
|
back = ((back - 1) & kMask2) | (back & ~kMask2);
|
|
back_.store(back, std::memory_order_relaxed);
|
|
e.w = std::move(w);
|
|
e.tag = Tag();
|
|
e.state.store(ElemState::kReady, std::memory_order_release);
|
|
return Work();
|
|
}
|
|
|
|
// PushBackWithTag adds w at the end of the queue. The tag value can be used on a
|
|
// subsequent call to RevokeWithTag to remove the item from the queue in combination
|
|
// with w_idx. Typically the tag will be a per-thread ID to distinguish work
|
|
// submitted from different threads.
|
|
//
|
|
// If the queue is full, returns w, otherwise returns default-constructed work.
|
|
Work PushBackWithTag(Work w, Tag tag, unsigned &w_idx) {
|
|
std::unique_lock<OrtMutex> lock(mutex_);
|
|
unsigned back = back_.load(std::memory_order_relaxed);
|
|
w_idx = (back-1) & kMask;
|
|
Elem& e = array_[w_idx];
|
|
ElemState s = e.state.load(std::memory_order_relaxed);
|
|
if (s != ElemState::kEmpty ||
|
|
!e.state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire))
|
|
return w;
|
|
back = ((back - 1) & kMask2) | (back & ~kMask2);
|
|
back_.store(back, std::memory_order_relaxed);
|
|
e.w = std::move(w);
|
|
e.tag = tag;
|
|
e.state.store(ElemState::kReady, std::memory_order_release);
|
|
return Work();
|
|
}
|
|
|
|
// PopBack removes and returns the last elements in the queue.
|
|
Work PopBack() {
|
|
if (Empty())
|
|
return Work();
|
|
std::unique_lock<OrtMutex> lock(mutex_);
|
|
unsigned back;
|
|
Elem *e;
|
|
ElemState s;
|
|
|
|
// Drain revoked items from the back of the queue. CAS to busy to synchronize with
|
|
// any attempt to take the same item from the front of the queue.
|
|
do {
|
|
back = back_.load(std::memory_order_relaxed);
|
|
e = &array_[back & kMask];
|
|
s = e->state.load(std::memory_order_relaxed);
|
|
if (s == ElemState::kRevoked &&
|
|
e->state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire)) {
|
|
e->state.store(ElemState::kEmpty, std::memory_order_release);
|
|
back_.store(back + 1 + (kSize << 1), std::memory_order_relaxed);
|
|
}
|
|
} while (s == ElemState::kRevoked);
|
|
|
|
if (s != ElemState::kReady ||
|
|
!e->state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire))
|
|
return Work();
|
|
Work w = std::move(e->w);
|
|
e->tag = Tag();
|
|
e->state.store(ElemState::kEmpty, std::memory_order_release);
|
|
back_.store(back + 1 + (kSize << 1), std::memory_order_relaxed);
|
|
return w;
|
|
}
|
|
|
|
// RevokeItem removes a work item from the queue. Items are identified positionally,
|
|
// and so a tag is used to detect whether the same position is occupied by a
|
|
// different work item at the time of removal. RevokeWithTags lets threads offer work
|
|
// for parallel execution, and then revoke the offer prior to the work executing (for
|
|
// instance if the thread itself completes all of the work). Revoking the work
|
|
// lets the thread deallocate state that might otherwise have been captured by the work item
|
|
// and accessed by it.
|
|
//
|
|
// Return true iff the item is successfully revoked. If the item is not revoked then
|
|
// the caller must assume that it may still execute, for instance because it
|
|
// has been pop'd from the queue concurrent with the revocation request.
|
|
|
|
bool RevokeWithTag(Tag tag, unsigned w_idx) {
|
|
bool revoked = false;
|
|
std::unique_lock<OrtMutex> lock(mutex_);
|
|
Elem& e = array_[w_idx];
|
|
ElemState s = e.state.load(std::memory_order_relaxed);
|
|
if (s == ElemState::kReady &&
|
|
e.state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire)) {
|
|
if (e.tag == tag) {
|
|
unsigned back = back_.load(std::memory_order_relaxed);
|
|
unsigned back_idx = back & kMask;
|
|
if (back_idx != w_idx) {
|
|
// Item is not at the back of the queue, mark it in-place as revoked
|
|
e.tag = Tag();
|
|
e.w = Work();
|
|
e.state.store(ElemState::kRevoked, std::memory_order_release);
|
|
revoked = true;
|
|
} else {
|
|
// Item being removed as still at the back; shift the back pointer over it,
|
|
// and bump the version number.
|
|
e.tag = Tag();
|
|
e.w = Work();
|
|
e.state.store(ElemState::kEmpty, std::memory_order_release);
|
|
back_.store(back + 1 + (kSize << 1), std::memory_order_relaxed);
|
|
revoked = true;
|
|
}
|
|
} else {
|
|
// Tag mismatch, i.e. work queue slot re-used
|
|
e.state.store(ElemState::kReady, std::memory_order_release);
|
|
}
|
|
}
|
|
return revoked;
|
|
}
|
|
|
|
// Size returns current queue size.
|
|
// Can be called by any thread at any time.
|
|
unsigned Size() const {
|
|
return SizeOrNotEmpty<true>();
|
|
}
|
|
|
|
// Empty tests whether container is empty.
|
|
// Can be called by any thread at any time.
|
|
bool Empty() const {
|
|
return SizeOrNotEmpty<false>() == 0;
|
|
}
|
|
|
|
// Delete all the elements from the queue.
|
|
void Flush() {
|
|
while (!Empty()) {
|
|
PopFront();
|
|
}
|
|
}
|
|
|
|
private:
|
|
static const unsigned kMask = kSize - 1;
|
|
static const unsigned kMask2 = (kSize << 1) - 1;
|
|
|
|
enum class ElemState : uint8_t {
|
|
kEmpty,
|
|
kBusy,
|
|
kReady,
|
|
kRevoked,
|
|
};
|
|
|
|
// Updates to an element are bracketed by a std::memory_order_acquire
|
|
// load from the state, and a std::memory_order_release store. Accesses
|
|
// to the front/back indices for the work queue use relaxed semantics,
|
|
// with the state of the elements being authoritative.
|
|
//
|
|
// TODO: Revisit whether there is a significant benefit for the current
|
|
// workloads in the complexity here.
|
|
struct Elem {
|
|
std::atomic<ElemState> state;
|
|
Tag tag;
|
|
Work w;
|
|
};
|
|
|
|
OrtMutex mutex_;
|
|
// Low log(kSize) + 1 bits in front_ and back_ contain rolling index of
|
|
// front/back, respectively. The remaining bits contain modification counters
|
|
// that are incremented on Push operations. This allows us to (1) distinguish
|
|
// between empty and full conditions (if we would use log(kSize) bits for
|
|
// position, these conditions would be indistinguishable); (2) obtain
|
|
// consistent snapshot of front_/back_ for Size operation using the
|
|
// modification counters.
|
|
std::atomic<unsigned> front_;
|
|
std::atomic<unsigned> back_;
|
|
Elem array_[kSize];
|
|
|
|
// SizeOrNotEmpty returns current queue size; if NeedSizeEstimate is false,
|
|
// only whether the size is 0 is guaranteed to be correct.
|
|
// Can be called by any thread at any time.
|
|
template <bool NeedSizeEstimate>
|
|
unsigned SizeOrNotEmpty() const {
|
|
// Emptiness plays critical role in thread pool blocking. So we go to great
|
|
// effort to not produce false positives (claim non-empty queue as empty).
|
|
unsigned front = front_.load(std::memory_order_acquire);
|
|
for (;;) {
|
|
// Capture a consistent snapshot of front/tail.
|
|
unsigned back = back_.load(std::memory_order_acquire);
|
|
unsigned front1 = front_.load(std::memory_order_relaxed);
|
|
if (front != front1) {
|
|
front = front1;
|
|
std::atomic_thread_fence(std::memory_order_acquire);
|
|
continue;
|
|
}
|
|
if (NeedSizeEstimate) {
|
|
return CalculateSize(front, back);
|
|
}
|
|
// This value will be 0 if the queue is empty, and undefined otherwise.
|
|
unsigned maybe_zero = ((front ^ back) & kMask2);
|
|
// Queue size estimate must agree with maybe zero check on the queue
|
|
// empty/non-empty state.
|
|
eigen_assert((CalculateSize(front, back) == 0) == (maybe_zero == 0));
|
|
return maybe_zero;
|
|
}
|
|
}
|
|
|
|
EIGEN_ALWAYS_INLINE
|
|
unsigned CalculateSize(unsigned front, unsigned back) const {
|
|
int size = (front & kMask2) - (back & kMask2);
|
|
// Fix overflow.
|
|
if (size < 0)
|
|
size += 2 * kSize;
|
|
// Order of modification in push/pop is crafted to make the queue look
|
|
// larger than it is during concurrent modifications. E.g. push can
|
|
// increment size before the corresponding pop has decremented it.
|
|
// So the computed size can be up to kSize + 1, fix it.
|
|
if (size > static_cast<int>(kSize))
|
|
size = kSize;
|
|
return static_cast<unsigned>(size);
|
|
}
|
|
|
|
RunQueue(const RunQueue&) = delete;
|
|
void operator=(const RunQueue&) = delete;
|
|
};
|
|
|
|
static std::atomic<uint32_t> next_tag{1};
|
|
|
|
template <typename Environment>
|
|
class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInterface {
|
|
|
|
private:
|
|
static unsigned WorkerLoop(int id, Eigen::ThreadPoolInterface* param) {
|
|
// unsafe downcast
|
|
ThreadPoolTempl* this_ptr = (ThreadPoolTempl*)param;
|
|
this_ptr->WorkerLoop(id);
|
|
return 0;
|
|
}
|
|
|
|
public:
|
|
typedef typename Environment::Task Task;
|
|
|
|
struct Tag {
|
|
constexpr Tag() : v_(0) {
|
|
}
|
|
|
|
Tag(uint32_t v) : v_(v) {
|
|
}
|
|
|
|
// Allocate a new tag to use to identify work items from a given thread
|
|
// in RunInParallel. Ideally, threads will have unique tags, but re-use
|
|
// is not incorrect if the counter wraps (for intsance, if a long-running
|
|
// workload is calling into ORT from a fresh thread for each request).
|
|
// We must not re-use the default tag 0 which is used to identify work
|
|
// items added via Schedule as opposed to requests for help in RunInParallel.
|
|
|
|
static Tag GetNext() {
|
|
Tag t = Tag(next_tag++);
|
|
if (t.v_ == 0) {
|
|
t = Tag(next_tag++);
|
|
}
|
|
return t;
|
|
}
|
|
|
|
uint32_t Get() const {
|
|
return v_;
|
|
}
|
|
|
|
bool operator==(Tag& other) const {
|
|
return v_ == other.v_;
|
|
}
|
|
|
|
uint32_t v_ = 0;
|
|
};
|
|
|
|
static Tag GetNextTag() {
|
|
return Tag(next_tag++);
|
|
}
|
|
|
|
typedef RunQueue<Task, Tag, 1024> Queue;
|
|
#ifdef _WIN32
|
|
using CHAR_TYPE = wchar_t;
|
|
#else
|
|
using CHAR_TYPE = char;
|
|
#endif
|
|
ThreadPoolTempl(const CHAR_TYPE* name, int num_threads, bool allow_spinning, Environment& env,
|
|
const ThreadOptions& thread_options)
|
|
: env_(env),
|
|
num_threads_(num_threads),
|
|
allow_spinning_(allow_spinning),
|
|
worker_data_(num_threads),
|
|
all_coprimes_(num_threads),
|
|
blocked_(0),
|
|
done_(false),
|
|
cancelled_(false) {
|
|
|
|
// Calculate coprimes of all numbers [1, num_threads].
|
|
// Coprimes are used for random walks over all threads in Steal
|
|
// and NonEmptyQueueIndex. Iteration is based on the fact that if we take
|
|
// a random starting thread index t and calculate num_threads - 1 subsequent
|
|
// indices as (t + coprime) % num_threads, we will cover all threads without
|
|
// repetitions (effectively getting a presudo-random permutation of thread
|
|
// indices).
|
|
for (int i = 1; i <= num_threads_; ++i) {
|
|
all_coprimes_.emplace_back(i);
|
|
ComputeCoprimes(i, &all_coprimes_.back());
|
|
}
|
|
|
|
// Allocate space for per-thread bits to indicate which threads to consider
|
|
// preferable for pushing work. We use a regular array given that a std::vector
|
|
// cannot contain std::atomic.
|
|
num_hint_words_ = static_cast<int>((num_threads_ + bits_per_hint_word_ - 1) / bits_per_hint_word_);
|
|
good_worker_hints_ = onnxruntime::make_unique<std::atomic<uint64_t>[]>(num_hint_words_);
|
|
|
|
worker_data_.resize(num_threads_);
|
|
for (int i = 0; i < num_threads_; i++) {
|
|
worker_data_[i].thread.reset(env_.CreateThread(name, i, WorkerLoop, this, thread_options));
|
|
}
|
|
}
|
|
|
|
~ThreadPoolTempl() override {
|
|
done_ = true;
|
|
|
|
// Now if all threads block without work, they will start exiting.
|
|
// But note that threads can continue to work arbitrary long,
|
|
// block, submit new work, unblock and otherwise live full life.
|
|
if (!cancelled_) {
|
|
WakeAllWorkersForExit();
|
|
} else {
|
|
// Since we were cancelled, there might be entries in the queues.
|
|
// Empty them to prevent their destructor from asserting.
|
|
for (size_t i = 0; i < worker_data_.size(); i++) {
|
|
worker_data_[i].queue.Flush();
|
|
}
|
|
}
|
|
// Join threads explicitly (by destroying) to avoid destruction order within
|
|
// this class.
|
|
for (size_t i = 0; i < worker_data_.size(); ++i) worker_data_[i].thread.reset();
|
|
}
|
|
|
|
// Run fn(). Ordinarily, the function will be added to the thread pool and executed
|
|
// by a worker thread. If the thread pool rejects the work then fn() will instead
|
|
// execute synchronously during Schedule(fn). Currently the thread pool will only
|
|
// reject work if the queue of pending work is full.
|
|
|
|
void Schedule(std::function<void()> fn) override {
|
|
Task t = env_.CreateTask(std::move(fn));
|
|
PerThread* pt = GetPerThread();
|
|
if (pt->pool == this) {
|
|
// Worker thread of this pool, push onto the thread's queue.
|
|
Queue& q = worker_data_[pt->thread_id].queue;
|
|
t = q.PushFront(std::move(t));
|
|
} else {
|
|
// A free-standing thread (or worker of another pool), push onto a random
|
|
// queue.
|
|
int q_idx = Rand(&pt->rand) % num_threads_;
|
|
WorkerData &td = worker_data_[q_idx];
|
|
Queue& q = td.queue;
|
|
t = q.PushBack(std::move(t));
|
|
if (!t.f) {
|
|
// The queue accepted the work; ensure that the thread will pick it up
|
|
td.EnsureAwake();
|
|
}
|
|
}
|
|
|
|
// Run the work directly if the queue rejected the work
|
|
if (t.f) {
|
|
env_.ExecuteTask(t);
|
|
}
|
|
}
|
|
|
|
// The thread pool maintains a set of hints for which threads will be good to distribute
|
|
// work to. A thread is considered "good" if it is actively spinning, meaning both that
|
|
// it is not busy with existing work, and that it should respond quickly to the addition
|
|
// of new work.
|
|
|
|
void SetGoodWorkerHint(int idx, bool is_good) {
|
|
assert(idx >= 0 && idx < num_threads_);
|
|
std::atomic<uint64_t>& u64 = good_worker_hints_[idx / bits_per_hint_word_];
|
|
uint64_t bit = 1ull << (idx % bits_per_hint_word_);
|
|
uint64_t saw, want;
|
|
do {
|
|
saw = u64.load();
|
|
want = is_good ? (saw|bit) : (saw&~bit);
|
|
} while (!u64.compare_exchange_weak(saw, want));
|
|
}
|
|
|
|
// Retrieve hints for up to n threads to distribute work to. Threads in good_hints
|
|
// pass a best-effort check to identify spinning threads via the good_worker_hints_
|
|
// bitmap. Threads in alt_hint do not pass that test, but are distinct from those in
|
|
// good_hints, letting the caller avoid distributing more than one work item to
|
|
// any individual thread.
|
|
|
|
void GetGoodWorkerHints(int n, std::vector<unsigned>& good_hints, std::vector<unsigned>& alt_hints) {
|
|
PerThread* pt = GetPerThread();
|
|
int need_alt = n;
|
|
good_hints.clear();
|
|
alt_hints.clear();
|
|
|
|
// Iterate through the words of hints, starting from a pseudo-randomly chosen
|
|
// base. This aims to distribute work across large machines in cases we
|
|
// have multiple threads scheduling work concurrently.
|
|
|
|
unsigned base = Rand(&pt->rand) % num_hint_words_;
|
|
for (int i = 0; n && (i < num_hint_words_); i++) {
|
|
int u64_idx = (base + i) % num_hint_words_;
|
|
std::atomic<uint64_t>* u64 = &good_worker_hints_[u64_idx];
|
|
uint64_t saw = u64->load();
|
|
uint64_t want = saw;
|
|
|
|
// Pick up to n bits that are set in the current word
|
|
for (int j = 0; n && (j < bits_per_hint_word_); j++) {
|
|
uint64_t bit = 1ull << j;
|
|
int thread = u64_idx * bits_per_hint_word_ + j;
|
|
if (saw & bit) {
|
|
good_hints.push_back(thread);
|
|
want &= ~bit;
|
|
n--;
|
|
} else if (need_alt && thread < num_threads_) {
|
|
alt_hints.push_back(thread);
|
|
need_alt--;
|
|
}
|
|
}
|
|
|
|
// Best-effort attempt to remove the hints. We should measure the impact of
|
|
// contention here, but the intuition is that if we conflict on the CAS then the
|
|
// machine is likely to be busy in any case, and we will have queuing on the
|
|
// work items.
|
|
u64->compare_exchange_strong(saw, want);
|
|
}
|
|
}
|
|
|
|
void RunInParallel(std::function<void()> fn, unsigned n) override {
|
|
PerThread* my_pt = GetPerThread();
|
|
assert(n>=1);
|
|
if (n == 1 || my_pt->in_parallel) {
|
|
fn();
|
|
} else {
|
|
// We build a list of <thread,idx> pairs for each of the queues that accepts a work
|
|
// item. This lets us remove any work items that do not get executed by the threads
|
|
// that we push them to.
|
|
std::vector<std::pair<int, unsigned>> pending_items;
|
|
Barrier b(n);
|
|
|
|
my_pt->in_parallel = true;
|
|
if (!my_pt->tag.Get()) {
|
|
my_pt->tag = Tag::GetNext();
|
|
}
|
|
|
|
// Push up to n-1 copies of the work item into the queues
|
|
std::vector<unsigned> good_hints, alt_hints;
|
|
GetGoodWorkerHints(n - 1, good_hints, alt_hints);
|
|
for (unsigned i = 0; i < n - 1; i++) {
|
|
Task t = env_.CreateTask([&b, &fn]() {
|
|
fn();
|
|
b.Notify(1);
|
|
});
|
|
int q_idx;
|
|
if (i < good_hints.size()) {
|
|
q_idx = good_hints[i];
|
|
} else {
|
|
auto alt_i = i - static_cast<unsigned>(good_hints.size());
|
|
if (alt_i < alt_hints.size()) {
|
|
q_idx = alt_hints[alt_i];
|
|
} else {
|
|
q_idx = Rand(&my_pt->rand) % num_threads_;
|
|
}
|
|
}
|
|
WorkerData& td = worker_data_[q_idx];
|
|
Queue& q = td.queue;
|
|
unsigned w_idx;
|
|
t = q.PushBackWithTag(std::move(t), my_pt->tag, w_idx);
|
|
if (t.f) {
|
|
// The queue rejected the work. Account for the missing capacity for work
|
|
// on the synchronization barrier. The semantics for RunInParallel are that
|
|
// the function is called with up to n-way parallelism, and so the
|
|
// work itself will be performed in the current thread's call to fn()
|
|
// after finishing adding work to the pool.
|
|
b.Notify(1);
|
|
} else {
|
|
// The queue accepted the work, ensure that the thread is servicing the queue
|
|
pending_items.push_back({q_idx, w_idx});
|
|
td.EnsureAwake();
|
|
}
|
|
}
|
|
|
|
// Run the final copy ourselves, for the total of n degree-of-parallelism
|
|
fn();
|
|
|
|
// Notify the barrier for the work we completed, plus any work that we successfully
|
|
// revoke from the work queues
|
|
int notifications_needed = 1;
|
|
for (auto& item : pending_items) {
|
|
Queue& q = worker_data_[item.first].queue;
|
|
if (q.RevokeWithTag(my_pt->tag, item.second)) {
|
|
notifications_needed++;
|
|
}
|
|
}
|
|
b.Notify(notifications_needed);
|
|
|
|
// Synchronize with any work items that are still running
|
|
b.Wait();
|
|
my_pt->in_parallel = false;
|
|
}
|
|
}
|
|
|
|
void Cancel() override {
|
|
cancelled_ = true;
|
|
// If done_ is true, which means this object is being destructing.
|
|
// Therefore worker_data_[i].thread could be NULL.
|
|
if (!done_) {
|
|
done_ = true;
|
|
// Let each thread know it's been cancelled.
|
|
for (size_t i = 0; i < worker_data_.size(); i++) {
|
|
assert(worker_data_[i].thread != nullptr);
|
|
worker_data_[i].thread->OnCancel();
|
|
}
|
|
}
|
|
|
|
// Wake up the threads without work to let them exit on their own.
|
|
WakeAllWorkersForExit();
|
|
}
|
|
|
|
int NumThreads() const EIGEN_FINAL {
|
|
return num_threads_;
|
|
}
|
|
|
|
int CurrentThreadId() const EIGEN_FINAL {
|
|
const PerThread* pt = const_cast<ThreadPoolTempl*>(this)->GetPerThread();
|
|
if (pt->pool == this) {
|
|
return pt->thread_id;
|
|
}
|
|
return -1;
|
|
}
|
|
|
|
private:
|
|
|
|
#ifdef NDEBUG
|
|
void AssertBounds(int, int) {
|
|
}
|
|
#else
|
|
void AssertBounds(int start, int end) {
|
|
assert(start >= 0);
|
|
assert(start < end); // non-zero sized partition
|
|
assert(end <= num_threads_);
|
|
}
|
|
#endif
|
|
|
|
void ComputeCoprimes(int N, Eigen::MaxSizeVector<unsigned>* coprimes) {
|
|
for (int i = 1; i <= N; i++) {
|
|
unsigned a = i;
|
|
unsigned b = N;
|
|
// If GCD(a, b) == 1, then a and b are coprimes.
|
|
while (b != 0) {
|
|
unsigned tmp = a;
|
|
a = b;
|
|
b = tmp % b;
|
|
}
|
|
if (a == 1) {
|
|
coprimes->push_back(i);
|
|
}
|
|
}
|
|
}
|
|
|
|
typedef typename Environment::EnvThread Thread;
|
|
struct WorkerData;
|
|
|
|
// PerThread objects are allocated in thread-local storage and allocated
|
|
// on the thread's first call to GetPerThread. The object should
|
|
// remain trivially-destructable, with other state placed in the
|
|
// WorkerData objects that are allocated and cleaned-up explicitly.
|
|
//
|
|
// PerThread objects are allocated for all threads that submit work to
|
|
// the thread pool, in addition to threads within the pool.
|
|
//
|
|
// In contrast, the WorkerData objects are allocated only for the
|
|
// threads in the pool, and their lifetime is managed along with the
|
|
// pool.
|
|
|
|
struct PerThread {
|
|
constexpr PerThread() : pool(nullptr) {
|
|
}
|
|
ThreadPoolTempl* pool; // Parent pool, or null for normal threads.
|
|
uint64_t rand{0}; // Random generator state.
|
|
int thread_id{-1}; // Worker thread index in pool.
|
|
Tag tag{}; // Work item tag used to identify this thread.
|
|
bool in_parallel{false}; // Inside a parallel section (hence tag not unique if we re-use)
|
|
};
|
|
|
|
static_assert(std::is_trivially_destructible<PerThread>::value, "Per-thread state should be trivially destructible");
|
|
|
|
struct WorkerData {
|
|
constexpr WorkerData() : thread(), queue() {
|
|
}
|
|
std::unique_ptr<Thread> thread;
|
|
Queue queue;
|
|
|
|
// Each thread has a status, available read-only without locking, and protected
|
|
// by the mutex field below for updates. The status is used for three
|
|
// purposes:
|
|
//
|
|
// 1. To identify threads that are good candidates to push work to.
|
|
// We prefer to push work to threads that are actively spinning (no need
|
|
// for an OS wake-up, and no need for current work to finish). After that, we
|
|
// prefer to push work to threads that are blocked (no need to wait for the
|
|
// current work to finish).
|
|
//
|
|
// 2. To identify threads that are good candidates to steal work from. We
|
|
// prefer to steal work from threads that are active outside the worker loop.
|
|
// This avoids "snatching" new work away from a thread that has just been
|
|
// given it but not yet noticed.
|
|
//
|
|
// 3. When pushing work to a thread, we use the status read-only to identify
|
|
// when we need to wake the thread. This read-only check avoids the
|
|
// need for mutex / condvar operations in the case where the thread pool
|
|
// remains busy.
|
|
|
|
enum class ThreadStatus : uint8_t {
|
|
Spinning, // Spinning in the work loop, and other cases (initialization) where
|
|
// the thread will soon be in the loop
|
|
Active, // Running user code, not waiting for work
|
|
Blocking, // In the process of blocking; may no longer notice work pushed to it
|
|
Blocked, // Blocked on cv
|
|
Waking, // Not yet back in the worker loop, but wake-up notification sent
|
|
};
|
|
|
|
ThreadStatus GetStatus() const {
|
|
return status;
|
|
}
|
|
|
|
// State transitions, called from other threads
|
|
|
|
void EnsureAwake() {
|
|
ThreadStatus seen = status;
|
|
if (seen == ThreadStatus::Blocking ||
|
|
seen == ThreadStatus::Blocked) {
|
|
std::unique_lock<OrtMutex> lk(mutex);
|
|
// Blocking state exists only transiently during the SetBlock() method
|
|
// while holding the lock. We may observe it at the start of this
|
|
// function, but after acquiring the lock then the target thread
|
|
// will either be blocked or not.
|
|
seen = status;
|
|
assert(seen != ThreadStatus::Blocking);
|
|
if (seen == ThreadStatus::Blocked) {
|
|
status = ThreadStatus::Waking;
|
|
cv.notify_one();
|
|
}
|
|
}
|
|
}
|
|
|
|
// State transitions, called only from the thread itself
|
|
|
|
void SetActive() {
|
|
std::unique_lock<OrtMutex> lk(mutex);
|
|
status = ThreadStatus::Active;
|
|
}
|
|
|
|
void SetSpinning() {
|
|
std::unique_lock<OrtMutex> lk(mutex);
|
|
status = ThreadStatus::Spinning;
|
|
}
|
|
|
|
void SetBlocked(std::function<bool()> should_block,
|
|
std::function<void()> post_block) {
|
|
std::unique_lock<OrtMutex> lk(mutex);
|
|
assert(status == ThreadStatus::Spinning);
|
|
status = ThreadStatus::Blocking;
|
|
if (should_block()) {
|
|
status = ThreadStatus::Blocked;
|
|
while (status == ThreadStatus::Blocked) {
|
|
cv.wait(lk);
|
|
}
|
|
post_block();
|
|
}
|
|
status = ThreadStatus::Spinning;
|
|
}
|
|
|
|
private:
|
|
std::atomic<ThreadStatus> status{ThreadStatus::Spinning};
|
|
OrtMutex mutex;
|
|
OrtCondVar cv;
|
|
};
|
|
|
|
Environment& env_;
|
|
const int num_threads_;
|
|
const bool allow_spinning_;
|
|
Eigen::MaxSizeVector<WorkerData> worker_data_;
|
|
Eigen::MaxSizeVector<Eigen::MaxSizeVector<unsigned>> all_coprimes_;
|
|
std::atomic<unsigned> blocked_; // Count of blocked workers, used as a termination condition
|
|
std::atomic<bool> done_;
|
|
std::atomic<bool> cancelled_;
|
|
|
|
// Allow control over how many bits to use in each entry in good_worker_hints_.
|
|
// We reduce this below the full 64-bit word size for two reasons. First, it
|
|
// helps test coverage on machines without 64 vCPUS. Second, it lets us
|
|
// reduce contention by having different threads start work searching for hints
|
|
// at different locations in the bitmap.
|
|
|
|
static const int bits_per_hint_word_ = 4;
|
|
int num_hint_words_;
|
|
std::unique_ptr<std::atomic<uint64_t>[]> good_worker_hints_;
|
|
|
|
// Wake any blocked workers so that they can cleanly exit WorkerLoop(). For an
|
|
// abrupt exit, cancelled_==true and threads will exit their worker loops. For
|
|
// a clean exit, each thread will observe (1) done_ set, indicating that the
|
|
// destructor has been called, (2) all threads blocked, and (3) no
|
|
// items in the work queues.
|
|
|
|
void WakeAllWorkersForExit() {
|
|
for (auto &td: worker_data_) {
|
|
td.EnsureAwake();
|
|
}
|
|
}
|
|
|
|
// Main worker thread loop.
|
|
void WorkerLoop(int thread_id) {
|
|
PerThread* pt = GetPerThread();
|
|
WorkerData& td = worker_data_[thread_id];
|
|
Queue& q = td.queue;
|
|
bool should_exit = false;
|
|
pt->pool = this;
|
|
pt->rand = GlobalThreadIdHash();
|
|
pt->thread_id = thread_id;
|
|
|
|
assert(td.GetStatus() == WorkerData::ThreadStatus::Spinning);
|
|
SetGoodWorkerHint(thread_id, true /* Is good */);
|
|
|
|
const int log2_spin = 20;
|
|
const int spin_count = allow_spinning_ ? (1ull<<log2_spin) : 0;
|
|
const int steal_count = spin_count/100;
|
|
|
|
while (!cancelled_ && !should_exit) {
|
|
Task t = q.PopFront();
|
|
if (!t.f) {
|
|
// Spin waiting for work. We indicate, via SetGOodWorkerHint that we are
|
|
// spinning. This will bias other threads toward pushing work to our queue.
|
|
// In addition, priodically make a best-effort attempt to steal from other
|
|
// threads which are not themselves spinning.
|
|
|
|
SetGoodWorkerHint(thread_id, true);
|
|
for (int i = 0; i < spin_count && !t.f && !cancelled_ && !done_; i++) {
|
|
t = (i%steal_count == 0) ? TrySteal() : q.PopFront();
|
|
}
|
|
SetGoodWorkerHint(thread_id, false);
|
|
|
|
if (!t.f) {
|
|
// No work passed to us while spinning; make a further full attempt to
|
|
// steal work from other threads prior to blocking.
|
|
if (num_threads_ != 1) {
|
|
t = Steal(true /* true => check all queues */);
|
|
}
|
|
if (!t.f) {
|
|
td.SetBlocked(
|
|
// Pre-block test
|
|
[&]() -> bool {
|
|
bool should_block = true;
|
|
// We already did a best-effort emptiness check when stealing; now
|
|
// do a full check prior to blocking.
|
|
int victim = NonEmptyQueueIndex();
|
|
if (victim != -1) {
|
|
should_block = false;
|
|
if (!cancelled_) {
|
|
t = worker_data_[victim].queue.PopBack();
|
|
}
|
|
}
|
|
// Number of blocked threads is used as termination condition.
|
|
// If we are shutting down and all worker threads blocked without work,
|
|
// that's we are done.
|
|
if (should_block) {
|
|
blocked_++;
|
|
if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) {
|
|
should_block = false;
|
|
// Almost done, but need to re-check queues.
|
|
// Consider that all queues are empty and all worker threads are preempted
|
|
// right after incrementing blocked_ above. Now a free-standing thread
|
|
// submits work and calls destructor (which sets done_). If we don't
|
|
// re-check queues, we will exit leaving the work unexecuted.
|
|
if (NonEmptyQueueIndex() != -1) {
|
|
// Note: we must not pop from queues before we decrement blocked_,
|
|
// otherwise the following scenario is possible. Consider that instead
|
|
// of checking for emptiness we popped the only element from queues.
|
|
// Now other worker threads can start exiting, which is bad if the
|
|
// work item submits other work. So we just check emptiness here,
|
|
// which ensures that all worker threads exit at the same time.
|
|
blocked_--;
|
|
} else {
|
|
should_exit = true;
|
|
}
|
|
}
|
|
}
|
|
return should_block;
|
|
},
|
|
// Post-block update (executed only if we blocked)
|
|
[&]() {
|
|
blocked_--;
|
|
});
|
|
}
|
|
}
|
|
}
|
|
if (t.f) {
|
|
td.SetActive();
|
|
env_.ExecuteTask(t);
|
|
td.SetSpinning();
|
|
}
|
|
}
|
|
|
|
// Whichever thread(s) observe the termination conditions are responsible for waking
|
|
// any other threads that have remained blocked.
|
|
if (should_exit) {
|
|
WakeAllWorkersForExit();
|
|
}
|
|
}
|
|
|
|
// Steal tries to steal work from other worker threads in the range [start,
|
|
// limit) in best-effort manner. We make two passes over the threads:
|
|
//
|
|
// - round 0 : we attempt to steal from threads that are running in
|
|
// user code (ThreadStatus::Active). The intuition behind this is that
|
|
// the thread is busy with other work, and that by preferring to
|
|
// steel from busy victims we will avoid "snatching" work from a
|
|
// thread which is just about to notice the work itself.
|
|
//
|
|
// - round 1 : we steal work from any thread, including those which claim
|
|
// to be spinning. In these cases, even though the victim thread is
|
|
// looking for work itself, it may have been pre-empted.
|
|
|
|
Task Steal(bool check_all) {
|
|
PerThread* pt = GetPerThread();
|
|
unsigned size = static_cast<unsigned>(num_threads_);
|
|
unsigned r = Rand(&pt->rand);
|
|
unsigned inc = all_coprimes_[size - 1][r % all_coprimes_[size - 1].size()];
|
|
|
|
for (int round = 0; round < 2; round++) {
|
|
unsigned victim = r % size;
|
|
for (unsigned i = 0; i < size; i++) {
|
|
assert(victim < size);
|
|
if (round == 1 ||
|
|
worker_data_[victim].GetStatus() == WorkerData::ThreadStatus::Active) {
|
|
Task t = worker_data_[victim].queue.PopBack();
|
|
if (t.f) {
|
|
return t;
|
|
}
|
|
}
|
|
if (!check_all) {
|
|
return Task();
|
|
}
|
|
victim += inc;
|
|
if (victim >= size) {
|
|
victim -= size;
|
|
}
|
|
}
|
|
}
|
|
|
|
return Task();
|
|
}
|
|
|
|
Task TrySteal() {
|
|
return Steal(false);
|
|
}
|
|
|
|
int NonEmptyQueueIndex() {
|
|
PerThread* pt = GetPerThread();
|
|
const unsigned size = static_cast<unsigned>(worker_data_.size());
|
|
unsigned r = Rand(&pt->rand);
|
|
unsigned inc = all_coprimes_[size - 1][r % all_coprimes_[size - 1].size()];
|
|
unsigned victim = r % size;
|
|
for (unsigned i = 0; i < size; i++) {
|
|
if (!worker_data_[victim].queue.Empty()) {
|
|
return victim;
|
|
}
|
|
victim += inc;
|
|
if (victim >= size) {
|
|
victim -= size;
|
|
}
|
|
}
|
|
return -1;
|
|
}
|
|
|
|
static EIGEN_STRONG_INLINE uint64_t GlobalThreadIdHash() {
|
|
return std::hash<std::thread::id>()(std::this_thread::get_id());
|
|
}
|
|
|
|
EIGEN_STRONG_INLINE PerThread* GetPerThread() {
|
|
static thread_local PerThread per_thread_;
|
|
PerThread* pt = &per_thread_;
|
|
return pt;
|
|
}
|
|
|
|
static EIGEN_STRONG_INLINE unsigned Rand(uint64_t* state) {
|
|
uint64_t current = *state;
|
|
// Update the internal state
|
|
*state = current * 6364136223846793005ULL + 0xda3e39cb94b95bdbULL;
|
|
// Generate the random output (using the PCG-XSH-RS scheme)
|
|
return static_cast<unsigned>((current ^ (current >> 22)) >> (22 + (current >> 61)));
|
|
}
|
|
};
|
|
|
|
} // namespace onnxruntime
|