diff --git a/onnxruntime/core/common/threadpool.cc b/onnxruntime/core/common/threadpool.cc index 9292b4a8c7..6cdcb3add7 100644 --- a/onnxruntime/core/common/threadpool.cc +++ b/onnxruntime/core/common/threadpool.cc @@ -39,23 +39,9 @@ void ThreadPool::ParallelFor(int32_t total, std::function fn) { fn(0); return; } + // TODO: Eigen supports a more efficient ThreadPoolDevice mechanism // We will simply rely on the work queue and stealing in the short term. - if (total > NumThreads()) { - //The dispatcher thread will be idle at here - Barrier barrier(static_cast(total)); - std::function handle_iteration = [&barrier, &fn](int iteration) { - fn(iteration); - barrier.Notify(); - }; - - for (int32_t id = 0; id < total; ++id) { - Schedule([=, &handle_iteration]() { handle_iteration(id); }); - } - - barrier.Wait(); - return; - } Barrier barrier(static_cast(total - 1)); std::function handle_iteration = [&barrier, &fn](int iteration) { fn(iteration); @@ -65,7 +51,7 @@ void ThreadPool::ParallelFor(int32_t total, std::function fn) { for (int32_t id = 1; id < total; ++id) { Schedule([=, &handle_iteration]() { handle_iteration(id); }); } - //reuse the current thread for one task + fn(0); barrier.Wait(); } diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index c19ac16694..6ba13c728d 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -545,7 +545,7 @@ MlasGetMaximumThreadCount( MLAS_UNREFERENCED_PARAMETER(ThreadPool); #else if (ThreadPool != nullptr) { - return ThreadPool->NumThreads(); + return ThreadPool->NumThreads() + 1; } #endif