mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[DCP] Provides default AsyncStager (#124939)
Differential Revision: [D56575987](https://our.internmc.facebook.com/intern/diff/D56575987/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124939 Approved by: https://github.com/fegin ghstack dependencies: #122965
This commit is contained in:
parent
3741fb3680
commit
799f1460af
4 changed files with 131 additions and 7 deletions
|
|
@ -36,6 +36,9 @@ The following module is also useful for additional customization of the staging
|
|||
.. autoclass:: torch.distributed.checkpoint.staging.AsyncStager
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.distributed.checkpoint.staging.BlockingAsyncStager
|
||||
:members:
|
||||
|
||||
In addition to the above entrypoints, `Stateful` objects, as described below, provide additional customization during saving/loading
|
||||
.. automodule:: torch.distributed.checkpoint.stateful
|
||||
|
||||
|
|
|
|||
|
|
@ -211,10 +211,18 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
|||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
def test_e2e_async(self):
|
||||
self._run_e2e_test(compile=False, model_type=ModelType.FSDP, async_op=True)
|
||||
@parametrize("cache_staged_state_dict", [False, True])
|
||||
def test_e2e_async_cached(self, cache_staged_state_dict):
|
||||
self._run_e2e_test(
|
||||
compile=False,
|
||||
model_type=ModelType.FSDP,
|
||||
async_op=True,
|
||||
cache_staged_state_dict=cache_staged_state_dict,
|
||||
)
|
||||
|
||||
def _run_e2e_test(self, compile, model_type, async_op=False):
|
||||
def _run_e2e_test(
|
||||
self, compile, model_type, async_op=False, cache_staged_state_dict=False
|
||||
):
|
||||
model, optim = self._create_model(compile, ModelType.NONE)
|
||||
_train(model, optim, train_steps=2)
|
||||
|
||||
|
|
@ -230,7 +238,10 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
|||
}
|
||||
|
||||
if async_op:
|
||||
f = saver.async_save(sd, checkpoint_id=self.temp_dir)
|
||||
writer = DCP.FileSystemWriter(
|
||||
self.temp_dir, cache_staged_state_dict=cache_staged_state_dict
|
||||
)
|
||||
f = saver.async_save(sd, storage_writer=writer)
|
||||
t = time.monotonic()
|
||||
while not f.done():
|
||||
time.sleep(1)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from contextlib import contextmanager
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
|
|
@ -28,6 +29,8 @@ import torch
|
|||
from torch import Tensor
|
||||
from torch._utils import _get_available_device_type, _get_device_module
|
||||
from torch.distributed._shard._utils import narrow_tensor_by_index
|
||||
from torch.distributed.checkpoint.staging import BlockingAsyncStager
|
||||
|
||||
from torch.futures import Future
|
||||
|
||||
from .metadata import Metadata, MetadataIndex
|
||||
|
|
@ -393,7 +396,7 @@ class FileSystem(FileSystemBase):
|
|||
return False
|
||||
|
||||
|
||||
class FileSystemWriter(StorageWriter):
|
||||
class _FileSystemWriter(StorageWriter):
|
||||
"""
|
||||
Basic implementation of StorageWriter using file IO.
|
||||
|
||||
|
|
@ -414,6 +417,8 @@ class FileSystemWriter(StorageWriter):
|
|||
sync_files: bool = True,
|
||||
thread_count: int = 1,
|
||||
per_thread_copy_ahead: int = 10_000_000,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the writer pointing to `path`.
|
||||
|
|
@ -631,3 +636,51 @@ class FileSystemReader(StorageReader):
|
|||
@classmethod
|
||||
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
|
||||
return FileSystem.validate_checkpoint_id(checkpoint_id)
|
||||
|
||||
|
||||
class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager):
|
||||
"""
|
||||
Basic implementation of StorageWriter using file IO.
|
||||
|
||||
This implementation makes the following assumptions and simplifications:
|
||||
|
||||
* The checkpoint path is an empty or non-existing directory.
|
||||
* File creation is atomic
|
||||
|
||||
The checkpoint consist of one file per write request plus
|
||||
a `.metadata` file with the serialized metadata.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: Union[str, os.PathLike],
|
||||
single_file_per_rank: bool = True,
|
||||
sync_files: bool = True,
|
||||
thread_count: int = 1,
|
||||
per_thread_copy_ahead: int = 10_000_000,
|
||||
cache_staged_state_dict: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the writer pointing to `path`.
|
||||
|
||||
Args:
|
||||
path: directory where the checkpoint will be written to.
|
||||
single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
|
||||
sync_files : force files to be synced to permanent storage. Default to True.
|
||||
thread_count: Number of IO threads to use to write. Default to 1.
|
||||
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
|
||||
cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency
|
||||
at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation
|
||||
that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False.
|
||||
|
||||
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
|
||||
"""
|
||||
super().__init__(
|
||||
path=path,
|
||||
single_file_per_rank=single_file_per_rank,
|
||||
sync_files=sync_files,
|
||||
thread_count=thread_count,
|
||||
per_thread_copy_ahead=per_thread_copy_ahead,
|
||||
cache_staged_state_dict=cache_staged_state_dict,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,16 @@
|
|||
from typing import runtime_checkable
|
||||
from typing import Optional, runtime_checkable
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from torch.distributed._state_dict_utils import (
|
||||
_copy_state_dict,
|
||||
_create_cpu_state_dict,
|
||||
_offload_state_dict_to_cpu,
|
||||
)
|
||||
|
||||
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
|
||||
|
||||
__all__ = ["AsyncStager"]
|
||||
__all__ = ["AsyncStager", "BlockingAsyncStager"]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
@ -61,3 +67,54 @@ class AsyncStager(Protocol):
|
|||
is complete and it is safe to begin modifying the original `state_dict`
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BlockingAsyncStager(AsyncStager):
|
||||
"""
|
||||
An implementation of AsyncStager which stages the state_dict on CPU RAM and blocks until the copy is complete.
|
||||
This implementation also provides an option to optimize stage latency using pinned memory.
|
||||
|
||||
N.B. synchronize_staging is a no-op in this case.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# default to True since the common case is to stage synchronously
|
||||
_synchronize_after_execute: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_staged_state_dict: bool = False,
|
||||
type_check: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes the BlockingAsyncStager.
|
||||
|
||||
Args:
|
||||
cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency
|
||||
at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation
|
||||
that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False.
|
||||
type_check: Whether to perform a type check during cpu_offload. Defaults to False.
|
||||
|
||||
"""
|
||||
self.cache_staged_state_dict = cache_staged_state_dict
|
||||
self.type_check = type_check
|
||||
self.state_dict_cache: Optional[STATE_DICT_TYPE] = None
|
||||
|
||||
def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
|
||||
"""
|
||||
Returns a copy of `state_dict` on the CPU.
|
||||
"""
|
||||
|
||||
if not self.cache_staged_state_dict:
|
||||
return _offload_state_dict_to_cpu(state_dict, type_check=self.type_check)
|
||||
|
||||
if self.state_dict_cache is None:
|
||||
self.state_dict_cache = _create_cpu_state_dict(state_dict, pin_memory=True)
|
||||
return _copy_state_dict(state_dict, self.state_dict_cache)
|
||||
|
||||
def synchronize_staging(self) -> None:
|
||||
"""
|
||||
No-op function, since staging is blocking.
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue