Revert "[DCP] Makes fsspec public (#121508)"

This reverts commit d482614fec.

Reverted https://github.com/pytorch/pytorch/pull/121508 on behalf of https://github.com/osalpekar due to this causes torchrec tests to fail internally with this error: ModuleNotFoundError: No module named 'fsspec'. see [D54779117](https://www.internalfb.com/diff/D54779117) ([comment](https://github.com/pytorch/pytorch/pull/121508#issuecomment-1992137831))
This commit is contained in:
PyTorch MergeBot 2024-03-12 17:02:34 +00:00
parent b84f94f6a3
commit 0398dc9e8e
9 changed files with 245 additions and 196 deletions

View file

@ -68,20 +68,10 @@ The following types define the planner interface used during checkpoint:
We provide a filesystem based storage layer:
.. autoclass:: torch.distributed.checkpoint.filesystem.FileSystemReader
.. autoclass:: torch.distributed.checkpoint.FileSystemReader
:members:
.. autoclass:: torch.distributed.checkpoint.filesystem.FileSystemWriter
:members:
Additionally, we provide the following abstractions for working with Fsspec storage.
.. automodule:: torch.distributed.checkpoint.fsspec
.. autoclass:: torch.distributed.checkpoint.fsspec.FsspecReader
:members:
.. autoclass:: torch.distributed.checkpoint.fsspec.FsspecWriter
.. autoclass:: torch.distributed.checkpoint.FileSystemWriter
:members:
We provide default implementations of `LoadPlanner` and `SavePlanner` that

View file

@ -171,28 +171,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
def test_e2e_async(self):
self._run_e2e_test(compile=False, model_type=ModelType.FSDP, async_op=True)
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_fsspec(self):
self._run_e2e_test(
compile=False,
model_type=ModelType.FSDP,
storage_reader=DCP.FsspecReader(),
storage_writer=DCP.FsspecWriter(),
)
def _run_e2e_test(
self,
compile,
model_type,
async_op=False,
storage_reader=None,
storage_writer=None,
):
storage_reader = storage_reader or DCP.FileSystemReader()
storage_writer = storage_writer or DCP.FileSystemWriter()
def _run_e2e_test(self, compile, model_type, async_op=False):
model, optim = self._create_model(compile, ModelType.NONE)
_train(model, optim, train_steps=2)
@ -207,9 +186,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
}
if async_op:
f = saver.async_save(
sd, checkpoint_id=self.temp_dir, storage_writer=storage_writer
)
f = saver.async_save(sd, checkpoint_id=self.temp_dir)
t = time.monotonic()
while not f.done():
time.sleep(1)
@ -217,7 +194,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
f.result()
else:
DCP.save(sd, checkpoint_id=self.temp_dir, storage_writer=storage_writer)
DCP.save(sd, checkpoint_id=self.temp_dir)
loaded_stateful_obj = TestStatefulObj()
dist_model, dist_optim = self._create_model(compile, model_type)
@ -232,7 +209,6 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
"s": loaded_stateful_obj,
},
checkpoint_id=self.temp_dir,
storage_reader=storage_reader,
)
self.assertEqual(original_stateful_obj, loaded_stateful_obj)

View file

@ -9,7 +9,7 @@ import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.checkpoint import FsspecReader, FsspecWriter
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

View file

@ -220,30 +220,29 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self._test_save_load(init_model_optim)
# TODO: these tests are failing on totally unrelated PR's, need to investigate
# @with_comms
# @skip_if_lt_x_gpu(2)
# def test_fsdp(self) -> None:
# self.run_subtests(
# {
# "use_orig_params": [True, False],
# "use_composable": [True, False],
# "use_dtensor": [True, False],
# "wrapping": [tuple(), (nn.Linear, UnitModule)],
# },
# self._test_fsdp,
# )
@with_comms
@skip_if_lt_x_gpu(2)
def test_fsdp(self) -> None:
self.run_subtests(
{
"use_orig_params": [True, False],
"use_composable": [True, False],
"use_dtensor": [True, False],
"wrapping": [tuple(), (nn.Linear, UnitModule)],
},
self._test_fsdp,
)
# @with_comms
# @skip_if_lt_x_gpu(2)
# def test_compiled_fsdp(self) -> None:
# self._test_fsdp(
# use_orig_params=True,
# use_composable=False,
# use_dtensor=False,
# wrapping=tuple(),
# compile_model=True,
# )
@with_comms
@skip_if_lt_x_gpu(2)
def test_compiled_fsdp(self) -> None:
self._test_fsdp(
use_orig_params=True,
use_composable=False,
use_dtensor=False,
wrapping=tuple(),
compile_model=True,
)
@with_comms
@skip_if_lt_x_gpu(2)

View file

@ -1,7 +1,6 @@
from .api import CheckpointException
from .default_planner import DefaultLoadPlanner, DefaultSavePlanner
from .filesystem import FileSystemReader, FileSystemWriter
from .fsspec import FsspecReader, FsspecWriter
from .metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,

View file

@ -0,0 +1,100 @@
from concurrent.futures import Future
from typing import Any, Dict, List, Optional
import torch.distributed as dist
import torch.distributed.checkpoint.state_dict_loader as loader
import torch.distributed.checkpoint.state_dict_saver as saver
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
from torch.distributed.checkpoint.storage import (
LoadPlanner,
SavePlanner,
StorageReader,
StorageWriter,
)
__all__: List[str] = []
class _Checkpointer:
"""This base class specefies a high level API for saving and loading
distributed `state_dict` 's. It provides an abstraction over the low-level APIs
provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling
:py:meth: `torch.distributed.state_dict_saver.save` and
:py:meth: `torch.distributed.state_dict_loader.load` with the provided storage
readers and writers.
.. warning::
This feature is experimental and subject to removal/change.
"""
def __init__(
self,
storage_writer: StorageWriter,
storage_reader: StorageReader,
*,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
load_planner: Optional[LoadPlanner] = None,
save_planner: Optional[SavePlanner] = None,
):
"""Initializes the Checkpointer instance.
Args:
storage_writer: Instance of StorageWrite use to perform writes.
storage_reader: StorageReader used to load data from.
process_group: ProcessGroup to be used for cross-rank synchronization.
coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default.
no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``)
loader_planner: Instance of LoadPlanner to use when loading.
save_planner: Instance of SavePlanner to use when saving.
"""
self.storage_writer = storage_writer
self.storage_reader = storage_reader
self.process_group = process_group
self.coordinator_rank = coordinator_rank
self.no_dist = no_dist
self.load_planner = load_planner
self.save_planner = save_planner
def save(
self,
state_dict: STATE_DICT_TYPE,
) -> Metadata:
"""Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization."""
return saver.save(
state_dict,
self.storage_writer,
process_group=self.process_group,
coordinator_rank=self.coordinator_rank,
no_dist=self.no_dist,
planner=self.save_planner,
)
def async_save(
self,
state_dict: STATE_DICT_TYPE,
) -> Future:
"""
Calls :py:meth: `torch.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization.
Returns:
Future: A future holding the resultant Metadata object from `save`.
"""
return saver.async_save(
state_dict,
storage_writer=self.storage_writer,
process_group=self.process_group,
planner=self.save_planner,
)
def load(self, state_dict: Dict[str, Any]) -> None:
"""Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization."""
loader.load(
state_dict,
storage_reader=self.storage_reader,
process_group=self.process_group,
planner=self.load_planner,
)

View file

