[RPC][Better Engineering] Consolidated all rpcAgentRunning atomic booleans (#33915)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33915

Closes: https://github.com/pytorch/pytorch/issues/32963

Test Plan: build bot

Reviewed By: jjlilley

Differential Revision: D20074714

fbshipit-source-id: ee89e76f547a1da71825a317c096176524504290
This commit is contained in:
Omkar Salpekar 2020-04-03 11:47:59 -07:00 committed by Facebook GitHub Bot
parent c5c63a2e35
commit 19bbfbe1cf
4 changed files with 35 additions and 34 deletions

View file

@ -142,7 +142,7 @@ ProcessGroupAgent::ProcessGroupAgent(
}
ProcessGroupAgent::~ProcessGroupAgent() {
if (rpcRunning_) {
if (rpcAgentRunning_) {
shutdown();
}
}
@ -240,11 +240,7 @@ void ProcessGroupAgent::sync() {
} while (hasPendingMessage());
}
void ProcessGroupAgent::start() {
{
std::lock_guard<std::mutex> futureLock{futureMutex_};
rpcRunning_.store(true);
}
void ProcessGroupAgent::startImpl() {
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
futureTimeoutThread_ =
std::thread(&ProcessGroupAgent::pollTimedOutRPCs, this);
@ -254,7 +250,7 @@ void ProcessGroupAgent::shutdown() {
LOG(INFO) << "Shutting down ProcessGroupAgent on rank " << pg_->getRank()
<< ".";
std::unique_lock<std::mutex> lock{futureMutex_};
if (!rpcRunning_.exchange(false)) {
if (!rpcAgentRunning_.exchange(false)) {
return;
}
lock.unlock();
@ -297,7 +293,7 @@ std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
}
}
if (!rpcRunning_.load()) {
if (!rpcAgentRunning_.load()) {
// We are trying to send but RPC has been shut down on this node. This can
// happen if we are in a shutdown sequence but background threads are still
// processing messages that result in send()s. Throw a descriptive error.
@ -449,7 +445,7 @@ void ProcessGroupAgent::handleSend(const SendWork& work) {
}
for (auto& pendingSend : pendingSends) {
if (!rpcRunning_.load() || !pendingSend->wait()) {
if (!rpcAgentRunning_.load() || !pendingSend->wait()) {
// Send was interrupted or RPC is not running.
return;
}
@ -694,7 +690,7 @@ void ProcessGroupAgent::listenLoop() {
}
void ProcessGroupAgent::listenLoopInternal() {
while (rpcRunning_.load()) {
while (rpcAgentRunning_.load()) {
// rank, tensor size, message type
std::vector<torch::Tensor> preamble = {torch::empty({4}, {torch::kInt64})};
auto work = pg_->recvAnysource(preamble, pg_->getRank());
@ -703,7 +699,7 @@ void ProcessGroupAgent::listenLoopInternal() {
recvWork_ = work;
}
if (!rpcRunning_.load() || !work->wait() /* aborted */) {
if (!rpcAgentRunning_.load() || !work->wait() /* aborted */) {
return;
}
@ -723,7 +719,7 @@ void ProcessGroupAgent::listenLoopInternal() {
}
void ProcessGroupAgent::pollTimedOutRPCs() {
while (rpcRunning_.load()) {
while (rpcAgentRunning_.load()) {
std::unique_lock<std::mutex> lock{futureMutex_};
steady_clock_time_point minEndTime;
// Estimate amount of time the first future will time out in, and sleep
@ -737,13 +733,13 @@ void ProcessGroupAgent::pollTimedOutRPCs() {
}
auto shouldUpdateMinEndTimePredicate = [&, this]() -> bool {
// Notice, whoever modifying `rpcRunning_`
// Notice, whoever modifying `rpcAgentRunning_`
// must acquire lock on `futureMutex_`.
// Otherwise, this predicate could deadlock.
// If during evaluating the predicate, `::shutdown()` is called, then
// the predicate missed the notification before it started waiting
// on the cond var.
if (!rpcRunning_.load()) {
if (!rpcAgentRunning_.load()) {
return true;
}
steady_clock_time_point minEndTimeInMap = kInfiniteTimeoutTimePoint;

View file

@ -74,7 +74,7 @@ class ProcessGroupAgent : public RpcAgent {
void sync() override;
void start() override;
void startImpl() override;
void shutdown() override;
@ -219,13 +219,6 @@ class ProcessGroupAgent : public RpcAgent {
MessageCounter recvCounts_;
std::atomic<int64_t> nextId_;
// atomic bool indicating if this agent is running. It is set in
// ProcessGroupAgent::start and unset in ProcessGroupAgent::shutdown and
// ProcessGroupAgent::join. It controls whether several background threads
// should be running.
// We lock access to this in shutdown() and pollTimedOutRPCs() to prevent race
// conditions when notifying condition variables.
std::atomic<bool> rpcRunning_{false};
// one mutex per ProcessGroup rank, as ProcessGroup::send is not thread-safe
// when using the same tag.
std::vector<std::mutex> sendMutexes_;

View file

@ -20,20 +20,24 @@ RpcAgent::RpcAgent(
cb_(std::move(cb)),
rpcTimeout_(rpcTimeout),
profilingEnabled_(false),
rpcAgentRunning_(true) {
rpcRetryThread_ = std::thread(&RpcAgent::retryExpiredRpcs, this);
}
rpcAgentRunning_(false) {}
RpcAgent::~RpcAgent() {
cleanup();
}
void RpcAgent::start() {
rpcAgentRunning_.store(true);
rpcRetryThread_ = std::thread(&RpcAgent::retryExpiredRpcs, this);
startImpl();
}
void RpcAgent::cleanup() {
if (!rpcAgentRunning_.exchange(false)) {
return;
rpcAgentRunning_.store(false);
if (rpcRetryThread_.joinable()) {
rpcRetryMapCV_.notify_one();
rpcRetryThread_.join();
}
rpcRetryMapCV_.notify_one();
rpcRetryThread_.join();
}
std::shared_ptr<FutureMessage> RpcAgent::sendWithRetries(

View file

@ -198,8 +198,15 @@ class TORCH_API RpcAgent {
// all ``RpcAgent``s reach this method and send all pending messages.
virtual void sync() = 0;
// start accepting requests
virtual void start() = 0;
// Sets up backend-agnostic state for accepting requests. Currently, this
// entails setting rpcAgentRunning_ to true, creating the retry thread, and
// calling the backend's startImpl.
void start();
// Derived classes must override this function to start accepting requests.
// This is used to initialize any backend-specific state. Users must call
// start, not startImpl, to initialize the RPC Agent.
virtual void startImpl() = 0;
// Stop accepting requests and shutdown the RPC framework as soon as possible
// by terminating all RPC threads.
@ -240,6 +247,10 @@ class TORCH_API RpcAgent {
std::atomic<std::chrono::milliseconds> rpcTimeout_;
std::atomic<bool> profilingEnabled_;
std::shared_ptr<TypeResolver> typeResolver_;
// Atomic boolean indicating whether this agent is running. It controls
// whether several background threads should be running. It is set in
// RpcAgent::start() and unset in the derived class shutdown().
std::atomic<bool> rpcAgentRunning_;
private:
static std::shared_ptr<RpcAgent> currentRpcAgent_;
@ -295,9 +306,6 @@ class TORCH_API RpcAgent {
std::chrono::steady_clock::now() + timedelta);
}
// Boolean that indicates whether RpcAgent is running.
std::atomic<bool> rpcAgentRunning_;
// storing futures before adding callback
std::vector<
std::pair<std::shared_ptr<FutureMessage>, std::shared_ptr<RpcRetryInfo>>>