diff --git a/docs/source/distributed.checkpoint.rst b/docs/source/distributed.checkpoint.rst index e948b66e1ec..573faa429b7 100644 --- a/docs/source/distributed.checkpoint.rst +++ b/docs/source/distributed.checkpoint.rst @@ -36,6 +36,9 @@ The following module is also useful for additional customization of the staging .. autoclass:: torch.distributed.checkpoint.staging.AsyncStager :members: +.. autoclass:: torch.distributed.checkpoint.staging.BlockingAsyncStager + :members: + In addition to the above entrypoints, `Stateful` objects, as described below, provide additional customization during saving/loading .. automodule:: torch.distributed.checkpoint.stateful diff --git a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py index 8d4733b8276..f24bb131667 100644 --- a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py +++ b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py @@ -211,10 +211,18 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin): @with_comms @skip_if_lt_x_gpu(4) @with_temp_dir - def test_e2e_async(self): - self._run_e2e_test(compile=False, model_type=ModelType.FSDP, async_op=True) + @parametrize("cache_staged_state_dict", [False, True]) + def test_e2e_async_cached(self, cache_staged_state_dict): + self._run_e2e_test( + compile=False, + model_type=ModelType.FSDP, + async_op=True, + cache_staged_state_dict=cache_staged_state_dict, + ) - def _run_e2e_test(self, compile, model_type, async_op=False): + def _run_e2e_test( + self, compile, model_type, async_op=False, cache_staged_state_dict=False + ): model, optim = self._create_model(compile, ModelType.NONE) _train(model, optim, train_steps=2) @@ -230,7 +238,10 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin): } if async_op: - f = saver.async_save(sd, checkpoint_id=self.temp_dir) + writer = DCP.FileSystemWriter( + self.temp_dir, cache_staged_state_dict=cache_staged_state_dict + ) + f = saver.async_save(sd, storage_writer=writer) t = time.monotonic() while not f.done(): time.sleep(1) diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index 9b6345862ce..3672e2401bb 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -11,6 +11,7 @@ from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path from typing import ( + Any, Callable, cast, Dict, @@ -28,6 +29,8 @@ import torch from torch import Tensor from torch._utils import _get_available_device_type, _get_device_module from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint.staging import BlockingAsyncStager + from torch.futures import Future from .metadata import Metadata, MetadataIndex @@ -393,7 +396,7 @@ class FileSystem(FileSystemBase): return False -class FileSystemWriter(StorageWriter): +class _FileSystemWriter(StorageWriter): """ Basic implementation of StorageWriter using file IO. @@ -414,6 +417,8 @@ class FileSystemWriter(StorageWriter): sync_files: bool = True, thread_count: int = 1, per_thread_copy_ahead: int = 10_000_000, + *args: Any, + **kwargs: Any, ) -> None: """ Initialize the writer pointing to `path`. @@ -631,3 +636,51 @@ class FileSystemReader(StorageReader): @classmethod def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: return FileSystem.validate_checkpoint_id(checkpoint_id) + + +class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager): + """ + Basic implementation of StorageWriter using file IO. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a `.metadata` file with the serialized metadata. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + thread_count: int = 1, + per_thread_copy_ahead: int = 10_000_000, + cache_staged_state_dict: bool = False, + ) -> None: + """ + Initialize the writer pointing to `path`. + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + thread_count: Number of IO threads to use to write. Default to 1. + per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency + at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation + that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False. + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + super().__init__( + path=path, + single_file_per_rank=single_file_per_rank, + sync_files=sync_files, + thread_count=thread_count, + per_thread_copy_ahead=per_thread_copy_ahead, + cache_staged_state_dict=cache_staged_state_dict, + ) diff --git a/torch/distributed/checkpoint/staging.py b/torch/distributed/checkpoint/staging.py index 3dd294eb0a5..f4ce2673dfc 100644 --- a/torch/distributed/checkpoint/staging.py +++ b/torch/distributed/checkpoint/staging.py @@ -1,10 +1,16 @@ -from typing import runtime_checkable +from typing import Optional, runtime_checkable from typing_extensions import Protocol +from torch.distributed._state_dict_utils import ( + _copy_state_dict, + _create_cpu_state_dict, + _offload_state_dict_to_cpu, +) + from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE -__all__ = ["AsyncStager"] +__all__ = ["AsyncStager", "BlockingAsyncStager"] @runtime_checkable @@ -61,3 +67,54 @@ class AsyncStager(Protocol): is complete and it is safe to begin modifying the original `state_dict` """ pass + + +class BlockingAsyncStager(AsyncStager): + """ + An implementation of AsyncStager which stages the state_dict on CPU RAM and blocks until the copy is complete. + This implementation also provides an option to optimize stage latency using pinned memory. + + N.B. synchronize_staging is a no-op in this case. + + + """ + + # default to True since the common case is to stage synchronously + _synchronize_after_execute: bool = False + + def __init__( + self, + cache_staged_state_dict: bool = False, + type_check: bool = False, + ): + """ + Initializes the BlockingAsyncStager. + + Args: + cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency + at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation + that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False. + type_check: Whether to perform a type check during cpu_offload. Defaults to False. + + """ + self.cache_staged_state_dict = cache_staged_state_dict + self.type_check = type_check + self.state_dict_cache: Optional[STATE_DICT_TYPE] = None + + def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """ + Returns a copy of `state_dict` on the CPU. + """ + + if not self.cache_staged_state_dict: + return _offload_state_dict_to_cpu(state_dict, type_check=self.type_check) + + if self.state_dict_cache is None: + self.state_dict_cache = _create_cpu_state_dict(state_dict, pin_memory=True) + return _copy_state_dict(state_dict, self.state_dict_cache) + + def synchronize_staging(self) -> None: + """ + No-op function, since staging is blocking. + """ + pass