onnxruntime/winml/test/concurrency/ThreadPool.h
Tiago Koji Castro Shibata c3cea486d0
Port ConcurrencyTests from TAEF (#3086)
* Add ConcurrencyTests

* Make ConcurrencyTests compatible with TAEF

* Use test PCH in concurrency tests

* Fix include header

* Ignore unused code warnings on WINML_SKIP_TEST

* Remove BOM

* Remove conflicting namespace in older SDK

* Refactor duplicate code

* Fix unused DELAYLOAD

* Fix unused DELAYLOAD

* Remove link to internal bug

* Address code style fixes

* Add new concurrency tests
2020-03-27 17:39:22 -07:00

33 lines
1 KiB
C++

#pragma once
#include <vector>
#include <thread>
#include <queue>
#include <mutex>
#include <future>
class ThreadPool {
private:
std::condition_variable m_cond_var;
bool m_destruct_pool;
std::mutex m_mutex;
std::vector<std::thread> m_threads;
std::queue<std::function<void()>> m_work_queue;
public:
ThreadPool(unsigned int initial_pool_size);
~ThreadPool();
template <typename F, typename...Args>
inline auto SubmitWork(F &&f, Args&&... args) -> std::future<decltype(f(args...))> {
auto func = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
auto task = std::make_shared<std::packaged_task<decltype(f(args...))()>>(std::forward<decltype(func)>(func));
{
std::lock_guard<std::mutex> lock(m_mutex);
// wrap packed task into a void return function type so that it can be stored in queue
m_work_queue.push([task]() { (*task)(); });
}
m_cond_var.notify_one(); // unblocks one of the waiting threads
return task->get_future();
}
};