diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index f1cbf47ea0f..edd51c89875 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -577,6 +577,7 @@ class ProcessGroupNCCL(Backend): def perform_nocolor_split(self, device: torch.device) -> None: ... def comm_split_count(self) -> int: ... def _add_ephemeral_timeout(self, timeout: timedelta) -> None: ... + def abort(self) -> None: ... @property def uid(self) -> int: ... @property diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 48d07983f2b..eaeeb22199e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1251,9 +1251,10 @@ void ProcessGroupNCCL::abortCommsFromMap( } // Abort all communicators on this rank -bool ProcessGroupNCCL::abort(const std::optional& abortReason) { - // This will log counter for how long the abort actually takes. - STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__abort); +// Note: original name of this method is `abort`. It was renamed to +// `abortComms` to distinguish from the `abort` method below. The `abort` +// method calls `abortComms` but does more destruction than the latter. +bool ProcessGroupNCCL::abortComms(std::optional abortReason) { // Remove record from global ncclCommDevIdxMapMutex before aboarting, // so that a new cache segment would not register to already aborded // communicators. Note that ncclCommDevIdxMap is a global container which may @@ -1272,7 +1273,11 @@ bool ProcessGroupNCCL::abort(const std::optional& abortReason) { return true; } -void ProcessGroupNCCL::shutdown(const std::optional& reason) { +// Abort this backend. +void ProcessGroupNCCL::abort() { + // This will log counter for how long the abort actually takes. + STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__abort); + // Don't join threads here since the purpose of this method is to abort all // communicators and signal the threads to exit. Joining on the threads could // potentially block and hence avoid it in this method. @@ -1282,8 +1287,8 @@ void ProcessGroupNCCL::shutdown(const std::optional& reason) { // lauch abort asynchrounously and wait for it to complete or timeout LOG(INFO) << logPrefix() << "Launching ProcessGroupNCCL abort asynchrounously."; - std::future fut = std::async( - std::launch::async, [this, &reason]() { return this->abort(reason); }); + std::future fut = + std::async(std::launch::async, [this]() { return this->abortComms(); }); waitForFutureOrTimeout( fut, options_->timeout, "ProcessGroup abort", true, false); @@ -1295,6 +1300,15 @@ void ProcessGroupNCCL::shutdown(const std::optional& reason) { monitorWakeUpCV_.notify_one(); } +// Destroy (shutdown) this backend -- normal exit. +void ProcessGroupNCCL::shutdown() { + // kwen2501 (Aug 2024): moved code of `shutdown()` to `abort()` because it + // actually implemented an abort behavior. + // TODO: implementation of `shutdown` should use ncclCommDestroy() instead + // of ncclCommAbort(). Ideally non-blocking API mode should be used. + this->abort(); +} + ProcessGroupNCCL::~ProcessGroupNCCL() { LOG(INFO) << logPrefix() << "ProcessGroupNCCL destructor entered."; @@ -1881,7 +1895,7 @@ void ProcessGroupNCCL::watchdogHandler() { work.abort(); // PG level abort, which would abort all other communicators on this // rank - abort(); + abortComms(); } // Report desync state in case of timeout @@ -2039,7 +2053,7 @@ void ProcessGroupNCCL::runHookLoop() { // already finished successfully at this point. We just need to abort // the process Abort all NCCL Communicators on this ProcessGroupNCCL // instance. - abort(errorStr); + abortComms(errorStr); } } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 428357fc908..9d89f429c1a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -686,12 +686,12 @@ class TORCH_API ProcessGroupNCCL : public Backend { c10::intrusive_ptr initIntraNodeComm(); + // Destroy (shutdown) this backend -- normal exit. + void shutdown(); + // Provides an API to abort the ProcessGroup (similar to ncclCommAbort) // instead of relying on ProcessGroupNCCL destructor. - // return true if abort is successful, otherwise false - bool abort(const std::optional& abortReason = std::nullopt); - - void shutdown(const std::optional& reason = std::nullopt); + void abort(); void eagerConnectSingleDevice(at::Device device) override; @@ -753,6 +753,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { // operations, we might need to use a side thread to do it. bool dumpDebuggingInfo(); + // Abort all communicators on this rank. + bool abortComms(std::optional abortReason = std::nullopt); + private: int globalRankStart; int globalRankStride; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index e9ee1547ef6..034b2ccde41 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2759,7 +2759,11 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). &::c10d::ProcessGroupNCCL::setBoundDeviceId) .def( "perform_nocolor_split", - &::c10d::ProcessGroupNCCL::performNocolorSplit); + &::c10d::ProcessGroupNCCL::performNocolorSplit) + .def( + "abort", + &::c10d::ProcessGroupNCCL::abort, + py::call_guard()); module.def( "_get_intra_node_comm_usage_counter", diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index ce2da9e891f..f169d180c25 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1706,6 +1706,20 @@ def _shutdown_backend(pg): backend._shutdown() +def _abort_backend(pg: ProcessGroup): + """ + Abort the backend of a process group. + Currently, only ProcessGroupNCCL backend is supported. + No op for other backends. + """ + try: + backend = pg._get_backend(torch.device("cuda")) + except RuntimeError: + backend = None + if isinstance(backend, ProcessGroupNCCL): + backend.abort() + + def _new_process_group_helper( group_size, group_rank, @@ -2064,6 +2078,101 @@ def destroy_process_group(group: Optional[ProcessGroup] = None): _unregister_process_group(pg.group_name) +def _abort_process_group(group: Optional[ProcessGroup] = None): + """ + Abort a given process group. If group.WORLD (i.e. `None`) is given, all + process groups including the default one will be aborted. + + Args: + group (ProcessGroup, optional): The process group to be aborted. + + .. note:: this API is experimental and currently only works with the NCCL + backend. + + .. note:: this API should be used with `TORCH_NCCL_ASYNC_ERROR_HANDLING` + turned off (i.e. set to 0). Otherwise, ProcessGroupNCCL's watchdog may + automatically handle errors or timeouts for you including aborting the + ProcessGroup. + """ + global _world + + if group == GroupMember.NON_GROUP_MEMBER: + return + + pg = group or GroupMember.WORLD + + assert pg is not None + if _world.pg_map.get(pg, None) is None: + raise ValueError("Invalid process group specified or has been destroyed.") + + try: + backend = pg._get_backend(torch.device("cuda")) + except RuntimeError: + backend = None + + if not isinstance(backend, ProcessGroupNCCL): + logger.warning( + "`abort_process_group` currently only has implementation for ProcessGroupNCCL; " + "however, no NCCL backend is found. This call will be a no-op." + ) + return + + if group == GroupMember.WORLD: + # Abort all backends within a ncclGroupStart|End semantic. + # This ensures that different NCCL communicators' abort calls won't + # deadlock each other. + # For details, please see: https://github.com/pytorch/pytorch/issues/119797 + backend._group_start() + for pg_to_abort in sorted( + _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True + ): + _abort_backend(pg_to_abort) + backend._group_end() + + _update_default_pg(None) + _world.pg_map.clear() + _world.pg_names.clear() + _world.pg_group_ranks.clear() + _world.pg_backend_config.clear() + _world.pg_to_tag.clear() + _world.tags_to_pg.clear() + _world.pg_coalesce_state.clear() + _unregister_all_process_groups() + + # when process group doesn't have an explicit name (only WORLD (default) + # process group can have an explicit name), we use global _world.group_count + # to generate the name. We need to reset the counter on destruction to + # allow consistent value to be generated when we re-create process + # groups after some trainers recover from failure + # + # We only reset this when WORLD is being destroyed because if this + # process group is in good state, we aren't dealing with failures. + _world.group_count = 0 + else: + _abort_backend(pg) + del _world.pg_map[pg] + del _world.pg_names[pg] + del _world.pg_group_ranks[pg] + del _world.pg_backend_config[pg] + if pg in _world.pg_coalesce_state.keys(): + warnings.warn( + "Some coalesced collectives haven't been launched when " + "ProcessGroup is aborted. They will be cleaned." + ) + del _world.pg_coalesce_state[pg] + + tag = _world.pg_to_tag.get(pg) + del _world.pg_to_tag[pg] + if tag is not None: + try: + _world.tags_to_pg[tag].remove(pg) + if tag.startswith("ptd:"): + _world.tags_to_pg[""].remove(pg) + except Exception: + pass + _unregister_process_group(pg.group_name) + + def get_rank(group: Optional[ProcessGroup] = None) -> int: """ Return the rank of the current process in the provided ``group``, default otherwise.