mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[DCP] Remove all-gather of state dict keys (#145998)
The original `_all_gather_keys` call was for a safety check, but could be costly as things scale, and it blocks CPU. Instead, we make it clear in the documentation that the `state_dict` passed to the `load` API should have same set of keys, otherwise the API may hang. In addition, we move the check to a utility function: `utils.assert_same_keys`. User uncertain about state dict unity can optionally call this API to check. Resolves #145965 (as a workaround). Pull Request resolved: https://github.com/pytorch/pytorch/pull/145998 Approved by: https://github.com/mhorowitz, https://github.com/fegin
This commit is contained in:
parent
7f796eb8b7
commit
762a05b3b3
3 changed files with 57 additions and 16 deletions
|
|
@ -2,6 +2,7 @@
|
|||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tensor.device_mesh import init_device_mesh
|
||||
|
|
@ -62,6 +63,24 @@ class TestSaveAndLoadAPI(DTensorTestBase):
|
|||
with self.assertRaisesRegex(RuntimeError, "Cannot detect"):
|
||||
dcp.load(model.state_dict(), checkpoint_id="abc://abc.abc")
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_assert_same_keys(self):
|
||||
"""Test the `_assert_same_keys` function."""
|
||||
model = MyTestModule()
|
||||
state_dict = model.state_dict()
|
||||
# Check across ranks; expect true
|
||||
dcp.utils._assert_same_keys(state_dict)
|
||||
|
||||
# Introduces difference; expect false
|
||||
if self.rank == 0:
|
||||
state_dict["abc"] = torch.rand(1)
|
||||
else:
|
||||
state_dict["def"] = torch.rand(1)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
dcp.utils._assert_same_keys(state_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from ._storage_utils import _storage_setup
|
|||
from .default_planner import DefaultLoadPlanner
|
||||
from .planner import LoadPlan, LoadPlanner
|
||||
from .storage import StorageReader
|
||||
from .utils import _all_gather_keys, _api_bc_check, _DistWrapper, _profile
|
||||
from .utils import _api_bc_check, _DistWrapper, _profile
|
||||
|
||||
|
||||
__all__ = ["load_state_dict", "load"]
|
||||
|
|
@ -60,7 +60,12 @@ def load(
|
|||
no_dist: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Load a distributed ``state_dict`` in SPMD style.
|
||||
Load a checkpoint into a distributed state dict in SPMD style.
|
||||
|
||||
Each rank must have the same keys in their ``state_dict`` provided to this
|
||||
API. Mismatched keys may result in hangs or errors. If unsure, you can use
|
||||
the ``utils._assert_same_keys`` API to check (but may incur communication
|
||||
costs).
|
||||
|
||||
Each rank will try to read the least amount of data necessary
|
||||
to fullfill the requested `state_dict`. When loading :class:`ShardedTensor`
|
||||
|
|
@ -93,7 +98,7 @@ def load(
|
|||
Rank 0 is assumed to be the coordinator rank.
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Any]): The state_dict to save.
|
||||
state_dict (Dict[str, Any]): The state_dict to load the checkpoint into.
|
||||
checkpoint_id (Union[str, os.PathLike, None]):
|
||||
The ID of this checkpoint instance. The meaning of the checkpoint_id
|
||||
depends on the storage. It can be a path to a folder or to a file.
|
||||
|
|
@ -152,15 +157,11 @@ def load(
|
|||
StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True)
|
||||
)
|
||||
|
||||
if no_dist:
|
||||
keys = list(state_dict.keys())
|
||||
else:
|
||||
keys = _all_gather_keys(state_dict, process_group)
|
||||
if keys != sorted(state_dict.keys()):
|
||||
warnings.warn(
|
||||
"Detected mismatched keys in state dict after all gather!"
|
||||
" This behavior is unsupported and may cause errors may cause errors."
|
||||
)
|
||||
# All ranks must have the same keys in their `state_dict` provided to
|
||||
# this API. See documentation for more details.
|
||||
# Here we simply sort the keys to ensure that all ranks load values in
|
||||
# the same order.
|
||||
keys = sorted(state_dict.keys())
|
||||
|
||||
statetful_sd = {}
|
||||
for key in keys:
|
||||
|
|
|
|||
|
|
@ -41,14 +41,35 @@ def _get_failure_dict(
|
|||
|
||||
|
||||
def _all_gather_keys(
|
||||
local_dict: dict[Any, Any], group: Optional[dist.ProcessGroup] = None
|
||||
) -> list[Any]:
|
||||
local_dict: dict[str, Any], group: Optional[dist.ProcessGroup] = None
|
||||
) -> set[str]:
|
||||
"""Gathers all keys, and returns them sorted."""
|
||||
keys = list(local_dict.keys())
|
||||
gathered_keys: list[list[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item]
|
||||
gathered_keys: list[list[str]] = [None] * dist.get_world_size(group) # type: ignore[list-item]
|
||||
|
||||
dist.all_gather_object(gathered_keys, keys, group=group)
|
||||
return sorted(set(itertools.chain.from_iterable(gathered_keys)))
|
||||
return set(itertools.chain.from_iterable(gathered_keys))
|
||||
|
||||
|
||||
def _assert_same_keys(
|
||||
state_dict: dict[str, Any], process_group: Optional[dist.ProcessGroup] = None
|
||||
) -> None:
|
||||
"""
|
||||
Asserts that all ranks have the same keys in their state dict.
|
||||
This is a collective call which requires all ranks in ``process_group`` to
|
||||
join. It will also induce cross-rank communication and block CPU.
|
||||
"""
|
||||
|
||||
if dist.get_world_size(process_group) == 1:
|
||||
return
|
||||
|
||||
all_keys = _all_gather_keys(state_dict, process_group)
|
||||
my_keys = set(state_dict.keys())
|
||||
diff = all_keys - my_keys
|
||||
if len(diff) > 0:
|
||||
raise AssertionError(
|
||||
f"Key(s) present in other ranks but not this one, difference: {diff}"
|
||||
)
|
||||
|
||||
|
||||
class _DistWrapper:
|
||||
|
|
|
|||
Loading…
Reference in a new issue