From 9cec98ec1bcb6b8a668733fcd00d81af61b44d93 Mon Sep 17 00:00:00 2001 From: Tim Harris Date: Thu, 13 Aug 2020 09:40:40 +0100 Subject: [PATCH] Honor allow_spinning at barrier at end of parallel sections (#4767) This commit means that when the thread pool is configured to spin, then we spin at the barrier at the end of parallel sections in the main thread, in addition to having workers spin waiting for work. The change updates Barrier.h to take an additional boolean to select spin/block, and passes this in based on the thread pool configuration. It adds an additional test case for barriers, although no problems were identified by the test case. --- include/onnxruntime/core/platform/Barrier.h | 25 +++++-- .../platform/EigenNonBlockingThreadPool.h | 2 +- onnxruntime/test/platform/barrier_test.cc | 74 +++++++++++++++++++ 3 files changed, 93 insertions(+), 8 deletions(-) create mode 100644 onnxruntime/test/platform/barrier_test.cc diff --git a/include/onnxruntime/core/platform/Barrier.h b/include/onnxruntime/core/platform/Barrier.h index 67cc29e3a3..f851cae40e 100644 --- a/include/onnxruntime/core/platform/Barrier.h +++ b/include/onnxruntime/core/platform/Barrier.h @@ -7,14 +7,18 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +#include + #include "core/platform/ort_mutex.h" + #include #include namespace onnxruntime { class Barrier { public: - explicit Barrier(unsigned int count) : state_(count << 1), notified_(false) { + explicit Barrier(unsigned int count, bool spin = false) + : state_(count << 1), notified_(false), spin_(spin) { assert(((count << 1) >> 1) == count); } #ifdef NDEBUG @@ -42,12 +46,18 @@ class Barrier { } void Wait() { - unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); - if ((v >> 1) == 0) - return; - std::unique_lock l(mu_); - while (!notified_) { - cv_.wait(l); + if (spin_) { + while ((state_ >> 1) != 0) { + /* spin */ + } + } else { + unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); + if ((v >> 1) == 0) + return; + std::unique_lock l(mu_); + while (!notified_) { + cv_.wait(l); + } } } @@ -56,6 +66,7 @@ class Barrier { OrtCondVar cv_; std::atomic state_; // low bit is waiter flag bool notified_; + const bool spin_; }; // Notification is an object that allows a user to to wait for another diff --git a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h index 04a47ca5d2..a86bdaf57d 100644 --- a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h +++ b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h @@ -557,7 +557,7 @@ void RunInParallel(std::function fn, unsigned n) override { // item. This lets us remove any work items that do not get executed by the threads // that we push them to. std::vector> pending_items; - Barrier b(n); + Barrier b(n, allow_spinning_); my_pt->in_parallel = true; if (!my_pt->tag.Get()) { diff --git a/onnxruntime/test/platform/barrier_test.cc b/onnxruntime/test/platform/barrier_test.cc new file mode 100644 index 0000000000..860714e569 --- /dev/null +++ b/onnxruntime/test/platform/barrier_test.cc @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/platform/Barrier.h" +#include "core/platform/threadpool.h" + +#include + +#include "gtest/gtest.h" + +#include +#include + +namespace { + +static void TestBarrier(int num_threads, uint64_t per_thread_count, bool spin) { + std::atomic counter{0}; + onnxruntime::Barrier barrier(num_threads, spin); + + std::vector threads; + for (auto i = 0; i < num_threads + 1; i++) { + threads.push_back(std::thread([&, i] { + if (i > 0) { + // Worker thread; increment the shared counter then + // notify the barrier. + for (uint64_t j = 0; j < per_thread_count; j++) { + counter++; + } + barrier.Notify(); + } else { + // Main thread; wait on the barrier, and then check the count seen. + barrier.Wait(); + ASSERT_EQ(counter, per_thread_count * num_threads); + } + })); + } + + // Wait for the threads to finish + for (auto &t : threads) { + t.join(); + } +} + +} // namespace + +namespace onnxruntime { + +constexpr uint64_t count = 1000000ull; + +TEST(BarrierTest, TestBarrier_0Workers_Spin) { + TestBarrier(0, count, true); +} + +TEST(BarrierTest, TestBarrier_0Workers_Block) { + TestBarrier(0, count, false); +} + +TEST(BarrierTest, TestBarrier_1Worker_Spin) { + TestBarrier(1, count, true); +} + +TEST(BarrierTest, TestBarrier_1Worker_Block) { + TestBarrier(1, count, false); +} + +TEST(BarrierTest, TestBarrier_4Workers_Spin) { + TestBarrier(4, count, true); +} + +TEST(BarrierTest, TestBarrier_4Workers_Block) { + TestBarrier(4, count, false); +} + +} // namespace onnxruntime