mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
[c10d][experimental] Add _abort_process_group (#132291)
Thanks @eqy for reminding me of this RFC: https://github.com/pytorch/pytorch/issues/119797 This PR is meant to: - provide a way to abort multiple PGs without deadlocking each other. - provide a possibility to manually handle comm errors or timeouts (and potentially recovery of such). One can find an example from: https://github.com/NVIDIA/nccl/issues/1013 ## How is it different from `destroy_process_group`? `destroy_process_group` is meant for normal exit, while `_abort_process_group` is meant for bailout upon hangs or failures. Similar to `ncclCommDestroy` vs `ncclCommAbort`. ## What's new in `_abort_process_group`? It added support for "group abort" semantic. The "group abort" semantic is capable of aborting multiple NCCL comms concurrently, avoiding deadlock in otherwise serialized `ncclCommAbort` executions. Details are in the [RFC](https://github.com/pytorch/pytorch/issues/119797) targeting [the hang issue in multi-comm case](https://github.com/NVIDIA/nccl/issues/1013). `Group abort` semantic is added in NCCL 2.22. ## What's next? Ideally, the watchdog's behavior should support "group abort" too. But this is hard to implement today due to a lack of "global view" by each PG's individual watchdog. A big semi-big refactor may be needed to "uplift" the watchdogs to a global level or consolidate them into one (i.e. one dog watching multiple PGs). In any case, it may not be a bad idea to experiment the "group abort" feature with a manual API first and then extend to the automatic mode (watchdog). Pull Request resolved: https://github.com/pytorch/pytorch/pull/132291 Approved by: https://github.com/eqy
This commit is contained in:
parent
bc232e3c08
commit
fe148024fe
5 changed files with 144 additions and 13 deletions
|
|
@ -577,6 +577,7 @@ class ProcessGroupNCCL(Backend):
|
|||
def perform_nocolor_split(self, device: torch.device) -> None: ...
|
||||
def comm_split_count(self) -> int: ...
|
||||
def _add_ephemeral_timeout(self, timeout: timedelta) -> None: ...
|
||||
def abort(self) -> None: ...
|
||||
@property
|
||||
def uid(self) -> int: ...
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -1251,9 +1251,10 @@ void ProcessGroupNCCL::abortCommsFromMap(
|
|||
}
|
||||
|
||||
// Abort all communicators on this rank
|
||||
bool ProcessGroupNCCL::abort(const std::optional<std::string>& abortReason) {
|
||||
// This will log counter for how long the abort actually takes.
|
||||
STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__abort);
|
||||
// Note: original name of this method is `abort`. It was renamed to
|
||||
// `abortComms` to distinguish from the `abort` method below. The `abort`
|
||||
// method calls `abortComms` but does more destruction than the latter.
|
||||
bool ProcessGroupNCCL::abortComms(std::optional<std::string> abortReason) {
|
||||
// Remove record from global ncclCommDevIdxMapMutex before aboarting,
|
||||
// so that a new cache segment would not register to already aborded
|
||||
// communicators. Note that ncclCommDevIdxMap is a global container which may
|
||||
|
|
@ -1272,7 +1273,11 @@ bool ProcessGroupNCCL::abort(const std::optional<std::string>& abortReason) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::shutdown(const std::optional<std::string>& reason) {
|
||||
// Abort this backend.
|
||||
void ProcessGroupNCCL::abort() {
|
||||
// This will log counter for how long the abort actually takes.
|
||||
STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__abort);
|
||||
|
||||
// Don't join threads here since the purpose of this method is to abort all
|
||||
// communicators and signal the threads to exit. Joining on the threads could
|
||||
// potentially block and hence avoid it in this method.
|
||||
|
|
@ -1282,8 +1287,8 @@ void ProcessGroupNCCL::shutdown(const std::optional<std::string>& reason) {
|
|||
// lauch abort asynchrounously and wait for it to complete or timeout
|
||||
LOG(INFO) << logPrefix()
|
||||
<< "Launching ProcessGroupNCCL abort asynchrounously.";
|
||||
std::future<bool> fut = std::async(
|
||||
std::launch::async, [this, &reason]() { return this->abort(reason); });
|
||||
std::future<bool> fut =
|
||||
std::async(std::launch::async, [this]() { return this->abortComms(); });
|
||||
|
||||
waitForFutureOrTimeout(
|
||||
fut, options_->timeout, "ProcessGroup abort", true, false);
|
||||
|
|
@ -1295,6 +1300,15 @@ void ProcessGroupNCCL::shutdown(const std::optional<std::string>& reason) {
|
|||
monitorWakeUpCV_.notify_one();
|
||||
}
|
||||
|
||||
// Destroy (shutdown) this backend -- normal exit.
|
||||
void ProcessGroupNCCL::shutdown() {
|
||||
// kwen2501 (Aug 2024): moved code of `shutdown()` to `abort()` because it
|
||||
// actually implemented an abort behavior.
|
||||
// TODO: implementation of `shutdown` should use ncclCommDestroy() instead
|
||||
// of ncclCommAbort(). Ideally non-blocking API mode should be used.
|
||||
this->abort();
|
||||
}
|
||||
|
||||
ProcessGroupNCCL::~ProcessGroupNCCL() {
|
||||
LOG(INFO) << logPrefix() << "ProcessGroupNCCL destructor entered.";
|
||||
|
||||
|
|
@ -1881,7 +1895,7 @@ void ProcessGroupNCCL::watchdogHandler() {
|
|||
work.abort();
|
||||
// PG level abort, which would abort all other communicators on this
|
||||
// rank
|
||||
abort();
|
||||
abortComms();
|
||||
}
|
||||
|
||||
// Report desync state in case of timeout
|
||||
|
|
@ -2039,7 +2053,7 @@ void ProcessGroupNCCL::runHookLoop() {
|
|||
// already finished successfully at this point. We just need to abort
|
||||
// the process Abort all NCCL Communicators on this ProcessGroupNCCL
|
||||
// instance.
|
||||
abort(errorStr);
|
||||
abortComms(errorStr);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -686,12 +686,12 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||
|
||||
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> initIntraNodeComm();
|
||||
|
||||
// Destroy (shutdown) this backend -- normal exit.
|
||||
void shutdown();
|
||||
|
||||
// Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
|
||||
// instead of relying on ProcessGroupNCCL destructor.
|
||||
// return true if abort is successful, otherwise false
|
||||
bool abort(const std::optional<std::string>& abortReason = std::nullopt);
|
||||
|
||||
void shutdown(const std::optional<std::string>& reason = std::nullopt);
|
||||
void abort();
|
||||
|
||||
void eagerConnectSingleDevice(at::Device device) override;
|
||||
|
||||
|
|
@ -753,6 +753,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||
// operations, we might need to use a side thread to do it.
|
||||
bool dumpDebuggingInfo();
|
||||
|
||||
// Abort all communicators on this rank.
|
||||
bool abortComms(std::optional<std::string> abortReason = std::nullopt);
|
||||
|
||||
private:
|
||||
int globalRankStart;
|
||||
int globalRankStride;
|
||||
|
|
|
|||
|
|
@ -2759,7 +2759,11 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
|||
&::c10d::ProcessGroupNCCL::setBoundDeviceId)
|
||||
.def(
|
||||
"perform_nocolor_split",
|
||||
&::c10d::ProcessGroupNCCL::performNocolorSplit);
|
||||
&::c10d::ProcessGroupNCCL::performNocolorSplit)
|
||||
.def(
|
||||
"abort",
|
||||
&::c10d::ProcessGroupNCCL::abort,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
module.def(
|
||||
"_get_intra_node_comm_usage_counter",
|
||||
|
|
|
|||
|
|
@ -1706,6 +1706,20 @@ def _shutdown_backend(pg):
|
|||
backend._shutdown()
|
||||
|
||||
|
||||
def _abort_backend(pg: ProcessGroup):
|
||||
"""
|
||||
Abort the backend of a process group.
|
||||
Currently, only ProcessGroupNCCL backend is supported.
|
||||
No op for other backends.
|
||||
"""
|
||||
try:
|
||||
backend = pg._get_backend(torch.device("cuda"))
|
||||
except RuntimeError:
|
||||
backend = None
|
||||
if isinstance(backend, ProcessGroupNCCL):
|
||||
backend.abort()
|
||||
|
||||
|
||||
def _new_process_group_helper(
|
||||
group_size,
|
||||
group_rank,
|
||||
|
|
@ -2064,6 +2078,101 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
|
|||
_unregister_process_group(pg.group_name)
|
||||
|
||||
|
||||
def _abort_process_group(group: Optional[ProcessGroup] = None):
|
||||
"""
|
||||
Abort a given process group. If group.WORLD (i.e. `None`) is given, all
|
||||
process groups including the default one will be aborted.
|
||||
|
||||
Args:
|
||||
group (ProcessGroup, optional): The process group to be aborted.
|
||||
|
||||
.. note:: this API is experimental and currently only works with the NCCL
|
||||
backend.
|
||||
|
||||
.. note:: this API should be used with `TORCH_NCCL_ASYNC_ERROR_HANDLING`
|
||||
turned off (i.e. set to 0). Otherwise, ProcessGroupNCCL's watchdog may
|
||||
automatically handle errors or timeouts for you including aborting the
|
||||
ProcessGroup.
|
||||
"""
|
||||
global _world
|
||||
|
||||
if group == GroupMember.NON_GROUP_MEMBER:
|
||||
return
|
||||
|
||||
pg = group or GroupMember.WORLD
|
||||
|
||||
assert pg is not None
|
||||
if _world.pg_map.get(pg, None) is None:
|
||||
raise ValueError("Invalid process group specified or has been destroyed.")
|
||||
|
||||
try:
|
||||
backend = pg._get_backend(torch.device("cuda"))
|
||||
except RuntimeError:
|
||||
backend = None
|
||||
|
||||
if not isinstance(backend, ProcessGroupNCCL):
|
||||
logger.warning(
|
||||
"`abort_process_group` currently only has implementation for ProcessGroupNCCL; "
|
||||
"however, no NCCL backend is found. This call will be a no-op."
|
||||
)
|
||||
return
|
||||
|
||||
if group == GroupMember.WORLD:
|
||||
# Abort all backends within a ncclGroupStart|End semantic.
|
||||
# This ensures that different NCCL communicators' abort calls won't
|
||||
# deadlock each other.
|
||||
# For details, please see: https://github.com/pytorch/pytorch/issues/119797
|
||||
backend._group_start()
|
||||
for pg_to_abort in sorted(
|
||||
_world.pg_names, key=lambda x: _world.pg_names[x], reverse=True
|
||||
):
|
||||
_abort_backend(pg_to_abort)
|
||||
backend._group_end()
|
||||
|
||||
_update_default_pg(None)
|
||||
_world.pg_map.clear()
|
||||
_world.pg_names.clear()
|
||||
_world.pg_group_ranks.clear()
|
||||
_world.pg_backend_config.clear()
|
||||
_world.pg_to_tag.clear()
|
||||
_world.tags_to_pg.clear()
|
||||
_world.pg_coalesce_state.clear()
|
||||
_unregister_all_process_groups()
|
||||
|
||||
# when process group doesn't have an explicit name (only WORLD (default)
|
||||
# process group can have an explicit name), we use global _world.group_count
|
||||
# to generate the name. We need to reset the counter on destruction to
|
||||
# allow consistent value to be generated when we re-create process
|
||||
# groups after some trainers recover from failure
|
||||
#
|
||||
# We only reset this when WORLD is being destroyed because if this
|
||||
# process group is in good state, we aren't dealing with failures.
|
||||
_world.group_count = 0
|
||||
else:
|
||||
_abort_backend(pg)
|
||||
del _world.pg_map[pg]
|
||||
del _world.pg_names[pg]
|
||||
del _world.pg_group_ranks[pg]
|
||||
del _world.pg_backend_config[pg]
|
||||
if pg in _world.pg_coalesce_state.keys():
|
||||
warnings.warn(
|
||||
"Some coalesced collectives haven't been launched when "
|
||||
"ProcessGroup is aborted. They will be cleaned."
|
||||
)
|
||||
del _world.pg_coalesce_state[pg]
|
||||
|
||||
tag = _world.pg_to_tag.get(pg)
|
||||
del _world.pg_to_tag[pg]
|
||||
if tag is not None:
|
||||
try:
|
||||
_world.tags_to_pg[tag].remove(pg)
|
||||
if tag.startswith("ptd:"):
|
||||
_world.tags_to_pg[""].remove(pg)
|
||||
except Exception:
|
||||
pass
|
||||
_unregister_process_group(pg.group_name)
|
||||
|
||||
|
||||
def get_rank(group: Optional[ProcessGroup] = None) -> int:
|
||||
"""
|
||||
Return the rank of the current process in the provided ``group``, default otherwise.
|
||||
|
|
|
|||
Loading…
Reference in a new issue