From 2ff66b80e0e075696e34c78ab59b351bc8590d56 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Mon, 16 Dec 2024 09:05:12 -0800 Subject: [PATCH] Fix a deadlock bug in EigenNonBlockingThreadPool.h (#23098) ### Description This PR fixes a deadlock bug in EigenNonBlockingThreadPool.h. It only happens on platforms with weakly ordered memory model, such as ARM64. --- .../platform/EigenNonBlockingThreadPool.h | 124 ++++++++++-------- 1 file changed, 66 insertions(+), 58 deletions(-) diff --git a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h index 27b14f008a..a7c63c507d 100644 --- a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h +++ b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h @@ -1467,11 +1467,14 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter status = ThreadStatus::Spinning; } - void SetBlocked(std::function should_block, + bool SetBlocked(std::function should_block, std::function post_block) { std::unique_lock lk(mutex); - assert(GetStatus() == ThreadStatus::Spinning); - status.store(ThreadStatus::Blocking, std::memory_order_relaxed); + auto old_status = status.exchange(ThreadStatus::Blocking, std::memory_order_seq_cst); + if (old_status != ThreadStatus::Spinning) { + // Encountered a logical error + return false; + } if (should_block()) { status.store(ThreadStatus::Blocked, std::memory_order_relaxed); do { @@ -1480,6 +1483,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter post_block(); } status.store(ThreadStatus::Spinning, std::memory_order_relaxed); + return true; } private: @@ -1558,62 +1562,66 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // Attempt to block if (!t) { - td.SetBlocked( // Pre-block test - [&]() -> bool { - bool should_block = true; - // Check whether work was pushed to us while attempting to block. We make - // this test while holding the per-thread status lock, and after setting - // our status to ThreadStatus::Blocking. - // - // This synchronizes with ThreadPool::Schedule which pushes work to the queue - // and then tests for ThreadStatus::Blocking/Blocked (via EnsureAwake): - // - // Main thread: Worker: - // #1 Push work #A Set status blocking - // #2 Read worker status #B Check queue - // #3 Wake if blocking/blocked - // - // If #A is before #2 then main sees worker blocked and wakes - // - // If #A if after #2 then #B will see #1, and we abandon blocking - assert(!t); - t = q.PopFront(); - if (t) { - should_block = false; - } - - // No work pushed to us, continue attempting to block. The remaining - // test is to synchronize with termination requests. If we are - // shutting down and all worker threads blocked without work, that's - // we are done. - if (should_block) { - blocked_++; - if (done_ && blocked_ == num_threads_) { - should_block = false; - // Almost done, but need to re-check queues. - // Consider that all queues are empty and all worker threads are preempted - // right after incrementing blocked_ above. Now a free-standing thread - // submits work and calls destructor (which sets done_). If we don't - // re-check queues, we will exit leaving the work unexecuted. - if (NonEmptyQueueIndex() != -1) { - // Note: we must not pop from queues before we decrement blocked_, - // otherwise the following scenario is possible. Consider that instead - // of checking for emptiness we popped the only element from queues. - // Now other worker threads can start exiting, which is bad if the - // work item submits other work. So we just check emptiness here, - // which ensures that all worker threads exit at the same time. - blocked_--; - } else { - should_exit = true; + if (!td.SetBlocked( // Pre-block test + [&]() -> bool { + bool should_block = true; + // Check whether work was pushed to us while attempting to block. We make + // this test while holding the per-thread status lock, and after setting + // our status to ThreadStatus::Blocking. + // + // This synchronizes with ThreadPool::Schedule which pushes work to the queue + // and then tests for ThreadStatus::Blocking/Blocked (via EnsureAwake): + // + // Main thread: Worker: + // #1 Push work #A Set status blocking + // #2 Read worker status #B Check queue + // #3 Wake if blocking/blocked + // + // If #A is before #2 then main sees worker blocked and wakes + // + // If #A if after #2 then #B will see #1, and we abandon blocking + assert(!t); + t = q.PopFront(); + if (t) { + should_block = false; } - } - } - return should_block; - }, - // Post-block update (executed only if we blocked) - [&]() { - blocked_--; - }); + + // No work pushed to us, continue attempting to block. The remaining + // test is to synchronize with termination requests. If we are + // shutting down and all worker threads blocked without work, that's + // we are done. + if (should_block) { + blocked_++; + if (done_ && blocked_ == num_threads_) { + should_block = false; + // Almost done, but need to re-check queues. + // Consider that all queues are empty and all worker threads are preempted + // right after incrementing blocked_ above. Now a free-standing thread + // submits work and calls destructor (which sets done_). If we don't + // re-check queues, we will exit leaving the work unexecuted. + if (NonEmptyQueueIndex() != -1) { + // Note: we must not pop from queues before we decrement blocked_, + // otherwise the following scenario is possible. Consider that instead + // of checking for emptiness we popped the only element from queues. + // Now other worker threads can start exiting, which is bad if the + // work item submits other work. So we just check emptiness here, + // which ensures that all worker threads exit at the same time. + blocked_--; + } else { + should_exit = true; + } + } + } + return should_block; + }, + // Post-block update (executed only if we blocked) + [&]() { + blocked_--; + })) { + // Encountered a fatal logic error in SetBlocked + should_exit = true; + break; + } // Thread just unblocked. Unless we picked up work while // blocking, or are exiting, then either work was pushed to // us, or it was pushed to an overloaded queue