mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
[RPC][Better Engineering] Consolidated all rpcAgentRunning atomic booleans (#33915)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33915 Closes: https://github.com/pytorch/pytorch/issues/32963 Test Plan: build bot Reviewed By: jjlilley Differential Revision: D20074714 fbshipit-source-id: ee89e76f547a1da71825a317c096176524504290
This commit is contained in:
parent
c5c63a2e35
commit
19bbfbe1cf
4 changed files with 35 additions and 34 deletions
|
|
@ -142,7 +142,7 @@ ProcessGroupAgent::ProcessGroupAgent(
|
|||
}
|
||||
|
||||
ProcessGroupAgent::~ProcessGroupAgent() {
|
||||
if (rpcRunning_) {
|
||||
if (rpcAgentRunning_) {
|
||||
shutdown();
|
||||
}
|
||||
}
|
||||
|
|
@ -240,11 +240,7 @@ void ProcessGroupAgent::sync() {
|
|||
} while (hasPendingMessage());
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::start() {
|
||||
{
|
||||
std::lock_guard<std::mutex> futureLock{futureMutex_};
|
||||
rpcRunning_.store(true);
|
||||
}
|
||||
void ProcessGroupAgent::startImpl() {
|
||||
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
|
||||
futureTimeoutThread_ =
|
||||
std::thread(&ProcessGroupAgent::pollTimedOutRPCs, this);
|
||||
|
|
@ -254,7 +250,7 @@ void ProcessGroupAgent::shutdown() {
|
|||
LOG(INFO) << "Shutting down ProcessGroupAgent on rank " << pg_->getRank()
|
||||
<< ".";
|
||||
std::unique_lock<std::mutex> lock{futureMutex_};
|
||||
if (!rpcRunning_.exchange(false)) {
|
||||
if (!rpcAgentRunning_.exchange(false)) {
|
||||
return;
|
||||
}
|
||||
lock.unlock();
|
||||
|
|
@ -297,7 +293,7 @@ std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
|
|||
}
|
||||
}
|
||||
|
||||
if (!rpcRunning_.load()) {
|
||||
if (!rpcAgentRunning_.load()) {
|
||||
// We are trying to send but RPC has been shut down on this node. This can
|
||||
// happen if we are in a shutdown sequence but background threads are still
|
||||
// processing messages that result in send()s. Throw a descriptive error.
|
||||
|
|
@ -449,7 +445,7 @@ void ProcessGroupAgent::handleSend(const SendWork& work) {
|
|||
}
|
||||
|
||||
for (auto& pendingSend : pendingSends) {
|
||||
if (!rpcRunning_.load() || !pendingSend->wait()) {
|
||||
if (!rpcAgentRunning_.load() || !pendingSend->wait()) {
|
||||
// Send was interrupted or RPC is not running.
|
||||
return;
|
||||
}
|
||||
|
|
@ -694,7 +690,7 @@ void ProcessGroupAgent::listenLoop() {
|
|||
}
|
||||
|
||||
void ProcessGroupAgent::listenLoopInternal() {
|
||||
while (rpcRunning_.load()) {
|
||||
while (rpcAgentRunning_.load()) {
|
||||
// rank, tensor size, message type
|
||||
std::vector<torch::Tensor> preamble = {torch::empty({4}, {torch::kInt64})};
|
||||
auto work = pg_->recvAnysource(preamble, pg_->getRank());
|
||||
|
|
@ -703,7 +699,7 @@ void ProcessGroupAgent::listenLoopInternal() {
|
|||
recvWork_ = work;
|
||||
}
|
||||
|
||||
if (!rpcRunning_.load() || !work->wait() /* aborted */) {
|
||||
if (!rpcAgentRunning_.load() || !work->wait() /* aborted */) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -723,7 +719,7 @@ void ProcessGroupAgent::listenLoopInternal() {
|
|||
}
|
||||
|
||||
void ProcessGroupAgent::pollTimedOutRPCs() {
|
||||
while (rpcRunning_.load()) {
|
||||
while (rpcAgentRunning_.load()) {
|
||||
std::unique_lock<std::mutex> lock{futureMutex_};
|
||||
steady_clock_time_point minEndTime;
|
||||
// Estimate amount of time the first future will time out in, and sleep
|
||||
|
|
@ -737,13 +733,13 @@ void ProcessGroupAgent::pollTimedOutRPCs() {
|
|||
}
|
||||
|
||||
auto shouldUpdateMinEndTimePredicate = [&, this]() -> bool {
|
||||
// Notice, whoever modifying `rpcRunning_`
|
||||
// Notice, whoever modifying `rpcAgentRunning_`
|
||||
// must acquire lock on `futureMutex_`.
|
||||
// Otherwise, this predicate could deadlock.
|
||||
// If during evaluating the predicate, `::shutdown()` is called, then
|
||||
// the predicate missed the notification before it started waiting
|
||||
// on the cond var.
|
||||
if (!rpcRunning_.load()) {
|
||||
if (!rpcAgentRunning_.load()) {
|
||||
return true;
|
||||
}
|
||||
steady_clock_time_point minEndTimeInMap = kInfiniteTimeoutTimePoint;
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class ProcessGroupAgent : public RpcAgent {
|
|||
|
||||
void sync() override;
|
||||
|
||||
void start() override;
|
||||
void startImpl() override;
|
||||
|
||||
void shutdown() override;
|
||||
|
||||
|
|
@ -219,13 +219,6 @@ class ProcessGroupAgent : public RpcAgent {
|
|||
MessageCounter recvCounts_;
|
||||
|
||||
std::atomic<int64_t> nextId_;
|
||||
// atomic bool indicating if this agent is running. It is set in
|
||||
// ProcessGroupAgent::start and unset in ProcessGroupAgent::shutdown and
|
||||
// ProcessGroupAgent::join. It controls whether several background threads
|
||||
// should be running.
|
||||
// We lock access to this in shutdown() and pollTimedOutRPCs() to prevent race
|
||||
// conditions when notifying condition variables.
|
||||
std::atomic<bool> rpcRunning_{false};
|
||||
// one mutex per ProcessGroup rank, as ProcessGroup::send is not thread-safe
|
||||
// when using the same tag.
|
||||
std::vector<std::mutex> sendMutexes_;
|
||||
|
|
|
|||
|
|
@ -20,20 +20,24 @@ RpcAgent::RpcAgent(
|
|||
cb_(std::move(cb)),
|
||||
rpcTimeout_(rpcTimeout),
|
||||
profilingEnabled_(false),
|
||||
rpcAgentRunning_(true) {
|
||||
rpcRetryThread_ = std::thread(&RpcAgent::retryExpiredRpcs, this);
|
||||
}
|
||||
rpcAgentRunning_(false) {}
|
||||
|
||||
RpcAgent::~RpcAgent() {
|
||||
cleanup();
|
||||
}
|
||||
|
||||
void RpcAgent::start() {
|
||||
rpcAgentRunning_.store(true);
|
||||
rpcRetryThread_ = std::thread(&RpcAgent::retryExpiredRpcs, this);
|
||||
startImpl();
|
||||
}
|
||||
|
||||
void RpcAgent::cleanup() {
|
||||
if (!rpcAgentRunning_.exchange(false)) {
|
||||
return;
|
||||
rpcAgentRunning_.store(false);
|
||||
if (rpcRetryThread_.joinable()) {
|
||||
rpcRetryMapCV_.notify_one();
|
||||
rpcRetryThread_.join();
|
||||
}
|
||||
rpcRetryMapCV_.notify_one();
|
||||
rpcRetryThread_.join();
|
||||
}
|
||||
|
||||
std::shared_ptr<FutureMessage> RpcAgent::sendWithRetries(
|
||||
|
|
|
|||
|
|
@ -198,8 +198,15 @@ class TORCH_API RpcAgent {
|
|||
// all ``RpcAgent``s reach this method and send all pending messages.
|
||||
virtual void sync() = 0;
|
||||
|
||||
// start accepting requests
|
||||
virtual void start() = 0;
|
||||
// Sets up backend-agnostic state for accepting requests. Currently, this
|
||||
// entails setting rpcAgentRunning_ to true, creating the retry thread, and
|
||||
// calling the backend's startImpl.
|
||||
void start();
|
||||
|
||||
// Derived classes must override this function to start accepting requests.
|
||||
// This is used to initialize any backend-specific state. Users must call
|
||||
// start, not startImpl, to initialize the RPC Agent.
|
||||
virtual void startImpl() = 0;
|
||||
|
||||
// Stop accepting requests and shutdown the RPC framework as soon as possible
|
||||
// by terminating all RPC threads.
|
||||
|
|
@ -240,6 +247,10 @@ class TORCH_API RpcAgent {
|
|||
std::atomic<std::chrono::milliseconds> rpcTimeout_;
|
||||
std::atomic<bool> profilingEnabled_;
|
||||
std::shared_ptr<TypeResolver> typeResolver_;
|
||||
// Atomic boolean indicating whether this agent is running. It controls
|
||||
// whether several background threads should be running. It is set in
|
||||
// RpcAgent::start() and unset in the derived class shutdown().
|
||||
std::atomic<bool> rpcAgentRunning_;
|
||||
|
||||
private:
|
||||
static std::shared_ptr<RpcAgent> currentRpcAgent_;
|
||||
|
|
@ -295,9 +306,6 @@ class TORCH_API RpcAgent {
|
|||
std::chrono::steady_clock::now() + timedelta);
|
||||
}
|
||||
|
||||
// Boolean that indicates whether RpcAgent is running.
|
||||
std::atomic<bool> rpcAgentRunning_;
|
||||
|
||||
// storing futures before adding callback
|
||||
std::vector<
|
||||
std::pair<std::shared_ptr<FutureMessage>, std::shared_ptr<RpcRetryInfo>>>
|
||||
|
|
|
|||
Loading…
Reference in a new issue