diff --git a/include/onnxruntime/core/platform/Barrier.h b/include/onnxruntime/core/platform/Barrier.h index 67cc29e3a3..f851cae40e 100644 --- a/include/onnxruntime/core/platform/Barrier.h +++ b/include/onnxruntime/core/platform/Barrier.h @@ -7,14 +7,18 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +#include + #include "core/platform/ort_mutex.h" + #include #include namespace onnxruntime { class Barrier { public: - explicit Barrier(unsigned int count) : state_(count << 1), notified_(false) { + explicit Barrier(unsigned int count, bool spin = false) + : state_(count << 1), notified_(false), spin_(spin) { assert(((count << 1) >> 1) == count); } #ifdef NDEBUG @@ -42,12 +46,18 @@ class Barrier { } void Wait() { - unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); - if ((v >> 1) == 0) - return; - std::unique_lock l(mu_); - while (!notified_) { - cv_.wait(l); + if (spin_) { + while ((state_ >> 1) != 0) { + /* spin */ + } + } else { + unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); + if ((v >> 1) == 0) + return; + std::unique_lock l(mu_); + while (!notified_) { + cv_.wait(l); + } } } @@ -56,6 +66,7 @@ class Barrier { OrtCondVar cv_; std::atomic state_; // low bit is waiter flag bool notified_; + const bool spin_; }; // Notification is an object that allows a user to to wait for another diff --git a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h index 04a47ca5d2..a86bdaf57d 100644 --- a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h +++ b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h @@ -557,7 +557,7 @@ void RunInParallel(std::function fn, unsigned n) override { // item. This lets us remove any work items that do not get executed by the threads // that we push them to. std::vector> pending_items; - Barrier b(n); + Barrier b(n, allow_spinning_); my_pt->in_parallel = true; if (!my_pt->tag.Get()) { diff --git a/onnxruntime/test/platform/barrier_test.cc b/onnxruntime/test/platform/barrier_test.cc new file mode 100644 index 0000000000..860714e569 --- /dev/null +++ b/onnxruntime/test/platform/barrier_test.cc @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/platform/Barrier.h" +#include "core/platform/threadpool.h" + +#include + +#include "gtest/gtest.h" + +#include +#include + +namespace { + +static void TestBarrier(int num_threads, uint64_t per_thread_count, bool spin) { + std::atomic counter{0}; + onnxruntime::Barrier barrier(num_threads, spin); + + std::vector threads; + for (auto i = 0; i < num_threads + 1; i++) { + threads.push_back(std::thread([&, i] { + if (i > 0) { + // Worker thread; increment the shared counter then + // notify the barrier. + for (uint64_t j = 0; j < per_thread_count; j++) { + counter++; + } + barrier.Notify(); + } else { + // Main thread; wait on the barrier, and then check the count seen. + barrier.Wait(); + ASSERT_EQ(counter, per_thread_count * num_threads); + } + })); + } + + // Wait for the threads to finish + for (auto &t : threads) { + t.join(); + } +} + +} // namespace + +namespace onnxruntime { + +constexpr uint64_t count = 1000000ull; + +TEST(BarrierTest, TestBarrier_0Workers_Spin) { + TestBarrier(0, count, true); +} + +TEST(BarrierTest, TestBarrier_0Workers_Block) { + TestBarrier(0, count, false); +} + +TEST(BarrierTest, TestBarrier_1Worker_Spin) { + TestBarrier(1, count, true); +} + +TEST(BarrierTest, TestBarrier_1Worker_Block) { + TestBarrier(1, count, false); +} + +TEST(BarrierTest, TestBarrier_4Workers_Spin) { + TestBarrier(4, count, true); +} + +TEST(BarrierTest, TestBarrier_4Workers_Block) { + TestBarrier(4, count, false); +} + +} // namespace onnxruntime