[dcp] Add ZStandard transformer (#143360)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143360
Approved by: https://github.com/saumishr
ghstack dependencies: #143358, #143359
This commit is contained in:
Marc Horowitz 2025-01-15 18:54:44 -08:00 committed by PyTorch MergeBot
parent 9c909bf3bb
commit 7b56b039af
6 changed files with 102 additions and 4 deletions

View file

@ -363,6 +363,7 @@ pwlf==2.2.1 ; python_version >= "3.8"
astunparse astunparse
PyYAML PyYAML
setuptools setuptools
zstandard
ninja==1.11.1 ; platform_machine == "aarch64" ninja==1.11.1 ; platform_machine == "aarch64"
scons==4.5.2 ; platform_machine == "aarch64" scons==4.5.2 ; platform_machine == "aarch64"

View file

@ -19,3 +19,4 @@ setuptools
sympy==1.13.3 sympy==1.13.3
types-dataclasses types-dataclasses
typing-extensions>=4.10.0 typing-extensions>=4.10.0
zstandard

View file

@ -8,6 +8,7 @@ from torch.distributed._tensor import (
Shard, Shard,
zeros, zeros,
) )
from torch.distributed.checkpoint._extension import ZStandard
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
parametrize, parametrize,
@ -58,7 +59,7 @@ class TestDTensorReshardPlacementChange(DTensorTestBase):
@with_comms @with_comms
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@with_temp_dir @with_temp_dir
@parametrize("extensions", [None, [Rot13Example()]]) @parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
def test_1d_to_1d_reshard_placement_change(self, extensions) -> None: def test_1d_to_1d_reshard_placement_change(self, extensions) -> None:
CHECKPOINT_DIR = self.temp_dir CHECKPOINT_DIR = self.temp_dir

View file

@ -22,6 +22,7 @@ from torch.distributed.checkpoint import (
load_state_dict, load_state_dict,
save_state_dict, save_state_dict,
) )
from torch.distributed.checkpoint._extension import ZStandard
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
@ -165,7 +166,7 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
@with_comms(init_rpc=False) @with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@requires_nccl() @requires_nccl()
@parametrize("extensions", [None, [Rot13Example()]]) @parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
def test_read_write_shard_tensor(self, extensions) -> None: def test_read_write_shard_tensor(self, extensions) -> None:
paths = [tempfile.mkdtemp()] paths = [tempfile.mkdtemp()]
dist.broadcast_object_list(paths) dist.broadcast_object_list(paths)

View file

@ -22,6 +22,7 @@ from torch.distributed.checkpoint import (
save, save,
save_state_dict, save_state_dict,
) )
from torch.distributed.checkpoint._extension import ZStandard
from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.stateful import Stateful
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
@ -191,6 +192,39 @@ class TestDistributedStateDictSaveLoadRot13(TestCase):
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
class TestDistributedStateDictSaveLoadZStandard(TestCase):
@parametrize("thread_count", _THREAD_COUNTS)
def test_read_write_only_tensor(self, thread_count) -> None:
with tempfile.TemporaryDirectory() as path:
state_dict_to_save = MyTestModule().state_dict()
state_dict_to_save["test_blob"] = BlobState(b"SomeBlobForTesting")
fs_writer = FileSystemWriter(
path=path,
thread_count=thread_count,
_extensions=[ZStandard()],
)
save(
state_dict=state_dict_to_save,
storage_writer=fs_writer,
)
state_dict_to_load_to = MyTestModule().state_dict()
state_dict_to_load_to["test_blob"] = BlobState(b"")
with self.assertRaises(AssertionError):
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
# Load from file without any resharding
fs_reader = FileSystemReader(path=path)
load(
state_dict=state_dict_to_load_to,
storage_reader=fs_reader,
)
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase): class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
@property @property
def world_size(self) -> int: def world_size(self) -> int:
@ -525,6 +559,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
instantiate_parametrized_tests(TestDistributedStateDictSaveLoad) instantiate_parametrized_tests(TestDistributedStateDictSaveLoad)
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13) instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13)
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor) instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadZStandard)
instantiate_parametrized_tests(TestDistributedReshardOnLoad) instantiate_parametrized_tests(TestDistributedReshardOnLoad)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1,6 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
import abc import abc
import importlib.util
import sys
from collections.abc import Sequence from collections.abc import Sequence
from typing import IO, Type from typing import IO, Type
@ -9,7 +11,24 @@ from typing import IO, Type
# change. Feedback and bug fixes are always welcome. # change. Feedback and bug fixes are always welcome.
__all__ = ["Extension", "StreamTransformExtension", "ExtensionRegistry"] zstandard_module_name = "zstandard"
if (zstandard := sys.modules.get(zstandard_module_name, None)) is not None:
pass
elif (zstandard_spec := importlib.util.find_spec(zstandard_module_name)) is not None:
zstandard = importlib.util.module_from_spec(zstandard_spec)
sys.modules[zstandard_module_name] = zstandard
zstandard_spec.loader.exec_module(zstandard) # type: ignore[union-attr]
else:
zstandard = None
__all__ = [
"Extension",
"StreamTransformExtension",
"ZStandard",
"ExtensionRegistry",
]
class Extension(abc.ABC): class Extension(abc.ABC):
@ -72,10 +91,50 @@ class StreamTransformExtension(Extension):
""" """
class ZStandard(StreamTransformExtension):
@staticmethod
def is_available() -> bool:
return zstandard is not None
@staticmethod
def from_descriptor(version: str) -> "ZStandard":
if version.partition(".")[0] != "1":
raise ValueError(f"Unknown extension {version=}")
if not ZStandard.is_available():
raise ValueError(
f"Stream with ZStandard compression cannot be processed because no module named '{zstandard_module_name}'"
)
return ZStandard()
@staticmethod
def registry_name() -> str:
return "stream.zstd"
def __init__(self) -> None:
super().__init__()
if not ZStandard.is_available():
raise ValueError(
f"ZStandard extension is unavailable because no module named '{zstandard_module_name}'"
)
def get_descriptor(self) -> str:
return f"{self.registry_name()}/1"
def transform_to(self, output: IO[bytes]) -> IO[bytes]:
compressor = zstandard.ZstdCompressor() # type: ignore[union-attr]
return compressor.stream_writer(output)
def transform_from(self, input: IO[bytes]) -> IO[bytes]:
decompressor = zstandard.ZstdDecompressor() # type: ignore[union-attr]
return decompressor.stream_reader(input)
class ExtensionRegistry: class ExtensionRegistry:
def __init__(self) -> None: def __init__(self) -> None:
# Populate default registry contents # Populate default registry contents
self.extensions: dict[str, Type[Extension]] = {} self.extensions: dict[str, Type[Extension]] = {
cls.registry_name(): cls for cls in (ZStandard,)
}
def register(self, cls: Type[Extension]) -> None: def register(self, cls: Type[Extension]) -> None:
self.extensions[cls.registry_name()] = cls self.extensions[cls.registry_name()] = cls