[c10d][experimental] Add _abort_process_group (#132291)

Thanks @eqy for reminding me of this RFC: https://github.com/pytorch/pytorch/issues/119797

This PR is meant to:
- provide a way to abort multiple PGs without deadlocking each other.
- provide a possibility to manually handle comm errors or timeouts (and potentially recovery of such).
One can find an example from: https://github.com/NVIDIA/nccl/issues/1013

## How is it different from `destroy_process_group`?
`destroy_process_group` is meant for normal exit, while `_abort_process_group` is meant for bailout upon hangs or failures. Similar to `ncclCommDestroy` vs `ncclCommAbort`.

## What's new in `_abort_process_group`?
It added support for "group abort" semantic. The "group abort" semantic is capable of aborting multiple NCCL comms concurrently, avoiding deadlock in otherwise serialized `ncclCommAbort` executions. Details are in the [RFC](https://github.com/pytorch/pytorch/issues/119797) targeting [the hang issue in multi-comm case](https://github.com/NVIDIA/nccl/issues/1013). `Group abort` semantic is added in NCCL 2.22.

## What's next?
Ideally, the watchdog's behavior should support "group abort" too. But this is hard to implement today due to a lack of "global view" by each PG's individual watchdog. A big semi-big refactor may be needed to "uplift" the watchdogs to a global level or consolidate them into one (i.e. one dog watching multiple PGs).

In any case, it may not be a bad idea to experiment the "group abort" feature with a manual API first and then extend to the automatic mode (watchdog).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132291
Approved by: https://github.com/eqy
This commit is contained in:
Ke Wen 2024-10-10 19:07:32 -07:00 committed by PyTorch MergeBot
parent bc232e3c08
commit fe148024fe
5 changed files with 144 additions and 13 deletions

View file

@ -577,6 +577,7 @@ class ProcessGroupNCCL(Backend):
def perform_nocolor_split(self, device: torch.device) -> None: ...
def comm_split_count(self) -> int: ...
def _add_ephemeral_timeout(self, timeout: timedelta) -> None: ...
def abort(self) -> None: ...
@property
def uid(self) -> int: ...
@property

View file

@ -1251,9 +1251,10 @@ void ProcessGroupNCCL::abortCommsFromMap(
}
// Abort all communicators on this rank
bool ProcessGroupNCCL::abort(const std::optional<std::string>& abortReason) {
// This will log counter for how long the abort actually takes.
STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__abort);
// Note: original name of this method is `abort`. It was renamed to
// `abortComms` to distinguish from the `abort` method below. The `abort`
// method calls `abortComms` but does more destruction than the latter.
bool ProcessGroupNCCL::abortComms(std::optional<std::string> abortReason) {
// Remove record from global ncclCommDevIdxMapMutex before aboarting,
// so that a new cache segment would not register to already aborded
// communicators. Note that ncclCommDevIdxMap is a global container which may
@ -1272,7 +1273,11 @@ bool ProcessGroupNCCL::abort(const std::optional<std::string>& abortReason) {
return true;
}
void ProcessGroupNCCL::shutdown(const std::optional<std::string>& reason) {
// Abort this backend.
void ProcessGroupNCCL::abort() {
// This will log counter for how long the abort actually takes.
STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__abort);
// Don't join threads here since the purpose of this method is to abort all
// communicators and signal the threads to exit. Joining on the threads could
// potentially block and hence avoid it in this method.
@ -1282,8 +1287,8 @@ void ProcessGroupNCCL::shutdown(const std::optional<std::string>& reason) {
// lauch abort asynchrounously and wait for it to complete or timeout
LOG(INFO) << logPrefix()
<< "Launching ProcessGroupNCCL abort asynchrounously.";
std::future<bool> fut = std::async(
std::launch::async, [this, &reason]() { return this->abort(reason); });
std::future<bool> fut =
std::async(std::launch::async, [this]() { return this->abortComms(); });
waitForFutureOrTimeout(
fut, options_->timeout, "ProcessGroup abort", true, false);
@ -1295,6 +1300,15 @@ void ProcessGroupNCCL::shutdown(const std::optional<std::string>& reason) {
monitorWakeUpCV_.notify_one();
}
// Destroy (shutdown) this backend -- normal exit.
void ProcessGroupNCCL::shutdown() {
// kwen2501 (Aug 2024): moved code of `shutdown()` to `abort()` because it
// actually implemented an abort behavior.
// TODO: implementation of `shutdown` should use ncclCommDestroy() instead
// of ncclCommAbort(). Ideally non-blocking API mode should be used.
this->abort();
}
ProcessGroupNCCL::~ProcessGroupNCCL() {
LOG(INFO) << logPrefix() << "ProcessGroupNCCL destructor entered.";
@ -1881,7 +1895,7 @@ void ProcessGroupNCCL::watchdogHandler() {
work.abort();
// PG level abort, which would abort all other communicators on this
// rank
abort();
abortComms();
}
// Report desync state in case of timeout
@ -2039,7 +2053,7 @@ void ProcessGroupNCCL::runHookLoop() {
// already finished successfully at this point. We just need to abort
// the process Abort all NCCL Communicators on this ProcessGroupNCCL
// instance.
abort(errorStr);
abortComms(errorStr);
}
}

View file

@ -686,12 +686,12 @@ class TORCH_API ProcessGroupNCCL : public Backend {
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> initIntraNodeComm();
// Destroy (shutdown) this backend -- normal exit.
void shutdown();
// Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
// instead of relying on ProcessGroupNCCL destructor.
// return true if abort is successful, otherwise false
bool abort(const std::optional<std::string>& abortReason = std::nullopt);
void shutdown(const std::optional<std::string>& reason = std::nullopt);
void abort();
void eagerConnectSingleDevice(at::Device device) override;
@ -753,6 +753,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// operations, we might need to use a side thread to do it.
bool dumpDebuggingInfo();
// Abort all communicators on this rank.
bool abortComms(std::optional<std::string> abortReason = std::nullopt);
private:
int globalRankStart;
int globalRankStride;

View file

@ -2759,7 +2759,11 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
&::c10d::ProcessGroupNCCL::setBoundDeviceId)
.def(
"perform_nocolor_split",
&::c10d::ProcessGroupNCCL::performNocolorSplit);
&::c10d::ProcessGroupNCCL::performNocolorSplit)
.def(
"abort",
&::c10d::ProcessGroupNCCL::abort,
py::call_guard<py::gil_scoped_release>());
module.def(
"_get_intra_node_comm_usage_counter",

View file

@ -1706,6 +1706,20 @@ def _shutdown_backend(pg):
backend._shutdown()
def _abort_backend(pg: ProcessGroup):
"""
Abort the backend of a process group.
Currently, only ProcessGroupNCCL backend is supported.
No op for other backends.
"""
try:
backend = pg._get_backend(torch.device("cuda"))
except RuntimeError:
backend = None
if isinstance(backend, ProcessGroupNCCL):
backend.abort()
def _new_process_group_helper(
group_size,
group_rank,
@ -2064,6 +2078,101 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
_unregister_process_group(pg.group_name)
def _abort_process_group(group: Optional[ProcessGroup] = None):
"""
Abort a given process group. If group.WORLD (i.e. `None`) is given, all
process groups including the default one will be aborted.
Args:
group (ProcessGroup, optional): The process group to be aborted.
.. note:: this API is experimental and currently only works with the NCCL
backend.
.. note:: this API should be used with `TORCH_NCCL_ASYNC_ERROR_HANDLING`
turned off (i.e. set to 0). Otherwise, ProcessGroupNCCL's watchdog may
automatically handle errors or timeouts for you including aborting the
ProcessGroup.
"""
global _world
if group == GroupMember.NON_GROUP_MEMBER:
return
pg = group or GroupMember.WORLD
assert pg is not None
if _world.pg_map.get(pg, None) is None:
raise ValueError("Invalid process group specified or has been destroyed.")
try:
backend = pg._get_backend(torch.device("cuda"))
except RuntimeError:
backend = None
if not isinstance(backend, ProcessGroupNCCL):
logger.warning(
"`abort_process_group` currently only has implementation for ProcessGroupNCCL; "
"however, no NCCL backend is found. This call will be a no-op."
)
return
if group == GroupMember.WORLD:
# Abort all backends within a ncclGroupStart|End semantic.
# This ensures that different NCCL communicators' abort calls won't
# deadlock each other.
# For details, please see: https://github.com/pytorch/pytorch/issues/119797
backend._group_start()
for pg_to_abort in sorted(
_world.pg_names, key=lambda x: _world.pg_names[x], reverse=True
):
_abort_backend(pg_to_abort)
backend._group_end()
_update_default_pg(None)
_world.pg_map.clear()
_world.pg_names.clear()
_world.pg_group_ranks.clear()
_world.pg_backend_config.clear()
_world.pg_to_tag.clear()
_world.tags_to_pg.clear()
_world.pg_coalesce_state.clear()
_unregister_all_process_groups()
# when process group doesn't have an explicit name (only WORLD (default)
# process group can have an explicit name), we use global _world.group_count
# to generate the name. We need to reset the counter on destruction to
# allow consistent value to be generated when we re-create process
# groups after some trainers recover from failure
#
# We only reset this when WORLD is being destroyed because if this
# process group is in good state, we aren't dealing with failures.
_world.group_count = 0
else:
_abort_backend(pg)
del _world.pg_map[pg]
del _world.pg_names[pg]
del _world.pg_group_ranks[pg]
del _world.pg_backend_config[pg]
if pg in _world.pg_coalesce_state.keys():
warnings.warn(
"Some coalesced collectives haven't been launched when "
"ProcessGroup is aborted. They will be cleaned."
)
del _world.pg_coalesce_state[pg]
tag = _world.pg_to_tag.get(pg)
del _world.pg_to_tag[pg]
if tag is not None:
try:
_world.tags_to_pg[tag].remove(pg)
if tag.startswith("ptd:"):
_world.tags_to_pg[""].remove(pg)
except Exception:
pass
_unregister_process_group(pg.group_name)
def get_rank(group: Optional[ProcessGroup] = None) -> int:
"""
Return the rank of the current process in the provided ``group``, default otherwise.