mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Honor allow_spinning at barrier at end of parallel sections (#4767)
This commit means that when the thread pool is configured to spin, then we spin at the barrier at the end of parallel sections in the main thread, in addition to having workers spin waiting for work. The change updates Barrier.h to take an additional boolean to select spin/block, and passes this in based on the thread pool configuration. It adds an additional test case for barriers, although no problems were identified by the test case.
This commit is contained in:
parent
61b2a663a3
commit
9cec98ec1b
3 changed files with 93 additions and 8 deletions
|
|
@ -7,14 +7,18 @@
|
|||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "core/platform/ort_mutex.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <atomic>
|
||||
|
||||
namespace onnxruntime {
|
||||
class Barrier {
|
||||
public:
|
||||
explicit Barrier(unsigned int count) : state_(count << 1), notified_(false) {
|
||||
explicit Barrier(unsigned int count, bool spin = false)
|
||||
: state_(count << 1), notified_(false), spin_(spin) {
|
||||
assert(((count << 1) >> 1) == count);
|
||||
}
|
||||
#ifdef NDEBUG
|
||||
|
|
@ -42,12 +46,18 @@ class Barrier {
|
|||
}
|
||||
|
||||
void Wait() {
|
||||
unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);
|
||||
if ((v >> 1) == 0)
|
||||
return;
|
||||
std::unique_lock<OrtMutex> l(mu_);
|
||||
while (!notified_) {
|
||||
cv_.wait(l);
|
||||
if (spin_) {
|
||||
while ((state_ >> 1) != 0) {
|
||||
/* spin */
|
||||
}
|
||||
} else {
|
||||
unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);
|
||||
if ((v >> 1) == 0)
|
||||
return;
|
||||
std::unique_lock<OrtMutex> l(mu_);
|
||||
while (!notified_) {
|
||||
cv_.wait(l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -56,6 +66,7 @@ class Barrier {
|
|||
OrtCondVar cv_;
|
||||
std::atomic<unsigned int> state_; // low bit is waiter flag
|
||||
bool notified_;
|
||||
const bool spin_;
|
||||
};
|
||||
|
||||
// Notification is an object that allows a user to to wait for another
|
||||
|
|
|
|||
|
|
@ -557,7 +557,7 @@ void RunInParallel(std::function<void()> fn, unsigned n) override {
|
|||
// item. This lets us remove any work items that do not get executed by the threads
|
||||
// that we push them to.
|
||||
std::vector<std::pair<int, unsigned>> pending_items;
|
||||
Barrier b(n);
|
||||
Barrier b(n, allow_spinning_);
|
||||
|
||||
my_pt->in_parallel = true;
|
||||
if (!my_pt->tag.Get()) {
|
||||
|
|
|
|||
74
onnxruntime/test/platform/barrier_test.cc
Normal file
74
onnxruntime/test/platform/barrier_test.cc
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/platform/Barrier.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
|
||||
#include <core/common/make_unique.h>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
|
||||
namespace {
|
||||
|
||||
static void TestBarrier(int num_threads, uint64_t per_thread_count, bool spin) {
|
||||
std::atomic<uint64_t> counter{0};
|
||||
onnxruntime::Barrier barrier(num_threads, spin);
|
||||
|
||||
std::vector<std::thread> threads;
|
||||
for (auto i = 0; i < num_threads + 1; i++) {
|
||||
threads.push_back(std::thread([&, i] {
|
||||
if (i > 0) {
|
||||
// Worker thread; increment the shared counter then
|
||||
// notify the barrier.
|
||||
for (uint64_t j = 0; j < per_thread_count; j++) {
|
||||
counter++;
|
||||
}
|
||||
barrier.Notify();
|
||||
} else {
|
||||
// Main thread; wait on the barrier, and then check the count seen.
|
||||
barrier.Wait();
|
||||
ASSERT_EQ(counter, per_thread_count * num_threads);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
// Wait for the threads to finish
|
||||
for (auto &t : threads) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
constexpr uint64_t count = 1000000ull;
|
||||
|
||||
TEST(BarrierTest, TestBarrier_0Workers_Spin) {
|
||||
TestBarrier(0, count, true);
|
||||
}
|
||||
|
||||
TEST(BarrierTest, TestBarrier_0Workers_Block) {
|
||||
TestBarrier(0, count, false);
|
||||
}
|
||||
|
||||
TEST(BarrierTest, TestBarrier_1Worker_Spin) {
|
||||
TestBarrier(1, count, true);
|
||||
}
|
||||
|
||||
TEST(BarrierTest, TestBarrier_1Worker_Block) {
|
||||
TestBarrier(1, count, false);
|
||||
}
|
||||
|
||||
TEST(BarrierTest, TestBarrier_4Workers_Spin) {
|
||||
TestBarrier(4, count, true);
|
||||
}
|
||||
|
||||
TEST(BarrierTest, TestBarrier_4Workers_Block) {
|
||||
TestBarrier(4, count, false);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue