diff --git a/docs/source/distributed.checkpoint.rst b/docs/source/distributed.checkpoint.rst index bf06a170681..f758b204ec7 100644 --- a/docs/source/distributed.checkpoint.rst +++ b/docs/source/distributed.checkpoint.rst @@ -68,20 +68,10 @@ The following types define the planner interface used during checkpoint: We provide a filesystem based storage layer: -.. autoclass:: torch.distributed.checkpoint.filesystem.FileSystemReader +.. autoclass:: torch.distributed.checkpoint.FileSystemReader :members: -.. autoclass:: torch.distributed.checkpoint.filesystem.FileSystemWriter - :members: - -Additionally, we provide the following abstractions for working with Fsspec storage. - -.. automodule:: torch.distributed.checkpoint.fsspec - -.. autoclass:: torch.distributed.checkpoint.fsspec.FsspecReader - :members: - -.. autoclass:: torch.distributed.checkpoint.fsspec.FsspecWriter +.. autoclass:: torch.distributed.checkpoint.FileSystemWriter :members: We provide default implementations of `LoadPlanner` and `SavePlanner` that diff --git a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py index 6b508d4a671..280f16c3d70 100644 --- a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py +++ b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py @@ -171,28 +171,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin): def test_e2e_async(self): self._run_e2e_test(compile=False, model_type=ModelType.FSDP, async_op=True) - @with_comms - @skip_if_lt_x_gpu(4) - @with_temp_dir - def test_fsspec(self): - self._run_e2e_test( - compile=False, - model_type=ModelType.FSDP, - storage_reader=DCP.FsspecReader(), - storage_writer=DCP.FsspecWriter(), - ) - - def _run_e2e_test( - self, - compile, - model_type, - async_op=False, - storage_reader=None, - storage_writer=None, - ): - storage_reader = storage_reader or DCP.FileSystemReader() - storage_writer = storage_writer or DCP.FileSystemWriter() - + def _run_e2e_test(self, compile, model_type, async_op=False): model, optim = self._create_model(compile, ModelType.NONE) _train(model, optim, train_steps=2) @@ -207,9 +186,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin): } if async_op: - f = saver.async_save( - sd, checkpoint_id=self.temp_dir, storage_writer=storage_writer - ) + f = saver.async_save(sd, checkpoint_id=self.temp_dir) t = time.monotonic() while not f.done(): time.sleep(1) @@ -217,7 +194,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin): f.result() else: - DCP.save(sd, checkpoint_id=self.temp_dir, storage_writer=storage_writer) + DCP.save(sd, checkpoint_id=self.temp_dir) loaded_stateful_obj = TestStatefulObj() dist_model, dist_optim = self._create_model(compile, model_type) @@ -232,7 +209,6 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin): "s": loaded_stateful_obj, }, checkpoint_id=self.temp_dir, - storage_reader=storage_reader, ) self.assertEqual(original_stateful_obj, loaded_stateful_obj) diff --git a/test/distributed/checkpoint/test_fsspec.py b/test/distributed/checkpoint/test_fsspec.py index a068bcb3a82..b5d41959dc3 100644 --- a/test/distributed/checkpoint/test_fsspec.py +++ b/test/distributed/checkpoint/test_fsspec.py @@ -9,7 +9,7 @@ import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn -from torch.distributed.checkpoint import FsspecReader, FsspecWriter +from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index 6d8e88e7b81..c553018d365 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -220,30 +220,29 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin): self._test_save_load(init_model_optim) - # TODO: these tests are failing on totally unrelated PR's, need to investigate - # @with_comms - # @skip_if_lt_x_gpu(2) - # def test_fsdp(self) -> None: - # self.run_subtests( - # { - # "use_orig_params": [True, False], - # "use_composable": [True, False], - # "use_dtensor": [True, False], - # "wrapping": [tuple(), (nn.Linear, UnitModule)], - # }, - # self._test_fsdp, - # ) + @with_comms + @skip_if_lt_x_gpu(2) + def test_fsdp(self) -> None: + self.run_subtests( + { + "use_orig_params": [True, False], + "use_composable": [True, False], + "use_dtensor": [True, False], + "wrapping": [tuple(), (nn.Linear, UnitModule)], + }, + self._test_fsdp, + ) - # @with_comms - # @skip_if_lt_x_gpu(2) - # def test_compiled_fsdp(self) -> None: - # self._test_fsdp( - # use_orig_params=True, - # use_composable=False, - # use_dtensor=False, - # wrapping=tuple(), - # compile_model=True, - # ) + @with_comms + @skip_if_lt_x_gpu(2) + def test_compiled_fsdp(self) -> None: + self._test_fsdp( + use_orig_params=True, + use_composable=False, + use_dtensor=False, + wrapping=tuple(), + compile_model=True, + ) @with_comms @skip_if_lt_x_gpu(2) diff --git a/torch/distributed/checkpoint/__init__.py b/torch/distributed/checkpoint/__init__.py index ff0c799233c..3262acccac1 100644 --- a/torch/distributed/checkpoint/__init__.py +++ b/torch/distributed/checkpoint/__init__.py @@ -1,7 +1,6 @@ from .api import CheckpointException from .default_planner import DefaultLoadPlanner, DefaultSavePlanner from .filesystem import FileSystemReader, FileSystemWriter -from .fsspec import FsspecReader, FsspecWriter from .metadata import ( BytesStorageMetadata, ChunkStorageMetadata, diff --git a/torch/distributed/checkpoint/_checkpointer.py b/torch/distributed/checkpoint/_checkpointer.py new file mode 100644 index 00000000000..a93fe8197de --- /dev/null +++ b/torch/distributed/checkpoint/_checkpointer.py @@ -0,0 +1,100 @@ +from concurrent.futures import Future +from typing import Any, Dict, List, Optional + +import torch.distributed as dist +import torch.distributed.checkpoint.state_dict_loader as loader +import torch.distributed.checkpoint.state_dict_saver as saver +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from torch.distributed.checkpoint.storage import ( + LoadPlanner, + SavePlanner, + StorageReader, + StorageWriter, +) + + +__all__: List[str] = [] + + +class _Checkpointer: + """This base class specefies a high level API for saving and loading + distributed `state_dict` 's. It provides an abstraction over the low-level APIs + provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling + :py:meth: `torch.distributed.state_dict_saver.save` and + :py:meth: `torch.distributed.state_dict_loader.load` with the provided storage + readers and writers. + + .. warning:: + This feature is experimental and subject to removal/change. + + """ + + def __init__( + self, + storage_writer: StorageWriter, + storage_reader: StorageReader, + *, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + load_planner: Optional[LoadPlanner] = None, + save_planner: Optional[SavePlanner] = None, + ): + """Initializes the Checkpointer instance. + + Args: + storage_writer: Instance of StorageWrite use to perform writes. + storage_reader: StorageReader used to load data from. + process_group: ProcessGroup to be used for cross-rank synchronization. + coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default. + no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``) + loader_planner: Instance of LoadPlanner to use when loading. + save_planner: Instance of SavePlanner to use when saving. + """ + self.storage_writer = storage_writer + self.storage_reader = storage_reader + self.process_group = process_group + self.coordinator_rank = coordinator_rank + self.no_dist = no_dist + self.load_planner = load_planner + self.save_planner = save_planner + + def save( + self, + state_dict: STATE_DICT_TYPE, + ) -> Metadata: + """Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization.""" + return saver.save( + state_dict, + self.storage_writer, + process_group=self.process_group, + coordinator_rank=self.coordinator_rank, + no_dist=self.no_dist, + planner=self.save_planner, + ) + + def async_save( + self, + state_dict: STATE_DICT_TYPE, + ) -> Future: + """ + Calls :py:meth: `torch.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization. + + Returns: + Future: A future holding the resultant Metadata object from `save`. + """ + return saver.async_save( + state_dict, + storage_writer=self.storage_writer, + process_group=self.process_group, + planner=self.save_planner, + ) + + def load(self, state_dict: Dict[str, Any]) -> None: + """Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization.""" + loader.load( + state_dict, + storage_reader=self.storage_reader, + process_group=self.process_group, + planner=self.load_planner, + ) diff --git a/torch/distributed/checkpoint/_fsspec_filesystem.py b/torch/distributed/checkpoint/_fsspec_filesystem.py index 3dfd7b61e3d..ae97bdf8d53 100644 --- a/torch/distributed/checkpoint/_fsspec_filesystem.py +++ b/torch/distributed/checkpoint/_fsspec_filesystem.py @@ -1,15 +1,122 @@ # Mypy will not try inferring the types of any 3rd party libraries installed. # mypy: ignore-errors -import logging +import io +import os +from contextlib import contextmanager +from pathlib import Path +from typing import Generator, Optional, Union -from torch.distributed.checkpoint.fsspec import ( # noqa: F401 # noqa: F401 - FsspecReader, - FsspecWriter, +import fsspec +from fsspec import AbstractFileSystem +from fsspec.core import url_to_fs + +from torch.distributed.checkpoint.filesystem import ( + FileSystemBase, + FileSystemReader, + FileSystemWriter, ) -log = logging.getLogger(__name__) -log.warning( - "FSSpec Filesystem has been made public, please update your " - "import to torch.distributed.checkpoint" -) +__all__ = [ + "FsspecWriter", + "FsspecReader", +] + + +class FileSystem(FileSystemBase): + def __init__(self) -> None: + self.fs: Optional[AbstractFileSystem] = None + + @contextmanager + def create_stream( + self, path: Union[str, os.PathLike], mode: str + ) -> Generator[io.IOBase, None, None]: + assert self.fs is not None + with self.fs.transaction: + with fsspec.open(str(path), mode) as stream: + yield stream + + def concat_path( + self, path: Union[str, os.PathLike], suffix: str + ) -> Union[str, os.PathLike]: + return os.path.join(path, suffix) + + def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: + self.fs, _ = url_to_fs(path) + return path + + def rename( + self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] + ) -> None: + self.fs.rename(path, new_path) + + def mkdir(self, path: [str, os.PathLike]) -> None: + self.fs.makedirs(path, exist_ok=True) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + if isinstance(checkpoint_id, Path): + return False + + try: + url_to_fs(checkpoint_id) + except ValueError as e: + return False + + return True + + +class FsspecWriter(FileSystemWriter): + """ + Basic implementation of StorageWriter using FFspec. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a `.metadata` file with the serialized metadata. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + thread_count: int = 1, + per_thread_copy_ahead: int = 10_000_000, + ) -> None: + """ + Initialize the writer pointing to `path`. + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + thread_count: Number of IO threads to use to write. Default to 1. + per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + super().__init__( + path, single_file_per_rank, sync_files, thread_count, per_thread_copy_ahead + ) + self.fs = FileSystem() + self.path = self.fs.init_path(path) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) + + +class FsspecReader(FileSystemReader): + def __init__(self, path: Union[str, os.PathLike]) -> None: + super().__init__(path) + self.fs = FileSystem() + self.path = self.fs.init_path(path) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.check(checkpoint_id) diff --git a/torch/distributed/checkpoint/_storage_utils.py b/torch/distributed/checkpoint/_storage_utils.py index 2b0ba566d47..0f5205a1f20 100644 --- a/torch/distributed/checkpoint/_storage_utils.py +++ b/torch/distributed/checkpoint/_storage_utils.py @@ -32,7 +32,7 @@ def _storage_setup( FileSystemWriter, ] try: - from .fsspec import FsspecReader, FsspecWriter + from ._fsspec_filesystem import FsspecReader, FsspecWriter targets.append(FsspecReader if reader else FsspecWriter) except Exception: diff --git a/torch/distributed/checkpoint/fsspec.py b/torch/distributed/checkpoint/fsspec.py deleted file mode 100644 index ae97bdf8d53..00000000000 --- a/torch/distributed/checkpoint/fsspec.py +++ /dev/null @@ -1,122 +0,0 @@ -# Mypy will not try inferring the types of any 3rd party libraries installed. -# mypy: ignore-errors - -import io -import os -from contextlib import contextmanager -from pathlib import Path -from typing import Generator, Optional, Union - -import fsspec -from fsspec import AbstractFileSystem -from fsspec.core import url_to_fs - -from torch.distributed.checkpoint.filesystem import ( - FileSystemBase, - FileSystemReader, - FileSystemWriter, -) - -__all__ = [ - "FsspecWriter", - "FsspecReader", -] - - -class FileSystem(FileSystemBase): - def __init__(self) -> None: - self.fs: Optional[AbstractFileSystem] = None - - @contextmanager - def create_stream( - self, path: Union[str, os.PathLike], mode: str - ) -> Generator[io.IOBase, None, None]: - assert self.fs is not None - with self.fs.transaction: - with fsspec.open(str(path), mode) as stream: - yield stream - - def concat_path( - self, path: Union[str, os.PathLike], suffix: str - ) -> Union[str, os.PathLike]: - return os.path.join(path, suffix) - - def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: - self.fs, _ = url_to_fs(path) - return path - - def rename( - self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] - ) -> None: - self.fs.rename(path, new_path) - - def mkdir(self, path: [str, os.PathLike]) -> None: - self.fs.makedirs(path, exist_ok=True) - - @classmethod - def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: - if isinstance(checkpoint_id, Path): - return False - - try: - url_to_fs(checkpoint_id) - except ValueError as e: - return False - - return True - - -class FsspecWriter(FileSystemWriter): - """ - Basic implementation of StorageWriter using FFspec. - - This implementation makes the following assumptions and simplifications: - - * The checkpoint path is an empty or non-existing directory. - * File creation is atomic - - The checkpoint consist of one file per write request plus - a `.metadata` file with the serialized metadata. - - """ - - def __init__( - self, - path: Union[str, os.PathLike], - single_file_per_rank: bool = True, - sync_files: bool = True, - thread_count: int = 1, - per_thread_copy_ahead: int = 10_000_000, - ) -> None: - """ - Initialize the writer pointing to `path`. - - Args: - path: directory where the checkpoint will be written to. - single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. - sync_files : force files to be synced to permanent storage. Default to True. - thread_count: Number of IO threads to use to write. Default to 1. - per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. - - N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. - """ - super().__init__( - path, single_file_per_rank, sync_files, thread_count, per_thread_copy_ahead - ) - self.fs = FileSystem() - self.path = self.fs.init_path(path) - - @classmethod - def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: - return FileSystem.validate_checkpoint_id(checkpoint_id) - - -class FsspecReader(FileSystemReader): - def __init__(self, path: Union[str, os.PathLike]) -> None: - super().__init__(path) - self.fs = FileSystem() - self.path = self.fs.init_path(path) - - @classmethod - def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: - return FileSystem.check(checkpoint_id)