From 7b56b039afe2b4a4038c09d8b6cb1597823f3d5f Mon Sep 17 00:00:00 2001 From: Marc Horowitz Date: Wed, 15 Jan 2025 18:54:44 -0800 Subject: [PATCH] [dcp] Add ZStandard transformer (#143360) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143360 Approved by: https://github.com/saumishr ghstack dependencies: #143358, #143359 --- .ci/docker/requirements-ci.txt | 1 + requirements.txt | 1 + .../checkpoint/test_dtensor_resharding.py | 3 +- .../checkpoint/test_file_system_checkpoint.py | 3 +- .../test_file_system_checkpoint_cpu.py | 35 +++++++++++ torch/distributed/checkpoint/_extension.py | 63 ++++++++++++++++++- 6 files changed, 102 insertions(+), 4 deletions(-) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index ecce41f8ea7..a936b2b043f 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -363,6 +363,7 @@ pwlf==2.2.1 ; python_version >= "3.8" astunparse PyYAML setuptools +zstandard ninja==1.11.1 ; platform_machine == "aarch64" scons==4.5.2 ; platform_machine == "aarch64" diff --git a/requirements.txt b/requirements.txt index 03731d2e0fc..64d5563c2d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ setuptools sympy==1.13.3 types-dataclasses typing-extensions>=4.10.0 +zstandard diff --git a/test/distributed/checkpoint/test_dtensor_resharding.py b/test/distributed/checkpoint/test_dtensor_resharding.py index f4e982c3c46..b99e6592c5c 100644 --- a/test/distributed/checkpoint/test_dtensor_resharding.py +++ b/test/distributed/checkpoint/test_dtensor_resharding.py @@ -8,6 +8,7 @@ from torch.distributed._tensor import ( Shard, zeros, ) +from torch.distributed.checkpoint._extension import ZStandard from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -58,7 +59,7 @@ class TestDTensorReshardPlacementChange(DTensorTestBase): @with_comms @skip_if_lt_x_gpu(2) @with_temp_dir - @parametrize("extensions", [None, [Rot13Example()]]) + @parametrize("extensions", [None, [Rot13Example()], [ZStandard()]]) def test_1d_to_1d_reshard_placement_change(self, extensions) -> None: CHECKPOINT_DIR = self.temp_dir diff --git a/test/distributed/checkpoint/test_file_system_checkpoint.py b/test/distributed/checkpoint/test_file_system_checkpoint.py index dbfcef0c2f3..c7c6e88b168 100644 --- a/test/distributed/checkpoint/test_file_system_checkpoint.py +++ b/test/distributed/checkpoint/test_file_system_checkpoint.py @@ -22,6 +22,7 @@ from torch.distributed.checkpoint import ( load_state_dict, save_state_dict, ) +from torch.distributed.checkpoint._extension import ZStandard from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -165,7 +166,7 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase): @with_comms(init_rpc=False) @skip_if_lt_x_gpu(2) @requires_nccl() - @parametrize("extensions", [None, [Rot13Example()]]) + @parametrize("extensions", [None, [Rot13Example()], [ZStandard()]]) def test_read_write_shard_tensor(self, extensions) -> None: paths = [tempfile.mkdtemp()] dist.broadcast_object_list(paths) diff --git a/test/distributed/checkpoint/test_file_system_checkpoint_cpu.py b/test/distributed/checkpoint/test_file_system_checkpoint_cpu.py index a398a55cdb6..f2e1483ce3d 100644 --- a/test/distributed/checkpoint/test_file_system_checkpoint_cpu.py +++ b/test/distributed/checkpoint/test_file_system_checkpoint_cpu.py @@ -22,6 +22,7 @@ from torch.distributed.checkpoint import ( save, save_state_dict, ) +from torch.distributed.checkpoint._extension import ZStandard from torch.distributed.checkpoint.stateful import Stateful from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -191,6 +192,39 @@ class TestDistributedStateDictSaveLoadRot13(TestCase): assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) +class TestDistributedStateDictSaveLoadZStandard(TestCase): + @parametrize("thread_count", _THREAD_COUNTS) + def test_read_write_only_tensor(self, thread_count) -> None: + with tempfile.TemporaryDirectory() as path: + state_dict_to_save = MyTestModule().state_dict() + state_dict_to_save["test_blob"] = BlobState(b"SomeBlobForTesting") + + fs_writer = FileSystemWriter( + path=path, + thread_count=thread_count, + _extensions=[ZStandard()], + ) + save( + state_dict=state_dict_to_save, + storage_writer=fs_writer, + ) + + state_dict_to_load_to = MyTestModule().state_dict() + state_dict_to_load_to["test_blob"] = BlobState(b"") + + with self.assertRaises(AssertionError): + assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) + + # Load from file without any resharding + fs_reader = FileSystemReader(path=path) + load( + state_dict=state_dict_to_load_to, + storage_reader=fs_reader, + ) + + assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) + + class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase): @property def world_size(self) -> int: @@ -525,6 +559,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase): instantiate_parametrized_tests(TestDistributedStateDictSaveLoad) instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13) instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor) +instantiate_parametrized_tests(TestDistributedStateDictSaveLoadZStandard) instantiate_parametrized_tests(TestDistributedReshardOnLoad) if __name__ == "__main__": diff --git a/torch/distributed/checkpoint/_extension.py b/torch/distributed/checkpoint/_extension.py index 0cadfbc8c52..f03d805c28c 100644 --- a/torch/distributed/checkpoint/_extension.py +++ b/torch/distributed/checkpoint/_extension.py @@ -1,6 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import abc +import importlib.util +import sys from collections.abc import Sequence from typing import IO, Type @@ -9,7 +11,24 @@ from typing import IO, Type # change. Feedback and bug fixes are always welcome. -__all__ = ["Extension", "StreamTransformExtension", "ExtensionRegistry"] +zstandard_module_name = "zstandard" + +if (zstandard := sys.modules.get(zstandard_module_name, None)) is not None: + pass +elif (zstandard_spec := importlib.util.find_spec(zstandard_module_name)) is not None: + zstandard = importlib.util.module_from_spec(zstandard_spec) + sys.modules[zstandard_module_name] = zstandard + zstandard_spec.loader.exec_module(zstandard) # type: ignore[union-attr] +else: + zstandard = None + + +__all__ = [ + "Extension", + "StreamTransformExtension", + "ZStandard", + "ExtensionRegistry", +] class Extension(abc.ABC): @@ -72,10 +91,50 @@ class StreamTransformExtension(Extension): """ +class ZStandard(StreamTransformExtension): + @staticmethod + def is_available() -> bool: + return zstandard is not None + + @staticmethod + def from_descriptor(version: str) -> "ZStandard": + if version.partition(".")[0] != "1": + raise ValueError(f"Unknown extension {version=}") + if not ZStandard.is_available(): + raise ValueError( + f"Stream with ZStandard compression cannot be processed because no module named '{zstandard_module_name}'" + ) + return ZStandard() + + @staticmethod + def registry_name() -> str: + return "stream.zstd" + + def __init__(self) -> None: + super().__init__() + if not ZStandard.is_available(): + raise ValueError( + f"ZStandard extension is unavailable because no module named '{zstandard_module_name}'" + ) + + def get_descriptor(self) -> str: + return f"{self.registry_name()}/1" + + def transform_to(self, output: IO[bytes]) -> IO[bytes]: + compressor = zstandard.ZstdCompressor() # type: ignore[union-attr] + return compressor.stream_writer(output) + + def transform_from(self, input: IO[bytes]) -> IO[bytes]: + decompressor = zstandard.ZstdDecompressor() # type: ignore[union-attr] + return decompressor.stream_reader(input) + + class ExtensionRegistry: def __init__(self) -> None: # Populate default registry contents - self.extensions: dict[str, Type[Extension]] = {} + self.extensions: dict[str, Type[Extension]] = { + cls.registry_name(): cls for cls in (ZStandard,) + } def register(self, cls: Type[Extension]) -> None: self.extensions[cls.registry_name()] = cls