@ -1,15 +1,122 @@
# Mypy will not try inferring the types of any 3rd party libraries installed.
# mypy: ignore-errors
import logging
import io
import os
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Optional, Union
from torch.distributed.checkpoint.fsspec import ( # noqa: F401 # noqa: F401
FsspecReader,
FsspecWriter,
import fsspec
from fsspec import AbstractFileSystem
from fsspec.core import url_to_fs
from torch.distributed.checkpoint.filesystem import (
FileSystemBase,
FileSystemReader,
FileSystemWriter,
)
log = logging.getLogger(__name__)
log.warning(
"FSSpec Filesystem has been made public, please update your "
"import to torch.distributed.checkpoint"
)
__all__ = [
"FsspecWriter",
"FsspecReader",
]
class FileSystem(FileSystemBase):
def __init__(self) -> None:
self.fs: Optional[AbstractFileSystem] = None
@contextmanager
def create_stream(
self, path: Union[str, os.PathLike], mode: str
) -> Generator[io.IOBase, None, None]:
assert self.fs is not None
with self.fs.transaction:
with fsspec.open(str(path), mode) as stream:
yield stream
def concat_path(
self, path: Union[str, os.PathLike], suffix: str
) -> Union[str, os.PathLike]:
return os.path.join(path, suffix)
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
self.fs, _ = url_to_fs(path)
return path
def rename(
self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
) -> None:
self.fs.rename(path, new_path)
def mkdir(self, path: [str, os.PathLike]) -> None:
self.fs.makedirs(path, exist_ok=True)
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
if isinstance(checkpoint_id, Path):
return False
try:
url_to_fs(checkpoint_id)
except ValueError as e:
return False
return True
class FsspecWriter(FileSystemWriter):
"""
Basic implementation of StorageWriter using FFspec.
This implementation makes the following assumptions and simplifications:
* The checkpoint path is an empty or non-existing directory.
* File creation is atomic
The checkpoint consist of one file per write request plus
a `.metadata` file with the serialized metadata.
"""
def __init__(
self,
path: Union[str, os.PathLike],
single_file_per_rank: bool = True,
sync_files: bool = True,
thread_count: int = 1,
per_thread_copy_ahead: int = 10_000_000,
) -> None:
"""
Initialize the writer pointing to `path`.
Args:
path: directory where the checkpoint will be written to.
single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
sync_files : force files to be synced to permanent storage. Default to True.
thread_count: Number of IO threads to use to write. Default to 1.
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
"""
super().__init__(
path, single_file_per_rank, sync_files, thread_count, per_thread_copy_ahead
)
self.fs = FileSystem()
self.path = self.fs.init_path(path)
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return FileSystem.validate_checkpoint_id(checkpoint_id)
class FsspecReader(FileSystemReader):
def __init__(self, path: Union[str, os.PathLike]) -> None:
super().__init__(path)
self.fs = FileSystem()
self.path = self.fs.init_path(path)
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return FileSystem.check(checkpoint_id)

View file

@ -32,7 +32,7 @@ def _storage_setup(
FileSystemWriter,
]
try:
from .fsspec import FsspecReader, FsspecWriter
from ._fsspec_filesystem import FsspecReader, FsspecWriter
targets.append(FsspecReader if reader else FsspecWriter)
except Exception:

View file

@ -1,122 +0,0 @@
# Mypy will not try inferring the types of any 3rd party libraries installed.
# mypy: ignore-errors
import io
import os
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Optional, Union
import fsspec
from fsspec import AbstractFileSystem
from fsspec.core import url_to_fs
from torch.distributed.checkpoint.filesystem import (
FileSystemBase,
FileSystemReader,
FileSystemWriter,
)
__all__ = [
"FsspecWriter",
"FsspecReader",
]
class FileSystem(FileSystemBase):
def __init__(self) -> None:
self.fs: Optional[AbstractFileSystem] = None
@contextmanager
def create_stream(
self, path: Union[str, os.PathLike], mode: str
) -> Generator[io.IOBase, None, None]:
assert self.fs is not None
with self.fs.transaction:
with fsspec.open(str(path), mode) as stream:
yield stream
def concat_path(
self, path: Union[str, os.PathLike], suffix: str
) -> Union[str, os.PathLike]:
return os.path.join(path, suffix)
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
self.fs, _ = url_to_fs(path)
return path
def rename(
self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
) -> None:
self.fs.rename(path, new_path)
def mkdir(self, path: [str, os.PathLike]) -> None:
self.fs.makedirs(path, exist_ok=True)
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
if isinstance(checkpoint_id, Path):
return False
try:
url_to_fs(checkpoint_id)
except ValueError as e:
return False
return True
class FsspecWriter(FileSystemWriter):
"""
Basic implementation of StorageWriter using FFspec.
This implementation makes the following assumptions and simplifications:
* The checkpoint path is an empty or non-existing directory.
* File creation is atomic
The checkpoint consist of one file per write request plus
a `.metadata` file with the serialized metadata.
"""
def __init__(
self,
path: Union[str, os.PathLike],
single_file_per_rank: bool = True,
sync_files: bool = True,
thread_count: int = 1,
per_thread_copy_ahead: int = 10_000_000,
) -> None:
"""
Initialize the writer pointing to `path`.
Args:
path: directory where the checkpoint will be written to.
single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
sync_files : force files to be synced to permanent storage. Default to True.
thread_count: Number of IO threads to use to write. Default to 1.
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
"""
super().__init__(
path, single_file_per_rank, sync_files, thread_count, per_thread_copy_ahead
)
self.fs = FileSystem()
self.path = self.fs.init_path(path)
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return FileSystem.validate_checkpoint_id(checkpoint_id)
class FsspecReader(FileSystemReader):
def __init__(self, path: Union[str, os.PathLike]) -> None:
super().__init__(path)
self.fs = FileSystem()
self.path = self.fs.init_path(path)
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return FileSystem.check(checkpoint_id)