Honor allow_spinning at barrier at end of parallel sections (#4767)

This commit means that when the thread pool is configured to spin, then we spin at the barrier at the end of parallel sections in the main thread, in addition to having workers spin waiting for work. 

The change updates Barrier.h to take an additional boolean to select spin/block, and passes this in based on the thread pool configuration. 

It adds an additional test case for barriers, although no problems were identified by the test case.
This commit is contained in:
Tim Harris 2020-08-13 09:40:40 +01:00 committed by GitHub
parent 61b2a663a3
commit 9cec98ec1b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 93 additions and 8 deletions

View file

@ -7,14 +7,18 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#include <assert.h>
#include "core/platform/ort_mutex.h"
#include <mutex>
#include <atomic>
namespace onnxruntime {
class Barrier {
public:
explicit Barrier(unsigned int count) : state_(count << 1), notified_(false) {
explicit Barrier(unsigned int count, bool spin = false)
: state_(count << 1), notified_(false), spin_(spin) {
assert(((count << 1) >> 1) == count);
}
#ifdef NDEBUG
@ -42,12 +46,18 @@ class Barrier {
}
void Wait() {
unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);
if ((v >> 1) == 0)
return;
std::unique_lock<OrtMutex> l(mu_);
while (!notified_) {
cv_.wait(l);
if (spin_) {
while ((state_ >> 1) != 0) {
/* spin */
}
} else {
unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);
if ((v >> 1) == 0)
return;
std::unique_lock<OrtMutex> l(mu_);
while (!notified_) {
cv_.wait(l);
}
}
}
@ -56,6 +66,7 @@ class Barrier {
OrtCondVar cv_;
std::atomic<unsigned int> state_; // low bit is waiter flag
bool notified_;
const bool spin_;
};
// Notification is an object that allows a user to to wait for another

View file

@ -557,7 +557,7 @@ void RunInParallel(std::function<void()> fn, unsigned n) override {
// item. This lets us remove any work items that do not get executed by the threads
// that we push them to.
std::vector<std::pair<int, unsigned>> pending_items;
Barrier b(n);
Barrier b(n, allow_spinning_);
my_pt->in_parallel = true;
if (!my_pt->tag.Get()) {

View file

@ -0,0 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/platform/Barrier.h"
#include "core/platform/threadpool.h"
#include <core/common/make_unique.h>
#include "gtest/gtest.h"
#include <atomic>
#include <thread>
namespace {
static void TestBarrier(int num_threads, uint64_t per_thread_count, bool spin) {
std::atomic<uint64_t> counter{0};
onnxruntime::Barrier barrier(num_threads, spin);
std::vector<std::thread> threads;
for (auto i = 0; i < num_threads + 1; i++) {
threads.push_back(std::thread([&, i] {
if (i > 0) {
// Worker thread; increment the shared counter then
// notify the barrier.
for (uint64_t j = 0; j < per_thread_count; j++) {
counter++;
}
barrier.Notify();
} else {
// Main thread; wait on the barrier, and then check the count seen.
barrier.Wait();
ASSERT_EQ(counter, per_thread_count * num_threads);
}
}));
}
// Wait for the threads to finish
for (auto &t : threads) {
t.join();
}
}
} // namespace
namespace onnxruntime {
constexpr uint64_t count = 1000000ull;
TEST(BarrierTest, TestBarrier_0Workers_Spin) {
TestBarrier(0, count, true);
}
TEST(BarrierTest, TestBarrier_0Workers_Block) {
TestBarrier(0, count, false);
}
TEST(BarrierTest, TestBarrier_1Worker_Spin) {
TestBarrier(1, count, true);
}
TEST(BarrierTest, TestBarrier_1Worker_Block) {
TestBarrier(1, count, false);
}
TEST(BarrierTest, TestBarrier_4Workers_Spin) {
TestBarrier(4, count, true);
}
TEST(BarrierTest, TestBarrier_4Workers_Block) {
TestBarrier(4, count, false);
}
} // namespace onnxruntime