From 73bc09421c1bdafda21402ab168ec397ec9e8eb0 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Wed, 24 Apr 2019 15:55:04 -0700 Subject: [PATCH] Fix deadlock in parallel executor (#891) Fix deadlock in parallel executor Execute immediately if ParalellFor has only 1 task --- onnxruntime/core/common/threadpool.cc | 44 ++++++++++++++++++--------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/common/threadpool.cc b/onnxruntime/core/common/threadpool.cc index 7f07b0d3fa..875b90c4ff 100644 --- a/onnxruntime/core/common/threadpool.cc +++ b/onnxruntime/core/common/threadpool.cc @@ -4,6 +4,8 @@ #include "core/platform/threadpool.h" #include "core/common/common.h" +#include + #ifdef USE_EIGEN_THREADPOOL #if defined(_MSC_VER) #pragma warning(disable : 4267) @@ -24,7 +26,6 @@ namespace onnxruntime { namespace concurrency { -#ifdef USE_EIGEN_THREADPOOL // TODO: This is temporarily taken from Eigen until we upgrade its version. // Barrier is an object that allows one or more threads to wait until @@ -32,20 +33,20 @@ namespace concurrency { class Barrier { public: Barrier(unsigned int count) : state_(count << 1), notified_(false) { - eigen_assert(((count << 1) >> 1) == count); + assert(((count << 1) >> 1) == count); } ~Barrier() { - eigen_assert((state_ >> 1) == 0); + assert((state_ >> 1) == 0); } void Notify() { unsigned int v = state_.fetch_sub(2, std::memory_order_acq_rel) - 2; if (v != 1) { - eigen_assert(((v + 2) & ~1) != 0); + assert(((v + 2) & ~1) != 0); return; // either count has not dropped to 0, or waiter is not waiting } std::unique_lock l(mu_); - eigen_assert(!notified_); + assert(!notified_); notified_ = true; cv_.notify_all(); } @@ -66,6 +67,7 @@ class Barrier { bool notified_; }; +#ifdef USE_EIGEN_THREADPOOL class ThreadPool::Impl : public Eigen::ThreadPool { public: Impl(const std::string& name, int num_threads) @@ -77,8 +79,7 @@ class ThreadPool::Impl : public Eigen::ThreadPool { // TODO: Eigen supports a more efficient ThreadPoolDevice mechanism // We will simply rely on the work queue and stealing in the short term. Barrier barrier(static_cast(total)); - std::function handle_iteration; - handle_iteration = [=, &handle_iteration, &barrier, &fn](int iteration) { + std::function handle_iteration = [&barrier, &fn](int iteration) { fn(iteration); barrier.Notify(); }; @@ -94,8 +95,7 @@ class ThreadPool::Impl : public Eigen::ThreadPool { // TODO: Eigen supports a more efficient ThreadPoolDevice mechanism // We will simply rely on the work queue and stealing in the short term. Barrier barrier(static_cast(last - first + 1)); - std::function handle_range; - handle_range = [=, &handle_range, &barrier, &fn](int64_t first, int64_t last) { + std::function handle_range = [&barrier, &fn](int64_t first, int64_t last) { fn(first, last); barrier.Notify(); }; @@ -127,11 +127,16 @@ class ThreadPool::Impl : public TaskThreadPool { fn(id); } #else + Barrier barrier(static_cast(total)); + std::function handle_iteration = [&barrier, &fn](int iteration) { + fn(iteration); + barrier.Notify(); + }; for (int32_t id = 0; id < total; ++id) { - std::packaged_task task(std::bind(fn, id)); + std::packaged_task task(std::bind(handle_iteration, id)); RunTask(std::move(task)); } - WaitWorkComplete(); + barrier.Wait(); #endif } @@ -142,11 +147,16 @@ class ThreadPool::Impl : public TaskThreadPool { fn(id, id + 1); } #else + Barrier barrier(static_cast(last - first + 1)); + std::function handle_iteration = [&barrier, &fn](int64_t first, int64_t last) { + fn(first, last); + barrier.Notify(); + }; for (int64_t id = first; id < last; ++id) { - std::packaged_task task(std::bind(fn, id, id + 1)); + std::packaged_task task(std::bind(handle_iteration, id, id + 1)); RunTask(std::move(task)); } - WaitWorkComplete(); + barrier.Wait(); #endif } }; @@ -162,7 +172,10 @@ ThreadPool::ThreadPool(const std::string& name, int num_threads) void ThreadPool::Schedule(std::function fn) { impl_->Schedule(fn); } void ThreadPool::ParallelFor(int32_t total, std::function fn) { - if (total <= 0) { + if (total <= 0) return; + + if (total == 1) { + fn(0); return; } @@ -170,7 +183,8 @@ void ThreadPool::ParallelFor(int32_t total, std::function fn) { } void ThreadPool::ParallelForRange(int64_t first, int64_t last, std::function fn) { - if (last <= first) { + if (last <= first) return; + if (last - first == 1) { fn(first, last); return; }