[DCP] Remove all-gather of state dict keys (#145998)

The original `_all_gather_keys` call was for a safety check, but could be costly as things scale, and it blocks CPU.

Instead, we make it clear in the documentation that the `state_dict` passed to the `load` API should have same set of keys, otherwise the API may hang.

In addition, we move the check to a utility function: `utils.assert_same_keys`. User uncertain about state dict unity can optionally call this API to check.

Resolves #145965 (as a workaround).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145998
Approved by: https://github.com/mhorowitz, https://github.com/fegin
This commit is contained in:
Ke Wen 2025-02-03 15:59:35 -08:00 committed by PyTorch MergeBot
parent 7f796eb8b7
commit 762a05b3b3
3 changed files with 57 additions and 16 deletions

View file

@ -2,6 +2,7 @@
import os
from unittest.mock import patch
import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed._tensor.device_mesh import init_device_mesh
@ -62,6 +63,24 @@ class TestSaveAndLoadAPI(DTensorTestBase):
with self.assertRaisesRegex(RuntimeError, "Cannot detect"):
dcp.load(model.state_dict(), checkpoint_id="abc://abc.abc")
@with_comms
@skip_if_lt_x_gpu(2)
def test_assert_same_keys(self):
"""Test the `_assert_same_keys` function."""
model = MyTestModule()
state_dict = model.state_dict()
# Check across ranks; expect true
dcp.utils._assert_same_keys(state_dict)
# Introduces difference; expect false
if self.rank == 0:
state_dict["abc"] = torch.rand(1)
else:
state_dict["def"] = torch.rand(1)
with self.assertRaises(AssertionError):
dcp.utils._assert_same_keys(state_dict)
if __name__ == "__main__":
run_tests()

View file

@ -15,7 +15,7 @@ from ._storage_utils import _storage_setup
from .default_planner import DefaultLoadPlanner
from .planner import LoadPlan, LoadPlanner
from .storage import StorageReader
from .utils import _all_gather_keys, _api_bc_check, _DistWrapper, _profile
from .utils import _api_bc_check, _DistWrapper, _profile
__all__ = ["load_state_dict", "load"]
@ -60,7 +60,12 @@ def load(
no_dist: bool = False,
) -> None:
"""
Load a distributed ``state_dict`` in SPMD style.
Load a checkpoint into a distributed state dict in SPMD style.
Each rank must have the same keys in their ``state_dict`` provided to this
API. Mismatched keys may result in hangs or errors. If unsure, you can use
the ``utils._assert_same_keys`` API to check (but may incur communication
costs).
Each rank will try to read the least amount of data necessary
to fullfill the requested `state_dict`. When loading :class:`ShardedTensor`
@ -93,7 +98,7 @@ def load(
Rank 0 is assumed to be the coordinator rank.
Args:
state_dict (Dict[str, Any]): The state_dict to save.
state_dict (Dict[str, Any]): The state_dict to load the checkpoint into.
checkpoint_id (Union[str, os.PathLike, None]):
The ID of this checkpoint instance. The meaning of the checkpoint_id
depends on the storage. It can be a path to a folder or to a file.
@ -152,15 +157,11 @@ def load(
StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True)
)
if no_dist:
keys = list(state_dict.keys())
else:
keys = _all_gather_keys(state_dict, process_group)
if keys != sorted(state_dict.keys()):
warnings.warn(
"Detected mismatched keys in state dict after all gather!"
" This behavior is unsupported and may cause errors may cause errors."
)
# All ranks must have the same keys in their `state_dict` provided to
# this API. See documentation for more details.
# Here we simply sort the keys to ensure that all ranks load values in
# the same order.
keys = sorted(state_dict.keys())
statetful_sd = {}
for key in keys:

View file

@ -41,14 +41,35 @@ def _get_failure_dict(
def _all_gather_keys(
local_dict: dict[Any, Any], group: Optional[dist.ProcessGroup] = None
) -> list[Any]:
local_dict: dict[str, Any], group: Optional[dist.ProcessGroup] = None
) -> set[str]:
"""Gathers all keys, and returns them sorted."""
keys = list(local_dict.keys())
gathered_keys: list[list[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item]
gathered_keys: list[list[str]] = [None] * dist.get_world_size(group) # type: ignore[list-item]
dist.all_gather_object(gathered_keys, keys, group=group)
return sorted(set(itertools.chain.from_iterable(gathered_keys)))
return set(itertools.chain.from_iterable(gathered_keys))
def _assert_same_keys(
state_dict: dict[str, Any], process_group: Optional[dist.ProcessGroup] = None
) -> None:
"""
Asserts that all ranks have the same keys in their state dict.
This is a collective call which requires all ranks in ``process_group`` to
join. It will also induce cross-rank communication and block CPU.
"""
if dist.get_world_size(process_group) == 1:
return
all_keys = _all_gather_keys(state_dict, process_group)
my_keys = set(state_dict.keys())
diff = all_keys - my_keys
if len(diff) > 0:
raise AssertionError(
f"Key(s) present in other ranks but not this one, difference: {diff}"
)
class _DistWrapper: