mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dcp] Integrate stream extensions into DCP impl (#143359)
Summary: Updates FileSystemReader/Writer, Planner, DefaultLoad/SavePlanner Pull Request resolved: https://github.com/pytorch/pytorch/pull/143359 Approved by: https://github.com/saumishr ghstack dependencies: #143358
This commit is contained in:
parent
ba3f1c29ee
commit
9c909bf3bb
5 changed files with 251 additions and 19 deletions
|
|
@ -8,13 +8,21 @@ from torch.distributed._tensor import (
|
|||
Shard,
|
||||
zeros,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
from torch.testing._internal.distributed.checkpoint_utils import (
|
||||
get_test_extension_registry,
|
||||
Rot13Example,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
CHECKPOINT_DIR = "checkpoint"
|
||||
|
|
@ -41,6 +49,7 @@ for p1 in TWO_D_PLACEMENTS:
|
|||
TWO_D_TO_TWO_D_PLACEMENTS.append((p1, p2))
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestDTensorReshardPlacementChange(DTensorTestBase):
|
||||
"""
|
||||
Test DCP reshard for DTensor with placements changes and without world_size change and mesh_tensor change.
|
||||
|
|
@ -49,7 +58,8 @@ class TestDTensorReshardPlacementChange(DTensorTestBase):
|
|||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_temp_dir
|
||||
def test_1d_to_1d_reshard_placement_change(self) -> None:
|
||||
@parametrize("extensions", [None, [Rot13Example()]])
|
||||
def test_1d_to_1d_reshard_placement_change(self, extensions) -> None:
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
for one_d_to_one_d_placements in ONE_D_TO_ONE_D_PLACEMENTS:
|
||||
|
|
@ -65,7 +75,9 @@ class TestDTensorReshardPlacementChange(DTensorTestBase):
|
|||
|
||||
dist_cp.save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
|
||||
storage_writer=dist_cp.FileSystemWriter(
|
||||
path=CHECKPOINT_DIR, _extensions=extensions
|
||||
),
|
||||
planner=dist_cp.DefaultSavePlanner(),
|
||||
)
|
||||
|
||||
|
|
@ -76,7 +88,9 @@ class TestDTensorReshardPlacementChange(DTensorTestBase):
|
|||
|
||||
dist_cp.load(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
|
||||
storage_reader=dist_cp.FileSystemReader(
|
||||
CHECKPOINT_DIR, _extension_registry=get_test_extension_registry()
|
||||
),
|
||||
planner=dist_cp.DefaultLoadPlanner(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ from torch.distributed.checkpoint import (
|
|||
)
|
||||
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TestCase,
|
||||
|
|
@ -35,6 +37,10 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
|
|||
from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
|
||||
MyShardedModel1,
|
||||
)
|
||||
from torch.testing._internal.distributed.checkpoint_utils import (
|
||||
get_test_extension_registry,
|
||||
Rot13Example,
|
||||
)
|
||||
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
|
|
@ -159,7 +165,8 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
|||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl()
|
||||
def test_read_write_shard_tensor(self) -> None:
|
||||
@parametrize("extensions", [None, [Rot13Example()]])
|
||||
def test_read_write_shard_tensor(self, extensions) -> None:
|
||||
paths = [tempfile.mkdtemp()]
|
||||
dist.broadcast_object_list(paths)
|
||||
|
||||
|
|
@ -180,7 +187,7 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
|||
model_to_save._register_state_dict_hook(state_dict_hook)
|
||||
state_dict_to_save = model_to_save.state_dict()
|
||||
|
||||
fs_writer = FileSystemWriter(path=path)
|
||||
fs_writer = FileSystemWriter(path=path, _extensions=extensions)
|
||||
save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
|
||||
|
||||
dist.barrier()
|
||||
|
|
@ -198,7 +205,9 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
|||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
# Test load.
|
||||
fs_reader = FileSystemReader(path=path)
|
||||
fs_reader = FileSystemReader(
|
||||
path=path, _extension_registry=get_test_extension_registry()
|
||||
)
|
||||
load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader)
|
||||
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
|
@ -494,5 +503,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
|||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Dict
|
||||
from typing import Any, Dict, IO
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -17,9 +17,12 @@ from torch.distributed._shard.sharding_spec import (
|
|||
from torch.distributed.checkpoint import (
|
||||
FileSystemReader,
|
||||
FileSystemWriter,
|
||||
load,
|
||||
load_state_dict,
|
||||
save,
|
||||
save_state_dict,
|
||||
)
|
||||
from torch.distributed.checkpoint.stateful import Stateful
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
|
|
@ -34,6 +37,10 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
|
|||
from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
|
||||
MyShardedModel1,
|
||||
)
|
||||
from torch.testing._internal.distributed.checkpoint_utils import (
|
||||
get_test_extension_registry,
|
||||
Rot13Example,
|
||||
)
|
||||
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
|
|
@ -76,6 +83,8 @@ def assert_state_dict_equal(
|
|||
torch.equal(value_1, value_2),
|
||||
f"Key {key}'s tensor does not match",
|
||||
)
|
||||
elif isinstance(value_1, Stateful):
|
||||
self.assertEqual(value_1, value_2)
|
||||
|
||||
return True
|
||||
|
||||
|
|
@ -100,6 +109,23 @@ class MyShardedModel3(torch.nn.Module):
|
|||
)
|
||||
|
||||
|
||||
class BlobState:
|
||||
def __init__(self, value: IO[bytes]) -> Any:
|
||||
self.state = {"blob": value}
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
return self.state
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
self.state = state_dict
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, BlobState) and self.state == other.state
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"BlobState({self.state['blob']})"
|
||||
|
||||
|
||||
class TestDistributedStateDictSaveLoad(TestCase):
|
||||
@parametrize("thread_count", _THREAD_COUNTS)
|
||||
def test_read_write_only_tensor(self, thread_count) -> None:
|
||||
|
|
@ -129,6 +155,42 @@ class TestDistributedStateDictSaveLoad(TestCase):
|
|||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
|
||||
class TestDistributedStateDictSaveLoadRot13(TestCase):
|
||||
@parametrize("thread_count", _THREAD_COUNTS)
|
||||
def test_read_write_tensor_and_blob(self, thread_count) -> None:
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
state_dict_to_save = MyTestModule().state_dict()
|
||||
state_dict_to_save["test_blob"] = BlobState(b"SomeBlobForTesting")
|
||||
|
||||
fs_writer = FileSystemWriter(
|
||||
path=path,
|
||||
thread_count=thread_count,
|
||||
_extensions=[Rot13Example()],
|
||||
)
|
||||
save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=fs_writer,
|
||||
)
|
||||
|
||||
state_dict_to_load_to = MyTestModule().state_dict()
|
||||
state_dict_to_load_to["test_blob"] = BlobState(b"")
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
# Load from file without any resharding. Note there is no extension
|
||||
# specification here; it is determined dynamically from the metadata.
|
||||
fs_reader = FileSystemReader(
|
||||
path=path, _extension_registry=get_test_extension_registry()
|
||||
)
|
||||
load(
|
||||
state_dict=state_dict_to_load_to,
|
||||
storage_reader=fs_reader,
|
||||
)
|
||||
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
|
||||
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
|
|
@ -461,6 +523,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
|||
|
||||
|
||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoad)
|
||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13)
|
||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
|
||||
instantiate_parametrized_tests(TestDistributedReshardOnLoad)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,12 +3,14 @@
|
|||
|
||||
import io
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from fsspec.core import url_to_fs
|
||||
|
||||
from torch.distributed.checkpoint._extension import StreamTransformExtension
|
||||
from torch.distributed.checkpoint.filesystem import (
|
||||
FileSystemBase,
|
||||
FileSystemReader,
|
||||
|
|
@ -110,6 +112,7 @@ class FsspecWriter(FileSystemWriter):
|
|||
thread_count: int = 1,
|
||||
per_thread_copy_ahead: int = 10_000_000,
|
||||
overwrite: bool = True,
|
||||
_extensions: Optional[Sequence[StreamTransformExtension]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the writer pointing to `path`.
|
||||
|
|
@ -121,6 +124,7 @@ class FsspecWriter(FileSystemWriter):
|
|||
thread_count: Number of IO threads to use to write. Default to 1.
|
||||
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
|
||||
overwrite: Whether to allow overwriting existing checkpoints. Defaults to True.
|
||||
_extensions: Extensions to apply to output streams (EXPERIMENTAL)
|
||||
|
||||
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
|
||||
"""
|
||||
|
|
@ -131,6 +135,7 @@ class FsspecWriter(FileSystemWriter):
|
|||
thread_count,
|
||||
per_thread_copy_ahead,
|
||||
overwrite=overwrite,
|
||||
_extensions=_extensions,
|
||||
)
|
||||
self.fs = FileSystem()
|
||||
self.path = self.fs.init_path(path)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import threading
|
|||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from io import UnsupportedOperation
|
||||
|
|
@ -28,10 +29,17 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
# introduced as collections.abc.Buffer in Python 3.12
|
||||
from typing_extensions import Buffer
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._utils import _get_available_device_type, _get_device_module
|
||||
from torch.distributed._shard._utils import narrow_tensor_by_index
|
||||
from torch.distributed.checkpoint._extension import (
|
||||
ExtensionRegistry,
|
||||
StreamTransformExtension,
|
||||
)
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
Metadata,
|
||||
MetadataIndex,
|
||||
|
|
@ -70,6 +78,10 @@ class _StorageInfo:
|
|||
relative_path: str
|
||||
offset: int
|
||||
length: int
|
||||
transform_descriptors: Optional[Sequence[str]] = None
|
||||
|
||||
def __getstate__(self):
|
||||
return {k: v for k, v in self.__dict__.items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -211,6 +223,57 @@ class _OverlappingCpuLoader(_TensorLoader):
|
|||
yield from self._finish()
|
||||
|
||||
|
||||
class _StorageWriterTransforms:
|
||||
"""
|
||||
This is experimental, and will likely move elsewhere in the
|
||||
future. It lives here to minimize changes while we are still
|
||||
learning and gathering feedback.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, extensions: Optional[Sequence[StreamTransformExtension]] = None
|
||||
) -> None:
|
||||
"""
|
||||
If the extensions arg is None, this means the implementation
|
||||
should provide whatever defaults it chooses. An empty
|
||||
sequence indicates no extensions should be used. At this
|
||||
time, the default extensions sequence is empty.
|
||||
"""
|
||||
self.extensions = () if extensions is None else extensions
|
||||
|
||||
def transform_save_stream(
|
||||
self, write_item: WriteItem, raw_stream: io.IOBase
|
||||
) -> tuple[IO[bytes], List[str]]:
|
||||
# In order to avoid leaking fds, transformers' close must
|
||||
# cascade to wrapped streams, but since this function can
|
||||
# append to the raw stream, we can't close the actual stream.
|
||||
# So, we use this to put a wrapper around the raw stream's
|
||||
# close() to make it a noop, and it gets closed once all files
|
||||
# are appended.
|
||||
|
||||
class NoCloseWriter(io.IOBase):
|
||||
def __init__(self, raw: io.IOBase):
|
||||
self.raw = raw
|
||||
|
||||
def writeable(self) -> bool:
|
||||
return True
|
||||
|
||||
def write(self, b: Buffer) -> int:
|
||||
return self.raw.write(b)
|
||||
|
||||
def close(self):
|
||||
self.flush()
|
||||
self.raw.flush()
|
||||
# but not close.
|
||||
|
||||
transform_to = cast(IO[bytes], NoCloseWriter(raw_stream))
|
||||
|
||||
for ex in self.extensions:
|
||||
transform_to = ex.transform_to(transform_to)
|
||||
|
||||
return (transform_to, [ex.get_descriptor() for ex in reversed(self.extensions)])
|
||||
|
||||
|
||||
def _item_size(item: WriteItem) -> int:
|
||||
size = 1
|
||||
assert item.tensor_data is not None
|
||||
|
|
@ -247,6 +310,7 @@ def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[Writ
|
|||
|
||||
|
||||
def _write_item(
|
||||
transforms: _StorageWriterTransforms,
|
||||
stream: io.IOBase,
|
||||
data: Union[io.BytesIO, torch.Tensor],
|
||||
write_item: WriteItem,
|
||||
|
|
@ -254,19 +318,36 @@ def _write_item(
|
|||
) -> WriteResult:
|
||||
offset = stream.tell()
|
||||
|
||||
(transform_to, transform_descriptors) = transforms.transform_save_stream(
|
||||
write_item, stream
|
||||
)
|
||||
|
||||
if write_item.type == WriteItemType.BYTE_IO:
|
||||
assert isinstance(data, io.BytesIO)
|
||||
stream.write(data.getbuffer())
|
||||
transform_to.write(data.getbuffer())
|
||||
else:
|
||||
assert isinstance(data, torch.Tensor)
|
||||
assert data.device == torch.device("cpu")
|
||||
torch.save(data, cast(IO[bytes], stream))
|
||||
torch.save(data, transform_to)
|
||||
transform_to.close()
|
||||
|
||||
length = stream.tell() - offset
|
||||
|
||||
# For consistency with earlier versions, leave this field out of the
|
||||
# metadata if there are no extensions.
|
||||
info_transform_descriptors = (
|
||||
None if len(transform_descriptors) == 0 else transform_descriptors
|
||||
)
|
||||
|
||||
return WriteResult(
|
||||
index=write_item.index,
|
||||
size_in_bytes=length,
|
||||
storage_data=_StorageInfo(storage_key, offset, length),
|
||||
storage_data=_StorageInfo(
|
||||
storage_key,
|
||||
offset,
|
||||
length,
|
||||
transform_descriptors=info_transform_descriptors,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -275,6 +356,7 @@ def _write_files_from_queue(
|
|||
file_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
planner: SavePlanner,
|
||||
transforms: _StorageWriterTransforms,
|
||||
inflight_threshhold: int,
|
||||
use_fsync: bool,
|
||||
thread_count: int,
|
||||
|
|
@ -319,13 +401,13 @@ def _write_files_from_queue(
|
|||
for write_item in bytes_w:
|
||||
data = planner.resolve_data(write_item)
|
||||
write_results.append(
|
||||
_write_item(stream, data, write_item, storage_key)
|
||||
_write_item(transforms, stream, data, write_item, storage_key)
|
||||
)
|
||||
|
||||
for tensor, write_item in loader.values():
|
||||
assert tensor.is_cpu
|
||||
write_results.append(
|
||||
_write_item(stream, tensor, write_item, storage_key)
|
||||
_write_item(transforms, stream, tensor, write_item, storage_key)
|
||||
)
|
||||
|
||||
if use_fsync:
|
||||
|
|
@ -333,6 +415,7 @@ def _write_files_from_queue(
|
|||
os.fsync(stream.fileno())
|
||||
except (AttributeError, UnsupportedOperation):
|
||||
os.sync()
|
||||
stream.close()
|
||||
result_queue.put(write_results)
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
|
@ -428,6 +511,7 @@ class FileSystem(FileSystemBase):
|
|||
|
||||
|
||||
class _FileSystemWriter(StorageWriter):
|
||||
|
||||
"""
|
||||
Basic implementation of StorageWriter using file IO.
|
||||
|
||||
|
|
@ -449,6 +533,7 @@ class _FileSystemWriter(StorageWriter):
|
|||
thread_count: int = 1,
|
||||
per_thread_copy_ahead: int = 10_000_000,
|
||||
overwrite: bool = True,
|
||||
_extensions: Optional[Sequence[StreamTransformExtension]] = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
|
@ -462,6 +547,7 @@ class _FileSystemWriter(StorageWriter):
|
|||
thread_count: Number of IO threads to use to write. Default to 1.
|
||||
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
|
||||
overwrite: Whether to allow overwriting existing checkpoints. Defaults to True.
|
||||
_extensions: Extensions to apply to output streams (EXPERIMENTAL)
|
||||
|
||||
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
|
||||
"""
|
||||
|
|
@ -474,6 +560,7 @@ class _FileSystemWriter(StorageWriter):
|
|||
self.per_thread_copy_ahead = per_thread_copy_ahead
|
||||
self.save_id = _generate_uuid()
|
||||
self.overwrite = overwrite
|
||||
self.transforms = _StorageWriterTransforms(_extensions)
|
||||
|
||||
def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
|
||||
if checkpoint_id:
|
||||
|
|
@ -541,6 +628,7 @@ class _FileSystemWriter(StorageWriter):
|
|||
file_queue,
|
||||
result_queue,
|
||||
planner,
|
||||
self.transforms,
|
||||
self.per_thread_copy_ahead,
|
||||
self.sync_files,
|
||||
self.thread_count,
|
||||
|
|
@ -554,6 +642,7 @@ class _FileSystemWriter(StorageWriter):
|
|||
file_queue=file_queue,
|
||||
result_queue=result_queue,
|
||||
planner=planner,
|
||||
transforms=self.transforms,
|
||||
inflight_threshhold=self.per_thread_copy_ahead,
|
||||
use_fsync=self.sync_files,
|
||||
thread_count=self.thread_count,
|
||||
|
|
@ -613,16 +702,47 @@ class _FileSystemWriter(StorageWriter):
|
|||
return FileSystem.validate_checkpoint_id(checkpoint_id)
|
||||
|
||||
|
||||
class _StorageReaderTransforms:
|
||||
"""
|
||||
This is experimental, and will likely move elsewhere in the
|
||||
future. It lives here to minimize changes while we are still
|
||||
learning and gathering feedback.
|
||||
"""
|
||||
|
||||
def __init__(self, extension_registry: Optional[ExtensionRegistry] = None) -> None:
|
||||
self.extension_registry = (
|
||||
ExtensionRegistry() if extension_registry is None else extension_registry
|
||||
)
|
||||
|
||||
def transform_load_stream(
|
||||
self,
|
||||
read_item: ReadItem,
|
||||
transform_descriptors: Sequence[str],
|
||||
raw_stream: IO[bytes],
|
||||
) -> IO[bytes]:
|
||||
extensions = self.extension_registry.from_descriptor_list(transform_descriptors)
|
||||
transform_from = raw_stream
|
||||
for ex in extensions:
|
||||
if isinstance(ex, StreamTransformExtension):
|
||||
transform_from = ex.transform_from(transform_from)
|
||||
return transform_from
|
||||
|
||||
|
||||
class FileSystemReader(StorageReader):
|
||||
def __init__(self, path: Union[str, os.PathLike]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
path: Union[str, os.PathLike],
|
||||
_extension_registry: Optional[ExtensionRegistry] = None, # EXPERIMENTAL
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.fs = FileSystem()
|
||||
self.path = self.fs.init_path(path)
|
||||
self.storage_data: Dict[MetadataIndex, _StorageInfo] = {}
|
||||
self.load_id = _generate_uuid()
|
||||
self.transforms = _StorageReaderTransforms(_extension_registry)
|
||||
|
||||
def _slice_file(self, file, sinfo: _StorageInfo) -> io.IOBase:
|
||||
return _create_file_view(file, sinfo.offset, sinfo.length)
|
||||
def _slice_file(self, file, sinfo: _StorageInfo) -> IO[bytes]:
|
||||
return cast(IO[bytes], _create_file_view(file, sinfo.offset, sinfo.length))
|
||||
|
||||
def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
|
||||
self.storage_data = {}
|
||||
|
|
@ -645,15 +765,31 @@ class FileSystemReader(StorageReader):
|
|||
for req in reqs:
|
||||
item_md = self.storage_data[req.storage_index]
|
||||
file_slice = self._slice_file(stream, item_md)
|
||||
transform_from = self.transforms.transform_load_stream(
|
||||
req,
|
||||
# This field wasn't present in older
|
||||
# implementations so provide a fallback.
|
||||
item_md.transform_descriptors or (),
|
||||
file_slice,
|
||||
)
|
||||
|
||||
if req.type == LoadItemType.BYTE_IO:
|
||||
read_bytes = io.BytesIO(file_slice.read(item_md.length))
|
||||
read_bytes = io.BytesIO(transform_from.read(-1))
|
||||
read_bytes.seek(0)
|
||||
planner.load_bytes(req, read_bytes)
|
||||
else:
|
||||
if transform_from.seekable():
|
||||
seekable = transform_from
|
||||
else:
|
||||
# torch.load requires a seekable input, so read the transform
|
||||
# stream now and store the output if needed
|
||||
seekable = io.BytesIO(transform_from.read(-1))
|
||||
seekable.seek(0)
|
||||
|
||||
tensor = cast(
|
||||
Tensor,
|
||||
torch.load(
|
||||
cast(IO[bytes], file_slice),
|
||||
seekable,
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
),
|
||||
|
|
@ -730,6 +866,7 @@ class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager):
|
|||
per_thread_copy_ahead: int = 10_000_000,
|
||||
cache_staged_state_dict: bool = False,
|
||||
overwrite: bool = True,
|
||||
_extensions: Optional[Sequence[StreamTransformExtension]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the writer pointing to `path`.
|
||||
|
|
@ -744,6 +881,7 @@ class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager):
|
|||
at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation
|
||||
that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False.
|
||||
overwrite: Whether to allow overwriting existing checkpoints. Defaults to True.
|
||||
_extensions: Extensions to apply to output streams (EXPERIMENTAL)
|
||||
|
||||
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
|
||||
"""
|
||||
|
|
@ -755,6 +893,7 @@ class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager):
|
|||
thread_count=thread_count,
|
||||
per_thread_copy_ahead=per_thread_copy_ahead,
|
||||
overwrite=overwrite,
|
||||
_extensions=_extensions,
|
||||
)
|
||||
BlockingAsyncStager.__init__(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Reference in a new issue