mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Fix deadlock in parallel executor (#891)
Fix deadlock in parallel executor Execute immediately if ParalellFor has only 1 task
This commit is contained in:
parent
ba3b82648e
commit
73bc09421c
1 changed files with 29 additions and 15 deletions
|
|
@ -4,6 +4,8 @@
|
|||
#include "core/platform/threadpool.h"
|
||||
#include "core/common/common.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#ifdef USE_EIGEN_THREADPOOL
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable : 4267)
|
||||
|
|
@ -24,7 +26,6 @@
|
|||
namespace onnxruntime {
|
||||
|
||||
namespace concurrency {
|
||||
#ifdef USE_EIGEN_THREADPOOL
|
||||
|
||||
// TODO: This is temporarily taken from Eigen until we upgrade its version.
|
||||
// Barrier is an object that allows one or more threads to wait until
|
||||
|
|
@ -32,20 +33,20 @@ namespace concurrency {
|
|||
class Barrier {
|
||||
public:
|
||||
Barrier(unsigned int count) : state_(count << 1), notified_(false) {
|
||||
eigen_assert(((count << 1) >> 1) == count);
|
||||
assert(((count << 1) >> 1) == count);
|
||||
}
|
||||
~Barrier() {
|
||||
eigen_assert((state_ >> 1) == 0);
|
||||
assert((state_ >> 1) == 0);
|
||||
}
|
||||
|
||||
void Notify() {
|
||||
unsigned int v = state_.fetch_sub(2, std::memory_order_acq_rel) - 2;
|
||||
if (v != 1) {
|
||||
eigen_assert(((v + 2) & ~1) != 0);
|
||||
assert(((v + 2) & ~1) != 0);
|
||||
return; // either count has not dropped to 0, or waiter is not waiting
|
||||
}
|
||||
std::unique_lock<std::mutex> l(mu_);
|
||||
eigen_assert(!notified_);
|
||||
assert(!notified_);
|
||||
notified_ = true;
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
|
@ -66,6 +67,7 @@ class Barrier {
|
|||
bool notified_;
|
||||
};
|
||||
|
||||
#ifdef USE_EIGEN_THREADPOOL
|
||||
class ThreadPool::Impl : public Eigen::ThreadPool {
|
||||
public:
|
||||
Impl(const std::string& name, int num_threads)
|
||||
|
|
@ -77,8 +79,7 @@ class ThreadPool::Impl : public Eigen::ThreadPool {
|
|||
// TODO: Eigen supports a more efficient ThreadPoolDevice mechanism
|
||||
// We will simply rely on the work queue and stealing in the short term.
|
||||
Barrier barrier(static_cast<unsigned int>(total));
|
||||
std::function<void(int32_t)> handle_iteration;
|
||||
handle_iteration = [=, &handle_iteration, &barrier, &fn](int iteration) {
|
||||
std::function<void(int32_t)> handle_iteration = [&barrier, &fn](int iteration) {
|
||||
fn(iteration);
|
||||
barrier.Notify();
|
||||
};
|
||||
|
|
@ -94,8 +95,7 @@ class ThreadPool::Impl : public Eigen::ThreadPool {
|
|||
// TODO: Eigen supports a more efficient ThreadPoolDevice mechanism
|
||||
// We will simply rely on the work queue and stealing in the short term.
|
||||
Barrier barrier(static_cast<unsigned int>(last - first + 1));
|
||||
std::function<void(int64_t, int64_t)> handle_range;
|
||||
handle_range = [=, &handle_range, &barrier, &fn](int64_t first, int64_t last) {
|
||||
std::function<void(int64_t, int64_t)> handle_range = [&barrier, &fn](int64_t first, int64_t last) {
|
||||
fn(first, last);
|
||||
barrier.Notify();
|
||||
};
|
||||
|
|
@ -127,11 +127,16 @@ class ThreadPool::Impl : public TaskThreadPool {
|
|||
fn(id);
|
||||
}
|
||||
#else
|
||||
Barrier barrier(static_cast<unsigned int>(total));
|
||||
std::function<void(int32_t)> handle_iteration = [&barrier, &fn](int iteration) {
|
||||
fn(iteration);
|
||||
barrier.Notify();
|
||||
};
|
||||
for (int32_t id = 0; id < total; ++id) {
|
||||
std::packaged_task<void()> task(std::bind(fn, id));
|
||||
std::packaged_task<void()> task(std::bind(handle_iteration, id));
|
||||
RunTask(std::move(task));
|
||||
}
|
||||
WaitWorkComplete();
|
||||
barrier.Wait();
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -142,11 +147,16 @@ class ThreadPool::Impl : public TaskThreadPool {
|
|||
fn(id, id + 1);
|
||||
}
|
||||
#else
|
||||
Barrier barrier(static_cast<unsigned int>(last - first + 1));
|
||||
std::function<void(int64_t, int64_t)> handle_iteration = [&barrier, &fn](int64_t first, int64_t last) {
|
||||
fn(first, last);
|
||||
barrier.Notify();
|
||||
};
|
||||
for (int64_t id = first; id < last; ++id) {
|
||||
std::packaged_task<void()> task(std::bind(fn, id, id + 1));
|
||||
std::packaged_task<void()> task(std::bind(handle_iteration, id, id + 1));
|
||||
RunTask(std::move(task));
|
||||
}
|
||||
WaitWorkComplete();
|
||||
barrier.Wait();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
|
@ -162,7 +172,10 @@ ThreadPool::ThreadPool(const std::string& name, int num_threads)
|
|||
void ThreadPool::Schedule(std::function<void()> fn) { impl_->Schedule(fn); }
|
||||
|
||||
void ThreadPool::ParallelFor(int32_t total, std::function<void(int32_t)> fn) {
|
||||
if (total <= 0) {
|
||||
if (total <= 0) return;
|
||||
|
||||
if (total == 1) {
|
||||
fn(0);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -170,7 +183,8 @@ void ThreadPool::ParallelFor(int32_t total, std::function<void(int32_t)> fn) {
|
|||
}
|
||||
|
||||
void ThreadPool::ParallelForRange(int64_t first, int64_t last, std::function<void(int64_t, int64_t)> fn) {
|
||||
if (last <= first) {
|
||||
if (last <= first) return;
|
||||
if (last - first == 1) {
|
||||
fn(first, last);
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue