[DCP] Provides default AsyncStager (#124939)

Differential Revision: [D56575987](https://our.internmc.facebook.com/intern/diff/D56575987/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124939
Approved by: https://github.com/fegin
ghstack dependencies: #122965
This commit is contained in:
Lucas Pasqualin 2024-05-02 08:53:54 -07:00 committed by PyTorch MergeBot
parent 3741fb3680
commit 799f1460af
4 changed files with 131 additions and 7 deletions

View file

@ -36,6 +36,9 @@ The following module is also useful for additional customization of the staging
.. autoclass:: torch.distributed.checkpoint.staging.AsyncStager
:members:
.. autoclass:: torch.distributed.checkpoint.staging.BlockingAsyncStager
:members:
In addition to the above entrypoints, `Stateful` objects, as described below, provide additional customization during saving/loading
.. automodule:: torch.distributed.checkpoint.stateful

View file

@ -211,10 +211,18 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_e2e_async(self):
self._run_e2e_test(compile=False, model_type=ModelType.FSDP, async_op=True)
@parametrize("cache_staged_state_dict", [False, True])
def test_e2e_async_cached(self, cache_staged_state_dict):
self._run_e2e_test(
compile=False,
model_type=ModelType.FSDP,
async_op=True,
cache_staged_state_dict=cache_staged_state_dict,
)
def _run_e2e_test(self, compile, model_type, async_op=False):
def _run_e2e_test(
self, compile, model_type, async_op=False, cache_staged_state_dict=False
):
model, optim = self._create_model(compile, ModelType.NONE)
_train(model, optim, train_steps=2)
@ -230,7 +238,10 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
}
if async_op:
f = saver.async_save(sd, checkpoint_id=self.temp_dir)
writer = DCP.FileSystemWriter(
self.temp_dir, cache_staged_state_dict=cache_staged_state_dict
)
f = saver.async_save(sd, storage_writer=writer)
t = time.monotonic()
while not f.done():
time.sleep(1)

View file

@ -11,6 +11,7 @@ from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import (
Any,
Callable,
cast,
Dict,
@ -28,6 +29,8 @@ import torch
from torch import Tensor
from torch._utils import _get_available_device_type, _get_device_module
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint.staging import BlockingAsyncStager
from torch.futures import Future
from .metadata import Metadata, MetadataIndex
@ -393,7 +396,7 @@ class FileSystem(FileSystemBase):
return False
class FileSystemWriter(StorageWriter):
class _FileSystemWriter(StorageWriter):
"""
Basic implementation of StorageWriter using file IO.
@ -414,6 +417,8 @@ class FileSystemWriter(StorageWriter):
sync_files: bool = True,
thread_count: int = 1,
per_thread_copy_ahead: int = 10_000_000,
*args: Any,
**kwargs: Any,
) -> None:
"""
Initialize the writer pointing to `path`.
@ -631,3 +636,51 @@ class FileSystemReader(StorageReader):
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return FileSystem.validate_checkpoint_id(checkpoint_id)
class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager):
"""
Basic implementation of StorageWriter using file IO.
This implementation makes the following assumptions and simplifications:
* The checkpoint path is an empty or non-existing directory.
* File creation is atomic
The checkpoint consist of one file per write request plus
a `.metadata` file with the serialized metadata.
"""
def __init__(
self,
path: Union[str, os.PathLike],
single_file_per_rank: bool = True,
sync_files: bool = True,
thread_count: int = 1,
per_thread_copy_ahead: int = 10_000_000,
cache_staged_state_dict: bool = False,
) -> None:
"""
Initialize the writer pointing to `path`.
Args:
path: directory where the checkpoint will be written to.
single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
sync_files : force files to be synced to permanent storage. Default to True.
thread_count: Number of IO threads to use to write. Default to 1.
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency
at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation
that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False.
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
"""
super().__init__(
path=path,
single_file_per_rank=single_file_per_rank,
sync_files=sync_files,
thread_count=thread_count,
per_thread_copy_ahead=per_thread_copy_ahead,
cache_staged_state_dict=cache_staged_state_dict,
)

View file

@ -1,10 +1,16 @@
from typing import runtime_checkable
from typing import Optional, runtime_checkable
from typing_extensions import Protocol
from torch.distributed._state_dict_utils import (
_copy_state_dict,
_create_cpu_state_dict,
_offload_state_dict_to_cpu,
)
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
__all__ = ["AsyncStager"]
__all__ = ["AsyncStager", "BlockingAsyncStager"]
@runtime_checkable
@ -61,3 +67,54 @@ class AsyncStager(Protocol):
is complete and it is safe to begin modifying the original `state_dict`
"""
pass
class BlockingAsyncStager(AsyncStager):
"""
An implementation of AsyncStager which stages the state_dict on CPU RAM and blocks until the copy is complete.
This implementation also provides an option to optimize stage latency using pinned memory.
N.B. synchronize_staging is a no-op in this case.
"""
# default to True since the common case is to stage synchronously
_synchronize_after_execute: bool = False
def __init__(
self,
cache_staged_state_dict: bool = False,
type_check: bool = False,
):
"""
Initializes the BlockingAsyncStager.
Args:
cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency
at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation
that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False.
type_check: Whether to perform a type check during cpu_offload. Defaults to False.
"""
self.cache_staged_state_dict = cache_staged_state_dict
self.type_check = type_check
self.state_dict_cache: Optional[STATE_DICT_TYPE] = None
def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
"""
Returns a copy of `state_dict` on the CPU.
"""
if not self.cache_staged_state_dict:
return _offload_state_dict_to_cpu(state_dict, type_check=self.type_check)
if self.state_dict_cache is None:
self.state_dict_cache = _create_cpu_state_dict(state_dict, pin_memory=True)
return _copy_state_dict(state_dict, self.state_dict_cache)
def synchronize_staging(self) -> None:
"""
No-op function, since staging is blocking.
"""
pass