diff --git a/caffe2/contrib/gloo/store_handler.cc b/caffe2/contrib/gloo/store_handler.cc index c9c116c266d..e5c30d9c71c 100644 --- a/caffe2/contrib/gloo/store_handler.cc +++ b/caffe2/contrib/gloo/store_handler.cc @@ -15,8 +15,10 @@ std::vector StoreHandlerWrapper::get(const std::string& key) { return std::vector(str.begin(), str.end()); } -void StoreHandlerWrapper::wait(const std::vector& keys) { - handler_.wait(keys); +void StoreHandlerWrapper::wait( + const std::vector& keys, + const std::chrono::milliseconds& timeout) { + handler_.wait(keys, timeout); } } // namespace gloo diff --git a/caffe2/contrib/gloo/store_handler.h b/caffe2/contrib/gloo/store_handler.h index 6621a87d936..59e131bdd5f 100644 --- a/caffe2/contrib/gloo/store_handler.h +++ b/caffe2/contrib/gloo/store_handler.h @@ -18,7 +18,13 @@ class StoreHandlerWrapper : public ::gloo::rendezvous::Store { virtual std::vector get(const std::string& key) override; - virtual void wait(const std::vector& keys) override; + virtual void wait(const std::vector& keys) override { + wait(keys, ::gloo::rendezvous::Store::kDefaultTimeout); + } + + virtual void wait( + const std::vector& keys, + const std::chrono::milliseconds& timeout) override; protected: StoreHandler& handler_; diff --git a/caffe2/distributed/file_store_handler.cc b/caffe2/distributed/file_store_handler.cc index 480b57df4a6..69267ae9077 100644 --- a/caffe2/distributed/file_store_handler.cc +++ b/caffe2/distributed/file_store_handler.cc @@ -119,10 +119,18 @@ bool FileStoreHandler::check(const std::vector& names) { return true; } -void FileStoreHandler::wait(const std::vector& names) { +void FileStoreHandler::wait( + const std::vector& names, + const std::chrono::milliseconds& timeout) { // Not using inotify because it doesn't work on many // shared filesystems (such as NFS). + const auto start = std::chrono::steady_clock::now(); while (!check(names)) { + const auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start); + if (timeout != kNoTimeout && elapsed > timeout) { + CAFFE_ENFORCE(false, "Wait timeout for name(s): ", Join(" ", names)); + } /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(10)); } diff --git a/caffe2/distributed/file_store_handler.h b/caffe2/distributed/file_store_handler.h index 2447508e5f0..91a169f6fc8 100644 --- a/caffe2/distributed/file_store_handler.h +++ b/caffe2/distributed/file_store_handler.h @@ -17,7 +17,9 @@ class FileStoreHandler : public StoreHandler { virtual bool check(const std::vector& names) override; - virtual void wait(const std::vector& names) override; + virtual void wait( + const std::vector& names, + const std::chrono::milliseconds& timeout = kDefaultTimeout) override; protected: std::string basePath_; diff --git a/caffe2/distributed/redis_store_handler.cc b/caffe2/distributed/redis_store_handler.cc index 244758fbd08..58f313b9d9a 100644 --- a/caffe2/distributed/redis_store_handler.cc +++ b/caffe2/distributed/redis_store_handler.cc @@ -89,12 +89,20 @@ bool RedisStoreHandler::check(const std::vector& names) { return reply->integer == names.size(); } -void RedisStoreHandler::wait(const std::vector& names) { +void RedisStoreHandler::wait( + const std::vector& names, + const std::chrono::milliseconds& timeout) { // Simple approach: poll... // Complex approach: use pub/sub. // Polling is fine for the typical rendezvous use case, as it is // only done at initialization time and not at run time. + const auto start = std::chrono::steady_clock::now(); while (!check(names)) { + const auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start); + if (timeout != kNoTimeout && elapsed > timeout) { + CAFFE_ENFORCE(false, "Wait timeout for name(s): ", Join(" ", names)); + } /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(10)); } diff --git a/caffe2/distributed/redis_store_handler.h b/caffe2/distributed/redis_store_handler.h index 8dba5ede404..ee208d9964e 100644 --- a/caffe2/distributed/redis_store_handler.h +++ b/caffe2/distributed/redis_store_handler.h @@ -23,7 +23,9 @@ class RedisStoreHandler : public StoreHandler { virtual bool check(const std::vector& names) override; - virtual void wait(const std::vector& names) override; + virtual void wait( + const std::vector& names, + const std::chrono::milliseconds& timeout = kDefaultTimeout) override; private: std::string host_; diff --git a/caffe2/distributed/store_handler.cc b/caffe2/distributed/store_handler.cc index 0bb9bcf7dc0..5c585f8544b 100644 --- a/caffe2/distributed/store_handler.cc +++ b/caffe2/distributed/store_handler.cc @@ -6,6 +6,9 @@ namespace caffe2 { +constexpr std::chrono::milliseconds StoreHandler::kDefaultTimeout; +constexpr std::chrono::milliseconds StoreHandler::kNoTimeout; + StoreHandler::~StoreHandler() { // NOP; definition is here to make sure library contains // symbols for this abstract class. diff --git a/caffe2/distributed/store_handler.h b/caffe2/distributed/store_handler.h index 329a0e6fed5..80380257418 100644 --- a/caffe2/distributed/store_handler.h +++ b/caffe2/distributed/store_handler.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -8,6 +9,11 @@ namespace caffe2 { class StoreHandler { public: + static constexpr std::chrono::milliseconds kDefaultTimeout = + std::chrono::seconds(30); + static constexpr std::chrono::milliseconds kNoTimeout = + std::chrono::milliseconds::zero(); + virtual ~StoreHandler(); virtual void set(const std::string& name, const std::string& data) = 0; @@ -18,6 +24,8 @@ class StoreHandler { virtual bool check(const std::vector& names) = 0; - virtual void wait(const std::vector& names) = 0; + virtual void wait( + const std::vector& names, + const std::chrono::milliseconds& timeout = kDefaultTimeout) = 0; }; }