From 762a05b3b3ef40bcbb2ae4edb1945ffb399fb333 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 3 Feb 2025 15:59:35 -0800 Subject: [PATCH] [DCP] Remove all-gather of state dict keys (#145998) The original `_all_gather_keys` call was for a safety check, but could be costly as things scale, and it blocks CPU. Instead, we make it clear in the documentation that the `state_dict` passed to the `load` API should have same set of keys, otherwise the API may hang. In addition, we move the check to a utility function: `utils.assert_same_keys`. User uncertain about state dict unity can optionally call this API to check. Resolves #145965 (as a workaround). Pull Request resolved: https://github.com/pytorch/pytorch/pull/145998 Approved by: https://github.com/mhorowitz, https://github.com/fegin --- .../checkpoint/test_save_load_api.py | 19 ++++++++++++ .../checkpoint/state_dict_loader.py | 25 ++++++++-------- torch/distributed/checkpoint/utils.py | 29 ++++++++++++++++--- 3 files changed, 57 insertions(+), 16 deletions(-) diff --git a/test/distributed/checkpoint/test_save_load_api.py b/test/distributed/checkpoint/test_save_load_api.py index 862f59f00da..e50a6b07cde 100644 --- a/test/distributed/checkpoint/test_save_load_api.py +++ b/test/distributed/checkpoint/test_save_load_api.py @@ -2,6 +2,7 @@ import os from unittest.mock import patch +import torch import torch.distributed.checkpoint as dcp import torch.nn as nn from torch.distributed._tensor.device_mesh import init_device_mesh @@ -62,6 +63,24 @@ class TestSaveAndLoadAPI(DTensorTestBase): with self.assertRaisesRegex(RuntimeError, "Cannot detect"): dcp.load(model.state_dict(), checkpoint_id="abc://abc.abc") + @with_comms + @skip_if_lt_x_gpu(2) + def test_assert_same_keys(self): + """Test the `_assert_same_keys` function.""" + model = MyTestModule() + state_dict = model.state_dict() + # Check across ranks; expect true + dcp.utils._assert_same_keys(state_dict) + + # Introduces difference; expect false + if self.rank == 0: + state_dict["abc"] = torch.rand(1) + else: + state_dict["def"] = torch.rand(1) + + with self.assertRaises(AssertionError): + dcp.utils._assert_same_keys(state_dict) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index 4e732b83667..5221f5ff8ed 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -15,7 +15,7 @@ from ._storage_utils import _storage_setup from .default_planner import DefaultLoadPlanner from .planner import LoadPlan, LoadPlanner from .storage import StorageReader -from .utils import _all_gather_keys, _api_bc_check, _DistWrapper, _profile +from .utils import _api_bc_check, _DistWrapper, _profile __all__ = ["load_state_dict", "load"] @@ -60,7 +60,12 @@ def load( no_dist: bool = False, ) -> None: """ - Load a distributed ``state_dict`` in SPMD style. + Load a checkpoint into a distributed state dict in SPMD style. + + Each rank must have the same keys in their ``state_dict`` provided to this + API. Mismatched keys may result in hangs or errors. If unsure, you can use + the ``utils._assert_same_keys`` API to check (but may incur communication + costs). Each rank will try to read the least amount of data necessary to fullfill the requested `state_dict`. When loading :class:`ShardedTensor` @@ -93,7 +98,7 @@ def load( Rank 0 is assumed to be the coordinator rank. Args: - state_dict (Dict[str, Any]): The state_dict to save. + state_dict (Dict[str, Any]): The state_dict to load the checkpoint into. checkpoint_id (Union[str, os.PathLike, None]): The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. @@ -152,15 +157,11 @@ def load( StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True) ) - if no_dist: - keys = list(state_dict.keys()) - else: - keys = _all_gather_keys(state_dict, process_group) - if keys != sorted(state_dict.keys()): - warnings.warn( - "Detected mismatched keys in state dict after all gather!" - " This behavior is unsupported and may cause errors may cause errors." - ) + # All ranks must have the same keys in their `state_dict` provided to + # this API. See documentation for more details. + # Here we simply sort the keys to ensure that all ranks load values in + # the same order. + keys = sorted(state_dict.keys()) statetful_sd = {} for key in keys: diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index 3513fc678b0..cd483f8a779 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -41,14 +41,35 @@ def _get_failure_dict( def _all_gather_keys( - local_dict: dict[Any, Any], group: Optional[dist.ProcessGroup] = None -) -> list[Any]: + local_dict: dict[str, Any], group: Optional[dist.ProcessGroup] = None +) -> set[str]: """Gathers all keys, and returns them sorted.""" keys = list(local_dict.keys()) - gathered_keys: list[list[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item] + gathered_keys: list[list[str]] = [None] * dist.get_world_size(group) # type: ignore[list-item] dist.all_gather_object(gathered_keys, keys, group=group) - return sorted(set(itertools.chain.from_iterable(gathered_keys))) + return set(itertools.chain.from_iterable(gathered_keys)) + + +def _assert_same_keys( + state_dict: dict[str, Any], process_group: Optional[dist.ProcessGroup] = None +) -> None: + """ + Asserts that all ranks have the same keys in their state dict. + This is a collective call which requires all ranks in ``process_group`` to + join. It will also induce cross-rank communication and block CPU. + """ + + if dist.get_world_size(process_group) == 1: + return + + all_keys = _all_gather_keys(state_dict, process_group) + my_keys = set(state_dict.keys()) + diff = all_keys - my_keys + if len(diff) > 0: + raise AssertionError( + f"Key(s) present in other ranks but not this one, difference: {diff}" + ) class _DistWrapper: