Eliminate unnecessary status lock acquisition in TP (#12196)

Eliminate unnecessary status lock acquisition in the Thread Pool
This commit is contained in:
Dmitri Smirnov 2022-07-19 14:16:12 -07:00 committed by GitHub
parent 972e5e7300
commit 4f106d2b3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1382,8 +1382,14 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
// State transitions, called from other threads
// We employ mutex for synchronizing on Blocked/Waking state (EnsureAwake/SeBlocked)
// to wakeup the thread in the event it goes to sleep. Because thread status
// is an atomic member the lock is not necessary to update it.
// Thus, we do not obtain the mutex when we set Active/Spinning state for the thread.
// While manipulating under the mutex, we employ relaxed semantics so the compiler is not restricted
// any further.
void EnsureAwake() {
ThreadStatus seen = status;
ThreadStatus seen = GetStatus();
if (seen == ThreadStatus::Blocking ||
seen == ThreadStatus::Blocked) {
std::unique_lock<OrtMutex> lk(mutex);
@ -1391,10 +1397,10 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
// while holding the lock. We may observe it at the start of this
// function, but after acquiring the lock then the target thread
// will either be blocked or not.
seen = status;
seen = status.load(std::memory_order_relaxed);
assert(seen != ThreadStatus::Blocking);
if (seen == ThreadStatus::Blocked) {
status = ThreadStatus::Waking;
status.store(ThreadStatus::Waking, std::memory_order_relaxed);
lk.unlock();
cv.notify_one();
}
@ -1402,30 +1408,31 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
}
// State transitions, called only from the thread itself
// The lock is only used in the synchronization between EnsureAwake and SetBlocked,
// while the Active vs Spinning states are just used as a hint for work stealing
// (prefer to steal from a thread that is actively running a task, rather than stealing from
// a thread that is spinning and likely to pick up the task itself).
void SetActive() {
std::lock_guard<OrtMutex> lk(mutex);
status = ThreadStatus::Active;
}
void SetSpinning() {
std::lock_guard<OrtMutex> lk(mutex);
status = ThreadStatus::Spinning;
}
void SetBlocked(std::function<bool()> should_block,
std::function<void()> post_block) {
std::unique_lock<OrtMutex> lk(mutex);
assert(status == ThreadStatus::Spinning);
status = ThreadStatus::Blocking;
assert(GetStatus() == ThreadStatus::Spinning);
status.store(ThreadStatus::Blocking, std::memory_order_relaxed);
if (should_block()) {
status = ThreadStatus::Blocked;
while (status == ThreadStatus::Blocked) {
status.store(ThreadStatus::Blocked, std::memory_order_relaxed);
do {
cv.wait(lk);
}
} while (status.load(std::memory_order_relaxed) == ThreadStatus::Blocked);
post_block();
}
status = ThreadStatus::Spinning;
status.store(ThreadStatus::Spinning, std::memory_order_relaxed);
}
private: