[dcp] Integrate stream extensions into DCP impl (#143359)

Summary: Updates FileSystemReader/Writer, Planner, DefaultLoad/SavePlanner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143359
Approved by: https://github.com/saumishr
ghstack dependencies: #143358
This commit is contained in:
Marc Horowitz 2025-01-15 18:54:43 -08:00 committed by PyTorch MergeBot
parent ba3f1c29ee
commit 9c909bf3bb
5 changed files with 251 additions and 19 deletions

View file

@ -8,13 +8,21 @@ from torch.distributed._tensor import (
Shard,
zeros,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
from torch.testing._internal.distributed.checkpoint_utils import (
get_test_extension_registry,
Rot13Example,
with_temp_dir,
)
CHECKPOINT_DIR = "checkpoint"
@ -41,6 +49,7 @@ for p1 in TWO_D_PLACEMENTS:
TWO_D_TO_TWO_D_PLACEMENTS.append((p1, p2))
@instantiate_parametrized_tests
class TestDTensorReshardPlacementChange(DTensorTestBase):
"""
Test DCP reshard for DTensor with placements changes and without world_size change and mesh_tensor change.
@ -49,7 +58,8 @@ class TestDTensorReshardPlacementChange(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(2)
@with_temp_dir
def test_1d_to_1d_reshard_placement_change(self) -> None:
@parametrize("extensions", [None, [Rot13Example()]])
def test_1d_to_1d_reshard_placement_change(self, extensions) -> None:
CHECKPOINT_DIR = self.temp_dir
for one_d_to_one_d_placements in ONE_D_TO_ONE_D_PLACEMENTS:
@ -65,7 +75,9 @@ class TestDTensorReshardPlacementChange(DTensorTestBase):
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
storage_writer=dist_cp.FileSystemWriter(
path=CHECKPOINT_DIR, _extensions=extensions
),
planner=dist_cp.DefaultSavePlanner(),
)
@ -76,7 +88,9 @@ class TestDTensorReshardPlacementChange(DTensorTestBase):
dist_cp.load(
state_dict=state_dict_to_load,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
storage_reader=dist_cp.FileSystemReader(
CHECKPOINT_DIR, _extension_registry=get_test_extension_registry()
),
planner=dist_cp.DefaultLoadPlanner(),
)

View file

@ -24,6 +24,8 @@ from torch.distributed.checkpoint import (
)
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
TestCase,
@ -35,6 +37,10 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
MyShardedModel1,
)
from torch.testing._internal.distributed.checkpoint_utils import (
get_test_extension_registry,
Rot13Example,
)
if TEST_WITH_DEV_DBG_ASAN:
@ -159,7 +165,8 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_nccl()
def test_read_write_shard_tensor(self) -> None:
@parametrize("extensions", [None, [Rot13Example()]])
def test_read_write_shard_tensor(self, extensions) -> None:
paths = [tempfile.mkdtemp()]
dist.broadcast_object_list(paths)
@ -180,7 +187,7 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
model_to_save._register_state_dict_hook(state_dict_hook)
state_dict_to_save = model_to_save.state_dict()
fs_writer = FileSystemWriter(path=path)
fs_writer = FileSystemWriter(path=path, _extensions=extensions)
save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
dist.barrier()
@ -198,7 +205,9 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
# Test load.
fs_reader = FileSystemReader(path=path)
fs_reader = FileSystemReader(
path=path, _extension_registry=get_test_extension_registry()
)
load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader)
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
@ -494,5 +503,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
)
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
if __name__ == "__main__":
run_tests()

View file

@ -2,7 +2,7 @@
import sys
import tempfile
from typing import Dict
from typing import Any, Dict, IO
import torch
import torch.distributed as dist
@ -17,9 +17,12 @@ from torch.distributed._shard.sharding_spec import (
from torch.distributed.checkpoint import (
FileSystemReader,
FileSystemWriter,
load,
load_state_dict,
save,
save_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -34,6 +37,10 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
MyShardedModel1,
)
from torch.testing._internal.distributed.checkpoint_utils import (
get_test_extension_registry,
Rot13Example,
)
if TEST_WITH_DEV_DBG_ASAN:
@ -76,6 +83,8 @@ def assert_state_dict_equal(
torch.equal(value_1, value_2),
f"Key {key}'s tensor does not match",
)
elif isinstance(value_1, Stateful):
self.assertEqual(value_1, value_2)
return True
@ -100,6 +109,23 @@ class MyShardedModel3(torch.nn.Module):
)
class BlobState:
def __init__(self, value: IO[bytes]) -> Any:
self.state = {"blob": value}
def state_dict(self) -> Dict[str, Any]:
return self.state
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.state = state_dict
def __eq__(self, other: object) -> bool:
return isinstance(other, BlobState) and self.state == other.state
def __repr__(self) -> str:
return f"BlobState({self.state['blob']})"
class TestDistributedStateDictSaveLoad(TestCase):
@parametrize("thread_count", _THREAD_COUNTS)
def test_read_write_only_tensor(self, thread_count) -> None:
@ -129,6 +155,42 @@ class TestDistributedStateDictSaveLoad(TestCase):
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
class TestDistributedStateDictSaveLoadRot13(TestCase):
@parametrize("thread_count", _THREAD_COUNTS)
def test_read_write_tensor_and_blob(self, thread_count) -> None:
with tempfile.TemporaryDirectory() as path:
state_dict_to_save = MyTestModule().state_dict()
state_dict_to_save["test_blob"] = BlobState(b"SomeBlobForTesting")
fs_writer = FileSystemWriter(
path=path,
thread_count=thread_count,
_extensions=[Rot13Example()],
)
save(
state_dict=state_dict_to_save,
storage_writer=fs_writer,
)
state_dict_to_load_to = MyTestModule().state_dict()
state_dict_to_load_to["test_blob"] = BlobState(b"")
with self.assertRaises(AssertionError):
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
# Load from file without any resharding. Note there is no extension
# specification here; it is determined dynamically from the metadata.
fs_reader = FileSystemReader(
path=path, _extension_registry=get_test_extension_registry()
)
load(
state_dict=state_dict_to_load_to,
storage_reader=fs_reader,
)
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
@property
def world_size(self) -> int:
@ -461,6 +523,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
instantiate_parametrized_tests(TestDistributedStateDictSaveLoad)
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13)
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
instantiate_parametrized_tests(TestDistributedReshardOnLoad)

View file

@ -3,12 +3,14 @@
import io
import os
from collections.abc import Sequence
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Optional, TYPE_CHECKING, Union
from fsspec.core import url_to_fs
from torch.distributed.checkpoint._extension import StreamTransformExtension
from torch.distributed.checkpoint.filesystem import (
FileSystemBase,
FileSystemReader,
@ -110,6 +112,7 @@ class FsspecWriter(FileSystemWriter):
thread_count: int = 1,
per_thread_copy_ahead: int = 10_000_000,
overwrite: bool = True,
_extensions: Optional[Sequence[StreamTransformExtension]] = None,
) -> None:
"""
Initialize the writer pointing to `path`.
@ -121,6 +124,7 @@ class FsspecWriter(FileSystemWriter):
thread_count: Number of IO threads to use to write. Default to 1.
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
overwrite: Whether to allow overwriting existing checkpoints. Defaults to True.
_extensions: Extensions to apply to output streams (EXPERIMENTAL)
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
"""
@ -131,6 +135,7 @@ class FsspecWriter(FileSystemWriter):
thread_count,
per_thread_copy_ahead,
overwrite=overwrite,
_extensions=_extensions,
)
self.fs = FileSystem()
self.path = self.fs.init_path(path)

View file

@ -10,6 +10,7 @@ import threading
import uuid
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from io import UnsupportedOperation
@ -28,10 +29,17 @@ from typing import (
Union,
)
# introduced as collections.abc.Buffer in Python 3.12
from typing_extensions import Buffer
import torch
from torch import Tensor
from torch._utils import _get_available_device_type, _get_device_module
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint._extension import (
ExtensionRegistry,
StreamTransformExtension,
)
from torch.distributed.checkpoint.metadata import (
Metadata,
MetadataIndex,
@ -70,6 +78,10 @@ class _StorageInfo:
relative_path: str
offset: int
length: int
transform_descriptors: Optional[Sequence[str]] = None
def __getstate__(self):
return {k: v for k, v in self.__dict__.items() if v is not None}
@dataclass
@ -211,6 +223,57 @@ class _OverlappingCpuLoader(_TensorLoader):
yield from self._finish()
class _StorageWriterTransforms:
"""
This is experimental, and will likely move elsewhere in the
future. It lives here to minimize changes while we are still
learning and gathering feedback.
"""
def __init__(
self, extensions: Optional[Sequence[StreamTransformExtension]] = None
) -> None:
"""
If the extensions arg is None, this means the implementation
should provide whatever defaults it chooses. An empty
sequence indicates no extensions should be used. At this
time, the default extensions sequence is empty.
"""
self.extensions = () if extensions is None else extensions
def transform_save_stream(
self, write_item: WriteItem, raw_stream: io.IOBase
) -> tuple[IO[bytes], List[str]]:
# In order to avoid leaking fds, transformers' close must
# cascade to wrapped streams, but since this function can
# append to the raw stream, we can't close the actual stream.
# So, we use this to put a wrapper around the raw stream's
# close() to make it a noop, and it gets closed once all files
# are appended.
class NoCloseWriter(io.IOBase):
def __init__(self, raw: io.IOBase):
self.raw = raw
def writeable(self) -> bool:
return True
def write(self, b: Buffer) -> int:
return self.raw.write(b)
def close(self):
self.flush()
self.raw.flush()
# but not close.
transform_to = cast(IO[bytes], NoCloseWriter(raw_stream))
for ex in self.extensions:
transform_to = ex.transform_to(transform_to)
return (transform_to, [ex.get_descriptor() for ex in reversed(self.extensions)])
def _item_size(item: WriteItem) -> int:
size = 1
assert item.tensor_data is not None
@ -247,6 +310,7 @@ def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[Writ
def _write_item(
transforms: _StorageWriterTransforms,
stream: io.IOBase,
data: Union[io.BytesIO, torch.Tensor],
write_item: WriteItem,
@ -254,19 +318,36 @@ def _write_item(
) -> WriteResult:
offset = stream.tell()
(transform_to, transform_descriptors) = transforms.transform_save_stream(
write_item, stream
)
if write_item.type == WriteItemType.BYTE_IO:
assert isinstance(data, io.BytesIO)
stream.write(data.getbuffer())
transform_to.write(data.getbuffer())
else:
assert isinstance(data, torch.Tensor)
assert data.device == torch.device("cpu")
torch.save(data, cast(IO[bytes], stream))
torch.save(data, transform_to)
transform_to.close()
length = stream.tell() - offset
# For consistency with earlier versions, leave this field out of the
# metadata if there are no extensions.
info_transform_descriptors = (
None if len(transform_descriptors) == 0 else transform_descriptors
)
return WriteResult(
index=write_item.index,
size_in_bytes=length,
storage_data=_StorageInfo(storage_key, offset, length),
storage_data=_StorageInfo(
storage_key,
offset,
length,
transform_descriptors=info_transform_descriptors,
),
)
@ -275,6 +356,7 @@ def _write_files_from_queue(
file_queue: queue.Queue,
result_queue: queue.Queue,
planner: SavePlanner,
transforms: _StorageWriterTransforms,
inflight_threshhold: int,
use_fsync: bool,
thread_count: int,
@ -319,13 +401,13 @@ def _write_files_from_queue(
for write_item in bytes_w:
data = planner.resolve_data(write_item)
write_results.append(
_write_item(stream, data, write_item, storage_key)
_write_item(transforms, stream, data, write_item, storage_key)
)
for tensor, write_item in loader.values():
assert tensor.is_cpu
write_results.append(
_write_item(stream, tensor, write_item, storage_key)
_write_item(transforms, stream, tensor, write_item, storage_key)
)
if use_fsync:
@ -333,6 +415,7 @@ def _write_files_from_queue(
os.fsync(stream.fileno())
except (AttributeError, UnsupportedOperation):
os.sync()
stream.close()
result_queue.put(write_results)
except queue.Empty:
pass
@ -428,6 +511,7 @@ class FileSystem(FileSystemBase):
class _FileSystemWriter(StorageWriter):
"""
Basic implementation of StorageWriter using file IO.
@ -449,6 +533,7 @@ class _FileSystemWriter(StorageWriter):
thread_count: int = 1,
per_thread_copy_ahead: int = 10_000_000,
overwrite: bool = True,
_extensions: Optional[Sequence[StreamTransformExtension]] = None,
*args: Any,
**kwargs: Any,
) -> None:
@ -462,6 +547,7 @@ class _FileSystemWriter(StorageWriter):
thread_count: Number of IO threads to use to write. Default to 1.
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
overwrite: Whether to allow overwriting existing checkpoints. Defaults to True.
_extensions: Extensions to apply to output streams (EXPERIMENTAL)
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
"""
@ -474,6 +560,7 @@ class _FileSystemWriter(StorageWriter):
self.per_thread_copy_ahead = per_thread_copy_ahead
self.save_id = _generate_uuid()
self.overwrite = overwrite
self.transforms = _StorageWriterTransforms(_extensions)
def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
if checkpoint_id:
@ -541,6 +628,7 @@ class _FileSystemWriter(StorageWriter):
file_queue,
result_queue,
planner,
self.transforms,
self.per_thread_copy_ahead,
self.sync_files,
self.thread_count,
@ -554,6 +642,7 @@ class _FileSystemWriter(StorageWriter):
file_queue=file_queue,
result_queue=result_queue,
planner=planner,
transforms=self.transforms,
inflight_threshhold=self.per_thread_copy_ahead,
use_fsync=self.sync_files,
thread_count=self.thread_count,
@ -613,16 +702,47 @@ class _FileSystemWriter(StorageWriter):
return FileSystem.validate_checkpoint_id(checkpoint_id)
class _StorageReaderTransforms:
"""
This is experimental, and will likely move elsewhere in the
future. It lives here to minimize changes while we are still
learning and gathering feedback.
"""
def __init__(self, extension_registry: Optional[ExtensionRegistry] = None) -> None:
self.extension_registry = (
ExtensionRegistry() if extension_registry is None else extension_registry
)
def transform_load_stream(
self,
read_item: ReadItem,
transform_descriptors: Sequence[str],
raw_stream: IO[bytes],
) -> IO[bytes]:
extensions = self.extension_registry.from_descriptor_list(transform_descriptors)
transform_from = raw_stream
for ex in extensions:
if isinstance(ex, StreamTransformExtension):
transform_from = ex.transform_from(transform_from)
return transform_from
class FileSystemReader(StorageReader):
def __init__(self, path: Union[str, os.PathLike]) -> None:
def __init__(
self,
path: Union[str, os.PathLike],
_extension_registry: Optional[ExtensionRegistry] = None, # EXPERIMENTAL
) -> None:
super().__init__()
self.fs = FileSystem()
self.path = self.fs.init_path(path)
self.storage_data: Dict[MetadataIndex, _StorageInfo] = {}
self.load_id = _generate_uuid()
self.transforms = _StorageReaderTransforms(_extension_registry)
def _slice_file(self, file, sinfo: _StorageInfo) -> io.IOBase:
return _create_file_view(file, sinfo.offset, sinfo.length)
def _slice_file(self, file, sinfo: _StorageInfo) -> IO[bytes]:
return cast(IO[bytes], _create_file_view(file, sinfo.offset, sinfo.length))
def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
self.storage_data = {}
@ -645,15 +765,31 @@ class FileSystemReader(StorageReader):
for req in reqs:
item_md = self.storage_data[req.storage_index]
file_slice = self._slice_file(stream, item_md)
transform_from = self.transforms.transform_load_stream(
req,
# This field wasn't present in older
# implementations so provide a fallback.
item_md.transform_descriptors or (),
file_slice,
)
if req.type == LoadItemType.BYTE_IO:
read_bytes = io.BytesIO(file_slice.read(item_md.length))
read_bytes = io.BytesIO(transform_from.read(-1))
read_bytes.seek(0)
planner.load_bytes(req, read_bytes)
else:
if transform_from.seekable():
seekable = transform_from
else:
# torch.load requires a seekable input, so read the transform
# stream now and store the output if needed
seekable = io.BytesIO(transform_from.read(-1))
seekable.seek(0)
tensor = cast(
Tensor,
torch.load(
cast(IO[bytes], file_slice),
seekable,
map_location="cpu",
weights_only=True,
),
@ -730,6 +866,7 @@ class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager):
per_thread_copy_ahead: int = 10_000_000,
cache_staged_state_dict: bool = False,
overwrite: bool = True,
_extensions: Optional[Sequence[StreamTransformExtension]] = None,
) -> None:
"""
Initialize the writer pointing to `path`.
@ -744,6 +881,7 @@ class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager):
at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation
that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False.
overwrite: Whether to allow overwriting existing checkpoints. Defaults to True.
_extensions: Extensions to apply to output streams (EXPERIMENTAL)
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
"""
@ -755,6 +893,7 @@ class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager):
thread_count=thread_count,
per_thread_copy_ahead=per_thread_copy_ahead,
overwrite=overwrite,
_extensions=_extensions,
)
BlockingAsyncStager.__init__(
self,