diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index a86039c6ef4..00bd235c866 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -16,7 +16,7 @@ namespace c10d { ncclComm_t NCCLComm::getNcclComm() { - std::unique_lock lock(mutex_); + LockType lock(mutex_); if (aborted_) { auto commFailureMsg = commFailureReason_ != std::nullopt ? c10::str(" Original reason for failure was: ", *commFailureReason_) diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index af32ab83ef5..0089d453bb8 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -275,6 +275,9 @@ class TORCH_API DebugInfoWriter { // RAII wrapper for NCCL communicator class NCCLComm { + using MutexType = std::recursive_mutex; + using LockType = std::unique_lock; + public: explicit NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {} @@ -283,7 +286,7 @@ class NCCLComm { ~NCCLComm() noexcept { // Add lock in this destructor, as aborted_ needs to be read after memory // barrier here. - std::unique_lock lock(mutex_); + LockType lock(mutex_); if (ncclComm_ && initialized_ && !aborted_) { #ifdef ENABLE_NCCL_ERROR_CHECKING // Use ncclCommAbort instead of ncclCommDestroy here since @@ -371,7 +374,7 @@ class NCCLComm { NCCLComm(NCCLComm&& other) { // Using other's lock, as it reads other's states // Can not use this.mutex_, as this object is being constructed. - std::unique_lock lock(other.mutex_); + LockType lock(other.mutex_); std::swap(ncclComm_, other.ncclComm_); std::swap(aborted_, other.aborted_); std::swap(ncclAsyncErr_, other.ncclAsyncErr_); @@ -382,13 +385,13 @@ class NCCLComm { ncclComm_t getNcclComm(); std::optional getNcclCommFailureReason() const { - std::unique_lock lock(mutex_); + LockType lock(mutex_); return commFailureReason_; } void ncclCommAbort( std::optional commFailureReason = std::nullopt) { - std::unique_lock lock(mutex_); + LockType lock(mutex_); #ifdef ENABLE_NCCL_ERROR_CHECKING if (aborted_ && !initialized_) { // Should not abort twice. @@ -436,12 +439,12 @@ class NCCLComm { } bool isInitialized() const { - std::unique_lock lock(mutex_); + LockType lock(mutex_); return initialized_; } bool isAborted() const { - std::unique_lock lock(mutex_); + LockType lock(mutex_); return aborted_; } @@ -450,7 +453,7 @@ class NCCLComm { } ncclResult_t checkForNcclError() { - std::unique_lock lock(mutex_); + LockType lock(mutex_); #ifdef ENABLE_NCCL_ERROR_CHECKING if (ncclAsyncErr_ != ncclSuccess) { return ncclAsyncErr_; @@ -465,7 +468,7 @@ class NCCLComm { } ncclResult_t registerSegment(void* ptr, size_t size) { - std::unique_lock lock(mutex_); + LockType lock(mutex_); #ifdef NCCL_HAS_COMM_REGISTER // We register only segments from cache allocator // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always @@ -498,7 +501,7 @@ class NCCLComm { } ncclResult_t deregisterSegment(void* ptr) { - std::unique_lock lock(mutex_); + LockType lock(mutex_); #ifdef NCCL_HAS_COMM_REGISTER TORCH_CHECK( registeredSegmentHandles_.count(ptr) == 1, @@ -538,7 +541,7 @@ class NCCLComm { bool aborted_{false}; uint64_t ncclCommSplitCounter_{0}; ncclResult_t ncclAsyncErr_{ncclSuccess}; - mutable std::mutex mutex_; + mutable MutexType mutex_; // Rank that this communicator corresponds to. int rank_{}; // Optional reason for communicator failure, provided by ProcessGroupNCCL for