mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "[DCP] Makes fsspec public (#121508)"
This reverts commit d482614fec.
Reverted https://github.com/pytorch/pytorch/pull/121508 on behalf of https://github.com/osalpekar due to this causes torchrec tests to fail internally with this error: ModuleNotFoundError: No module named 'fsspec'. see [D54779117](https://www.internalfb.com/diff/D54779117) ([comment](https://github.com/pytorch/pytorch/pull/121508#issuecomment-1992137831))
This commit is contained in:
parent
b84f94f6a3
commit
0398dc9e8e
9 changed files with 245 additions and 196 deletions
|
|
@ -68,20 +68,10 @@ The following types define the planner interface used during checkpoint:
|
|||
|
||||
We provide a filesystem based storage layer:
|
||||
|
||||
.. autoclass:: torch.distributed.checkpoint.filesystem.FileSystemReader
|
||||
.. autoclass:: torch.distributed.checkpoint.FileSystemReader
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.distributed.checkpoint.filesystem.FileSystemWriter
|
||||
:members:
|
||||
|
||||
Additionally, we provide the following abstractions for working with Fsspec storage.
|
||||
|
||||
.. automodule:: torch.distributed.checkpoint.fsspec
|
||||
|
||||
.. autoclass:: torch.distributed.checkpoint.fsspec.FsspecReader
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.distributed.checkpoint.fsspec.FsspecWriter
|
||||
.. autoclass:: torch.distributed.checkpoint.FileSystemWriter
|
||||
:members:
|
||||
|
||||
We provide default implementations of `LoadPlanner` and `SavePlanner` that
|
||||
|
|
|
|||
|
|
@ -171,28 +171,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
|||
def test_e2e_async(self):
|
||||
self._run_e2e_test(compile=False, model_type=ModelType.FSDP, async_op=True)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
def test_fsspec(self):
|
||||
self._run_e2e_test(
|
||||
compile=False,
|
||||
model_type=ModelType.FSDP,
|
||||
storage_reader=DCP.FsspecReader(),
|
||||
storage_writer=DCP.FsspecWriter(),
|
||||
)
|
||||
|
||||
def _run_e2e_test(
|
||||
self,
|
||||
compile,
|
||||
model_type,
|
||||
async_op=False,
|
||||
storage_reader=None,
|
||||
storage_writer=None,
|
||||
):
|
||||
storage_reader = storage_reader or DCP.FileSystemReader()
|
||||
storage_writer = storage_writer or DCP.FileSystemWriter()
|
||||
|
||||
def _run_e2e_test(self, compile, model_type, async_op=False):
|
||||
model, optim = self._create_model(compile, ModelType.NONE)
|
||||
_train(model, optim, train_steps=2)
|
||||
|
||||
|
|
@ -207,9 +186,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
|||
}
|
||||
|
||||
if async_op:
|
||||
f = saver.async_save(
|
||||
sd, checkpoint_id=self.temp_dir, storage_writer=storage_writer
|
||||
)
|
||||
f = saver.async_save(sd, checkpoint_id=self.temp_dir)
|
||||
t = time.monotonic()
|
||||
while not f.done():
|
||||
time.sleep(1)
|
||||
|
|
@ -217,7 +194,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
|||
|
||||
f.result()
|
||||
else:
|
||||
DCP.save(sd, checkpoint_id=self.temp_dir, storage_writer=storage_writer)
|
||||
DCP.save(sd, checkpoint_id=self.temp_dir)
|
||||
|
||||
loaded_stateful_obj = TestStatefulObj()
|
||||
dist_model, dist_optim = self._create_model(compile, model_type)
|
||||
|
|
@ -232,7 +209,6 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
|||
"s": loaded_stateful_obj,
|
||||
},
|
||||
checkpoint_id=self.temp_dir,
|
||||
storage_reader=storage_reader,
|
||||
)
|
||||
|
||||
self.assertEqual(original_stateful_obj, loaded_stateful_obj)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.nn as nn
|
||||
from torch.distributed.checkpoint import FsspecReader, FsspecWriter
|
||||
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
|
||||
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
||||
|
|
|
|||
|
|
@ -220,30 +220,29 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
|||
|
||||
self._test_save_load(init_model_optim)
|
||||
|
||||
# TODO: these tests are failing on totally unrelated PR's, need to investigate
|
||||
# @with_comms
|
||||
# @skip_if_lt_x_gpu(2)
|
||||
# def test_fsdp(self) -> None:
|
||||
# self.run_subtests(
|
||||
# {
|
||||
# "use_orig_params": [True, False],
|
||||
# "use_composable": [True, False],
|
||||
# "use_dtensor": [True, False],
|
||||
# "wrapping": [tuple(), (nn.Linear, UnitModule)],
|
||||
# },
|
||||
# self._test_fsdp,
|
||||
# )
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_fsdp(self) -> None:
|
||||
self.run_subtests(
|
||||
{
|
||||
"use_orig_params": [True, False],
|
||||
"use_composable": [True, False],
|
||||
"use_dtensor": [True, False],
|
||||
"wrapping": [tuple(), (nn.Linear, UnitModule)],
|
||||
},
|
||||
self._test_fsdp,
|
||||
)
|
||||
|
||||
# @with_comms
|
||||
# @skip_if_lt_x_gpu(2)
|
||||
# def test_compiled_fsdp(self) -> None:
|
||||
# self._test_fsdp(
|
||||
# use_orig_params=True,
|
||||
# use_composable=False,
|
||||
# use_dtensor=False,
|
||||
# wrapping=tuple(),
|
||||
# compile_model=True,
|
||||
# )
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_compiled_fsdp(self) -> None:
|
||||
self._test_fsdp(
|
||||
use_orig_params=True,
|
||||
use_composable=False,
|
||||
use_dtensor=False,
|
||||
wrapping=tuple(),
|
||||
compile_model=True,
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from .api import CheckpointException
|
||||
from .default_planner import DefaultLoadPlanner, DefaultSavePlanner
|
||||
from .filesystem import FileSystemReader, FileSystemWriter
|
||||
from .fsspec import FsspecReader, FsspecWriter
|
||||
from .metadata import (
|
||||
BytesStorageMetadata,
|
||||
ChunkStorageMetadata,
|
||||
|
|
|
|||
100
torch/distributed/checkpoint/_checkpointer.py
Normal file
100
torch/distributed/checkpoint/_checkpointer.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
from concurrent.futures import Future
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint.state_dict_loader as loader
|
||||
import torch.distributed.checkpoint.state_dict_saver as saver
|
||||
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
|
||||
from torch.distributed.checkpoint.storage import (
|
||||
LoadPlanner,
|
||||
SavePlanner,
|
||||
StorageReader,
|
||||
StorageWriter,
|
||||
)
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
|
||||
class _Checkpointer:
|
||||
"""This base class specefies a high level API for saving and loading
|
||||
distributed `state_dict` 's. It provides an abstraction over the low-level APIs
|
||||
provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling
|
||||
:py:meth: `torch.distributed.state_dict_saver.save` and
|
||||
:py:meth: `torch.distributed.state_dict_loader.load` with the provided storage
|
||||
readers and writers.
|
||||
|
||||
.. warning::
|
||||
This feature is experimental and subject to removal/change.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_writer: StorageWriter,
|
||||
storage_reader: StorageReader,
|
||||
*,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
coordinator_rank: int = 0,
|
||||
no_dist: bool = False,
|
||||
load_planner: Optional[LoadPlanner] = None,
|
||||
save_planner: Optional[SavePlanner] = None,
|
||||
):
|
||||
"""Initializes the Checkpointer instance.
|
||||
|
||||
Args:
|
||||
storage_writer: Instance of StorageWrite use to perform writes.
|
||||
storage_reader: StorageReader used to load data from.
|
||||
process_group: ProcessGroup to be used for cross-rank synchronization.
|
||||
coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default.
|
||||
no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``)
|
||||
loader_planner: Instance of LoadPlanner to use when loading.
|
||||
save_planner: Instance of SavePlanner to use when saving.
|
||||
"""
|
||||
self.storage_writer = storage_writer
|
||||
self.storage_reader = storage_reader
|
||||
self.process_group = process_group
|
||||
self.coordinator_rank = coordinator_rank
|
||||
self.no_dist = no_dist
|
||||
self.load_planner = load_planner
|
||||
self.save_planner = save_planner
|
||||
|
||||
def save(
|
||||
self,
|
||||
state_dict: STATE_DICT_TYPE,
|
||||
) -> Metadata:
|
||||
"""Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization."""
|
||||
return saver.save(
|
||||
state_dict,
|
||||
self.storage_writer,
|
||||
process_group=self.process_group,
|
||||
coordinator_rank=self.coordinator_rank,
|
||||
no_dist=self.no_dist,
|
||||
planner=self.save_planner,
|
||||
)
|
||||
|
||||
def async_save(
|
||||
self,
|
||||
state_dict: STATE_DICT_TYPE,
|
||||
) -> Future:
|
||||
"""
|
||||
Calls :py:meth: `torch.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization.
|
||||
|
||||
Returns:
|
||||
Future: A future holding the resultant Metadata object from `save`.
|
||||
"""
|
||||
return saver.async_save(
|
||||
state_dict,
|
||||
storage_writer=self.storage_writer,
|
||||
process_group=self.process_group,
|
||||
planner=self.save_planner,
|
||||
)
|
||||
|
||||
def load(self, state_dict: Dict[str, Any]) -> None:
|
||||
"""Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization."""
|
||||
loader.load(
|
||||
state_dict,
|
||||
storage_reader=self.storage_reader,
|
||||
process_group=self.process_group,
|
||||
planner=self.load_planner,
|
||||
)
|
||||
|
|
@ -1,15 +1,122 @@
|
|||
# Mypy will not try inferring the types of any 3rd party libraries installed.
|
||||
# mypy: ignore-errors
|
||||
|
||||
import logging
|
||||
import io
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
from torch.distributed.checkpoint.fsspec import ( # noqa: F401 # noqa: F401
|
||||
FsspecReader,
|
||||
FsspecWriter,
|
||||
import fsspec
|
||||
from fsspec import AbstractFileSystem
|
||||
from fsspec.core import url_to_fs
|
||||
|
||||
from torch.distributed.checkpoint.filesystem import (
|
||||
FileSystemBase,
|
||||
FileSystemReader,
|
||||
FileSystemWriter,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.warning(
|
||||
"FSSpec Filesystem has been made public, please update your "
|
||||
"import to torch.distributed.checkpoint"
|
||||
)
|
||||
__all__ = [
|
||||
"FsspecWriter",
|
||||
"FsspecReader",
|
||||
]
|
||||
|
||||
|
||||
class FileSystem(FileSystemBase):
|
||||
def __init__(self) -> None:
|
||||
self.fs: Optional[AbstractFileSystem] = None
|
||||
|
||||
@contextmanager
|
||||
def create_stream(
|
||||
self, path: Union[str, os.PathLike], mode: str
|
||||
) -> Generator[io.IOBase, None, None]:
|
||||
assert self.fs is not None
|
||||
with self.fs.transaction:
|
||||
with fsspec.open(str(path), mode) as stream:
|
||||
yield stream
|
||||
|
||||
def concat_path(
|
||||
self, path: Union[str, os.PathLike], suffix: str
|
||||
) -> Union[str, os.PathLike]:
|
||||
return os.path.join(path, suffix)
|
||||
|
||||
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
|
||||
self.fs, _ = url_to_fs(path)
|
||||
return path
|
||||
|
||||
def rename(
|
||||
self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
|
||||
) -> None:
|
||||
self.fs.rename(path, new_path)
|
||||
|
||||
def mkdir(self, path: [str, os.PathLike]) -> None:
|
||||
self.fs.makedirs(path, exist_ok=True)
|
||||
|
||||
@classmethod
|
||||
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
|
||||
if isinstance(checkpoint_id, Path):
|
||||
return False
|
||||
|
||||
try:
|
||||
url_to_fs(checkpoint_id)
|
||||
except ValueError as e:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class FsspecWriter(FileSystemWriter):
|
||||
"""
|
||||
Basic implementation of StorageWriter using FFspec.
|
||||
|
||||
This implementation makes the following assumptions and simplifications:
|
||||
|
||||
* The checkpoint path is an empty or non-existing directory.
|
||||
* File creation is atomic
|
||||
|
||||
The checkpoint consist of one file per write request plus
|
||||
a `.metadata` file with the serialized metadata.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: Union[str, os.PathLike],
|
||||
single_file_per_rank: bool = True,
|
||||
sync_files: bool = True,
|
||||
thread_count: int = 1,
|
||||
per_thread_copy_ahead: int = 10_000_000,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the writer pointing to `path`.
|
||||
|
||||
Args:
|
||||
path: directory where the checkpoint will be written to.
|
||||
single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
|
||||
sync_files : force files to be synced to permanent storage. Default to True.
|
||||
thread_count: Number of IO threads to use to write. Default to 1.
|
||||
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
|
||||
|
||||
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
|
||||
"""
|
||||
super().__init__(
|
||||
path, single_file_per_rank, sync_files, thread_count, per_thread_copy_ahead
|
||||
)
|
||||
self.fs = FileSystem()
|
||||
self.path = self.fs.init_path(path)
|
||||
|
||||
@classmethod
|
||||
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
|
||||
return FileSystem.validate_checkpoint_id(checkpoint_id)
|
||||
|
||||
|
||||
class FsspecReader(FileSystemReader):
|
||||
def __init__(self, path: Union[str, os.PathLike]) -> None:
|
||||
super().__init__(path)
|
||||
self.fs = FileSystem()
|
||||
self.path = self.fs.init_path(path)
|
||||
|
||||
@classmethod
|
||||
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
|
||||
return FileSystem.check(checkpoint_id)
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def _storage_setup(
|
|||
FileSystemWriter,
|
||||
]
|
||||
try:
|
||||
from .fsspec import FsspecReader, FsspecWriter
|
||||
from ._fsspec_filesystem import FsspecReader, FsspecWriter
|
||||
|
||||
targets.append(FsspecReader if reader else FsspecWriter)
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -1,122 +0,0 @@
|
|||
# Mypy will not try inferring the types of any 3rd party libraries installed.
|
||||
# mypy: ignore-errors
|
||||
|
||||
import io
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
import fsspec
|
||||
from fsspec import AbstractFileSystem
|
||||
from fsspec.core import url_to_fs
|
||||
|
||||
from torch.distributed.checkpoint.filesystem import (
|
||||
FileSystemBase,
|
||||
FileSystemReader,
|
||||
FileSystemWriter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FsspecWriter",
|
||||
"FsspecReader",
|
||||
]
|
||||
|
||||
|
||||
class FileSystem(FileSystemBase):
|
||||
def __init__(self) -> None:
|
||||
self.fs: Optional[AbstractFileSystem] = None
|
||||
|
||||
@contextmanager
|
||||
def create_stream(
|
||||
self, path: Union[str, os.PathLike], mode: str
|
||||
) -> Generator[io.IOBase, None, None]:
|
||||
assert self.fs is not None
|
||||
with self.fs.transaction:
|
||||
with fsspec.open(str(path), mode) as stream:
|
||||
yield stream
|
||||
|
||||
def concat_path(
|
||||
self, path: Union[str, os.PathLike], suffix: str
|
||||
) -> Union[str, os.PathLike]:
|
||||
return os.path.join(path, suffix)
|
||||
|
||||
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
|
||||
self.fs, _ = url_to_fs(path)
|
||||
return path
|
||||
|
||||
def rename(
|
||||
self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
|
||||
) -> None:
|
||||
self.fs.rename(path, new_path)
|
||||
|
||||
def mkdir(self, path: [str, os.PathLike]) -> None:
|
||||
self.fs.makedirs(path, exist_ok=True)
|
||||
|
||||
@classmethod
|
||||
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
|
||||
if isinstance(checkpoint_id, Path):
|
||||
return False
|
||||
|
||||
try:
|
||||
url_to_fs(checkpoint_id)
|
||||
except ValueError as e:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class FsspecWriter(FileSystemWriter):
|
||||
"""
|
||||
Basic implementation of StorageWriter using FFspec.
|
||||
|
||||
This implementation makes the following assumptions and simplifications:
|
||||
|
||||
* The checkpoint path is an empty or non-existing directory.
|
||||
* File creation is atomic
|
||||
|
||||
The checkpoint consist of one file per write request plus
|
||||
a `.metadata` file with the serialized metadata.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: Union[str, os.PathLike],
|
||||
single_file_per_rank: bool = True,
|
||||
sync_files: bool = True,
|
||||
thread_count: int = 1,
|
||||
per_thread_copy_ahead: int = 10_000_000,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the writer pointing to `path`.
|
||||
|
||||
Args:
|
||||
path: directory where the checkpoint will be written to.
|
||||
single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
|
||||
sync_files : force files to be synced to permanent storage. Default to True.
|
||||
thread_count: Number of IO threads to use to write. Default to 1.
|
||||
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
|
||||
|
||||
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
|
||||
"""
|
||||
super().__init__(
|
||||
path, single_file_per_rank, sync_files, thread_count, per_thread_copy_ahead
|
||||
)
|
||||
self.fs = FileSystem()
|
||||
self.path = self.fs.init_path(path)
|
||||
|
||||
@classmethod
|
||||
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
|
||||
return FileSystem.validate_checkpoint_id(checkpoint_id)
|
||||
|
||||
|
||||
class FsspecReader(FileSystemReader):
|
||||
def __init__(self, path: Union[str, os.PathLike]) -> None:
|
||||
super().__init__(path)
|
||||
self.fs = FileSystem()
|
||||
self.path = self.fs.init_path(path)
|
||||
|
||||
@classmethod
|
||||
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
|
||||
return FileSystem.check(checkpoint_id)
|
||||
Loading…
Reference in a new issue