mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dcp] Add ZStandard transformer (#143360)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143360 Approved by: https://github.com/saumishr ghstack dependencies: #143358, #143359
This commit is contained in:
parent
9c909bf3bb
commit
7b56b039af
6 changed files with 102 additions and 4 deletions
|
|
@ -363,6 +363,7 @@ pwlf==2.2.1 ; python_version >= "3.8"
|
||||||
astunparse
|
astunparse
|
||||||
PyYAML
|
PyYAML
|
||||||
setuptools
|
setuptools
|
||||||
|
zstandard
|
||||||
|
|
||||||
ninja==1.11.1 ; platform_machine == "aarch64"
|
ninja==1.11.1 ; platform_machine == "aarch64"
|
||||||
scons==4.5.2 ; platform_machine == "aarch64"
|
scons==4.5.2 ; platform_machine == "aarch64"
|
||||||
|
|
|
||||||
|
|
@ -19,3 +19,4 @@ setuptools
|
||||||
sympy==1.13.3
|
sympy==1.13.3
|
||||||
types-dataclasses
|
types-dataclasses
|
||||||
typing-extensions>=4.10.0
|
typing-extensions>=4.10.0
|
||||||
|
zstandard
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from torch.distributed._tensor import (
|
||||||
Shard,
|
Shard,
|
||||||
zeros,
|
zeros,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.checkpoint._extension import ZStandard
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
parametrize,
|
parametrize,
|
||||||
|
|
@ -58,7 +59,7 @@ class TestDTensorReshardPlacementChange(DTensorTestBase):
|
||||||
@with_comms
|
@with_comms
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@parametrize("extensions", [None, [Rot13Example()]])
|
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
|
||||||
def test_1d_to_1d_reshard_placement_change(self, extensions) -> None:
|
def test_1d_to_1d_reshard_placement_change(self, extensions) -> None:
|
||||||
CHECKPOINT_DIR = self.temp_dir
|
CHECKPOINT_DIR = self.temp_dir
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from torch.distributed.checkpoint import (
|
||||||
load_state_dict,
|
load_state_dict,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.checkpoint._extension import ZStandard
|
||||||
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
|
|
@ -165,7 +166,7 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
||||||
@with_comms(init_rpc=False)
|
@with_comms(init_rpc=False)
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
@parametrize("extensions", [None, [Rot13Example()]])
|
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
|
||||||
def test_read_write_shard_tensor(self, extensions) -> None:
|
def test_read_write_shard_tensor(self, extensions) -> None:
|
||||||
paths = [tempfile.mkdtemp()]
|
paths = [tempfile.mkdtemp()]
|
||||||
dist.broadcast_object_list(paths)
|
dist.broadcast_object_list(paths)
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from torch.distributed.checkpoint import (
|
||||||
save,
|
save,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.checkpoint._extension import ZStandard
|
||||||
from torch.distributed.checkpoint.stateful import Stateful
|
from torch.distributed.checkpoint.stateful import Stateful
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
|
|
@ -191,6 +192,39 @@ class TestDistributedStateDictSaveLoadRot13(TestCase):
|
||||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDistributedStateDictSaveLoadZStandard(TestCase):
|
||||||
|
@parametrize("thread_count", _THREAD_COUNTS)
|
||||||
|
def test_read_write_only_tensor(self, thread_count) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as path:
|
||||||
|
state_dict_to_save = MyTestModule().state_dict()
|
||||||
|
state_dict_to_save["test_blob"] = BlobState(b"SomeBlobForTesting")
|
||||||
|
|
||||||
|
fs_writer = FileSystemWriter(
|
||||||
|
path=path,
|
||||||
|
thread_count=thread_count,
|
||||||
|
_extensions=[ZStandard()],
|
||||||
|
)
|
||||||
|
save(
|
||||||
|
state_dict=state_dict_to_save,
|
||||||
|
storage_writer=fs_writer,
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict_to_load_to = MyTestModule().state_dict()
|
||||||
|
state_dict_to_load_to["test_blob"] = BlobState(b"")
|
||||||
|
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||||
|
|
||||||
|
# Load from file without any resharding
|
||||||
|
fs_reader = FileSystemReader(path=path)
|
||||||
|
load(
|
||||||
|
state_dict=state_dict_to_load_to,
|
||||||
|
storage_reader=fs_reader,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||||
|
|
||||||
|
|
||||||
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
||||||
@property
|
@property
|
||||||
def world_size(self) -> int:
|
def world_size(self) -> int:
|
||||||
|
|
@ -525,6 +559,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoad)
|
instantiate_parametrized_tests(TestDistributedStateDictSaveLoad)
|
||||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13)
|
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13)
|
||||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
|
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
|
||||||
|
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadZStandard)
|
||||||
instantiate_parametrized_tests(TestDistributedReshardOnLoad)
|
instantiate_parametrized_tests(TestDistributedReshardOnLoad)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import IO, Type
|
from typing import IO, Type
|
||||||
|
|
||||||
|
|
@ -9,7 +11,24 @@ from typing import IO, Type
|
||||||
# change. Feedback and bug fixes are always welcome.
|
# change. Feedback and bug fixes are always welcome.
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Extension", "StreamTransformExtension", "ExtensionRegistry"]
|
zstandard_module_name = "zstandard"
|
||||||
|
|
||||||
|
if (zstandard := sys.modules.get(zstandard_module_name, None)) is not None:
|
||||||
|
pass
|
||||||
|
elif (zstandard_spec := importlib.util.find_spec(zstandard_module_name)) is not None:
|
||||||
|
zstandard = importlib.util.module_from_spec(zstandard_spec)
|
||||||
|
sys.modules[zstandard_module_name] = zstandard
|
||||||
|
zstandard_spec.loader.exec_module(zstandard) # type: ignore[union-attr]
|
||||||
|
else:
|
||||||
|
zstandard = None
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Extension",
|
||||||
|
"StreamTransformExtension",
|
||||||
|
"ZStandard",
|
||||||
|
"ExtensionRegistry",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Extension(abc.ABC):
|
class Extension(abc.ABC):
|
||||||
|
|
@ -72,10 +91,50 @@ class StreamTransformExtension(Extension):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ZStandard(StreamTransformExtension):
|
||||||
|
@staticmethod
|
||||||
|
def is_available() -> bool:
|
||||||
|
return zstandard is not None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_descriptor(version: str) -> "ZStandard":
|
||||||
|
if version.partition(".")[0] != "1":
|
||||||
|
raise ValueError(f"Unknown extension {version=}")
|
||||||
|
if not ZStandard.is_available():
|
||||||
|
raise ValueError(
|
||||||
|
f"Stream with ZStandard compression cannot be processed because no module named '{zstandard_module_name}'"
|
||||||
|
)
|
||||||
|
return ZStandard()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def registry_name() -> str:
|
||||||
|
return "stream.zstd"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if not ZStandard.is_available():
|
||||||
|
raise ValueError(
|
||||||
|
f"ZStandard extension is unavailable because no module named '{zstandard_module_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_descriptor(self) -> str:
|
||||||
|
return f"{self.registry_name()}/1"
|
||||||
|
|
||||||
|
def transform_to(self, output: IO[bytes]) -> IO[bytes]:
|
||||||
|
compressor = zstandard.ZstdCompressor() # type: ignore[union-attr]
|
||||||
|
return compressor.stream_writer(output)
|
||||||
|
|
||||||
|
def transform_from(self, input: IO[bytes]) -> IO[bytes]:
|
||||||
|
decompressor = zstandard.ZstdDecompressor() # type: ignore[union-attr]
|
||||||
|
return decompressor.stream_reader(input)
|
||||||
|
|
||||||
|
|
||||||
class ExtensionRegistry:
|
class ExtensionRegistry:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
# Populate default registry contents
|
# Populate default registry contents
|
||||||
self.extensions: dict[str, Type[Extension]] = {}
|
self.extensions: dict[str, Type[Extension]] = {
|
||||||
|
cls.registry_name(): cls for cls in (ZStandard,)
|
||||||
|
}
|
||||||
|
|
||||||
def register(self, cls: Type[Extension]) -> None:
|
def register(self, cls: Type[Extension]) -> None:
|
||||||
self.extensions[cls.registry_name()] = cls
|
self.extensions[cls.registry_name()] = cls
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue