Fix deadlock in parallel executor (#891)

Fix deadlock in parallel executor
  Execute immediately if ParalellFor has only 1 task
This commit is contained in:
Yufeng Li 2019-04-24 15:55:04 -07:00 committed by Dmitri Smirnov
parent ba3b82648e
commit 73bc09421c

View file

@ -4,6 +4,8 @@
#include "core/platform/threadpool.h"
#include "core/common/common.h"
#include <cassert>
#ifdef USE_EIGEN_THREADPOOL
#if defined(_MSC_VER)
#pragma warning(disable : 4267)
@ -24,7 +26,6 @@
namespace onnxruntime {
namespace concurrency {
#ifdef USE_EIGEN_THREADPOOL
// TODO: This is temporarily taken from Eigen until we upgrade its version.
// Barrier is an object that allows one or more threads to wait until
@ -32,20 +33,20 @@ namespace concurrency {
class Barrier {
public:
Barrier(unsigned int count) : state_(count << 1), notified_(false) {
eigen_assert(((count << 1) >> 1) == count);
assert(((count << 1) >> 1) == count);
}
~Barrier() {
eigen_assert((state_ >> 1) == 0);
assert((state_ >> 1) == 0);
}
void Notify() {
unsigned int v = state_.fetch_sub(2, std::memory_order_acq_rel) - 2;
if (v != 1) {
eigen_assert(((v + 2) & ~1) != 0);
assert(((v + 2) & ~1) != 0);
return; // either count has not dropped to 0, or waiter is not waiting
}
std::unique_lock<std::mutex> l(mu_);
eigen_assert(!notified_);
assert(!notified_);
notified_ = true;
cv_.notify_all();
}
@ -66,6 +67,7 @@ class Barrier {
bool notified_;
};
#ifdef USE_EIGEN_THREADPOOL
class ThreadPool::Impl : public Eigen::ThreadPool {
public:
Impl(const std::string& name, int num_threads)
@ -77,8 +79,7 @@ class ThreadPool::Impl : public Eigen::ThreadPool {
// TODO: Eigen supports a more efficient ThreadPoolDevice mechanism
// We will simply rely on the work queue and stealing in the short term.
Barrier barrier(static_cast<unsigned int>(total));
std::function<void(int32_t)> handle_iteration;
handle_iteration = [=, &handle_iteration, &barrier, &fn](int iteration) {
std::function<void(int32_t)> handle_iteration = [&barrier, &fn](int iteration) {
fn(iteration);
barrier.Notify();
};
@ -94,8 +95,7 @@ class ThreadPool::Impl : public Eigen::ThreadPool {
// TODO: Eigen supports a more efficient ThreadPoolDevice mechanism
// We will simply rely on the work queue and stealing in the short term.
Barrier barrier(static_cast<unsigned int>(last - first + 1));
std::function<void(int64_t, int64_t)> handle_range;
handle_range = [=, &handle_range, &barrier, &fn](int64_t first, int64_t last) {
std::function<void(int64_t, int64_t)> handle_range = [&barrier, &fn](int64_t first, int64_t last) {
fn(first, last);
barrier.Notify();
};
@ -127,11 +127,16 @@ class ThreadPool::Impl : public TaskThreadPool {
fn(id);
}
#else
Barrier barrier(static_cast<unsigned int>(total));
std::function<void(int32_t)> handle_iteration = [&barrier, &fn](int iteration) {
fn(iteration);
barrier.Notify();
};
for (int32_t id = 0; id < total; ++id) {
std::packaged_task<void()> task(std::bind(fn, id));
std::packaged_task<void()> task(std::bind(handle_iteration, id));
RunTask(std::move(task));
}
WaitWorkComplete();
barrier.Wait();
#endif
}
@ -142,11 +147,16 @@ class ThreadPool::Impl : public TaskThreadPool {
fn(id, id + 1);
}
#else
Barrier barrier(static_cast<unsigned int>(last - first + 1));
std::function<void(int64_t, int64_t)> handle_iteration = [&barrier, &fn](int64_t first, int64_t last) {
fn(first, last);
barrier.Notify();
};
for (int64_t id = first; id < last; ++id) {
std::packaged_task<void()> task(std::bind(fn, id, id + 1));
std::packaged_task<void()> task(std::bind(handle_iteration, id, id + 1));
RunTask(std::move(task));
}
WaitWorkComplete();
barrier.Wait();
#endif
}
};
@ -162,7 +172,10 @@ ThreadPool::ThreadPool(const std::string& name, int num_threads)
void ThreadPool::Schedule(std::function<void()> fn) { impl_->Schedule(fn); }
void ThreadPool::ParallelFor(int32_t total, std::function<void(int32_t)> fn) {
if (total <= 0) {
if (total <= 0) return;
if (total == 1) {
fn(0);
return;
}
@ -170,7 +183,8 @@ void ThreadPool::ParallelFor(int32_t total, std::function<void(int32_t)> fn) {
}
void ThreadPool::ParallelForRange(int64_t first, int64_t last, std::function<void(int64_t, int64_t)> fn) {
if (last <= first) {
if (last <= first) return;
if (last - first == 1) {
fn(first, last);
return;
}