pytorch/test/distributed/checkpoint/test_compatibility.py
Chien-Chin Huang c7338f457c [DCP] Fixes the BC issue where the traversal doesn't support versions before 2.4 (#134158)
The original DCP doesn't flattening all the containers, which can cause issues, https://github.com/pytorch/pytorch/pull/125335 intends to solve the issue by flattening all the dictionaries.

Unfortunately, it breaks the checkpoints that are saved before 2.4. This
also shows some issues of the DCP:

1. DCP should record version in the metadata.
2. DCP should have a nice way to load old state_dict.
3. DCP should unflatten all containers (map, list) not just map.

This PR only addresses issue 2 to unblock users. Issue 1 and issue 3 need to be addressed in the future.

@pradeepfn Please let me know if this summary matches our discussion.

Fixes https://github.com/pytorch/pytorch/issues/133923

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134158
Approved by: https://github.com/wz337, https://github.com/pradeepfn
2024-08-28 16:31:44 +00:00

100 lines
3.4 KiB
Python

# Owner(s): ["oncall: distributed"]
from unittest.mock import patch
import torch
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
MetadataIndex,
TensorProperties,
TensorStorageMetadata,
)
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
class TestDCPCompatbility(TestCase):
def test_metadata(self) -> None:
# Ensure that all the new fields of all the metadata have the default
# values so that we can always deserialize from a legacy metadata.
try:
tensor = torch.zeros(4, 4)
chunk_meta = ChunkStorageMetadata(
torch.Size((1, 1)),
torch.Size((1, 1)),
)
tensor_meta = TensorStorageMetadata(
properties=TensorProperties.create_from_tensor(tensor),
size=tensor.size(),
chunks=[chunk_meta],
)
b_meta = BytesStorageMetadata()
_ = Metadata(state_dict_metadata={"a": tensor_meta, "b": b_meta})
_ = MetadataIndex(fqn="a.b.c")
except Exception as e:
raise RuntimeError(
"The change may break the BC of distributed checkpoint."
) from e
def test_sharded_tensor_dependency(self) -> None:
# Ensure that we can load the existing DCP checkpoints back even if the
# metadata contain # _shard.sharded_tensor.metadata.
from torch.distributed._shard.sharded_tensor.metadata import (
TensorProperties as stp,
)
with patch("torch.distributed.checkpoint.metadata.TensorProperties", stp):
dcp.save(
{"a": torch.zeros(4, 4)},
dcp.FileSystemWriter("/tmp/dcp_testing"),
)
dcp.load(
{"a": torch.zeros(4, 4)},
dcp.FileSystemReader("/tmp/dcp_testing"),
)
@with_temp_dir
def test_storage_meta(self) -> None:
writer = dcp.FileSystemWriter(self.temp_dir)
dcp.save({"a": torch.zeros(4, 4)}, storage_writer=writer)
reader = dcp.FileSystemReader(self.temp_dir)
storage_meta = reader.read_metadata().storage_meta
self.assertNotEqual(storage_meta, None)
self.assertEqual(str(storage_meta.checkpoint_id), self.temp_dir)
self.assertEqual(storage_meta.save_id, writer.save_id)
self.assertEqual(storage_meta.load_id, reader.load_id)
@with_temp_dir
def test_with_v_2_3(self) -> None:
sd = {
"a": torch.zeros(4, 4),
"dict": {
"dict_a": {"dict_a_1": 1, "dict_a_2": 2},
"dict_b": {"dict_b_1": 1, "dict_b_2": 2},
},
"list": [0, 1, 2, 3, 4, 5],
}
load_sd = {
"a": torch.ones(4, 4),
"dict": {
"dict_a": {"dict_a_1": 2, "dict_a_2": 4},
"dict_b": {"dict_b_1": 2, "dict_b_2": 4},
},
"list": [10, 11, 12, 13, 14, 15],
}
dcp._version._act_like_version = "2_3"
dcp.save(sd, checkpoint_id=self.temp_dir)
dcp._version._act_like_version = None
dcp.load(load_sd, checkpoint_id=self.temp_dir)
self.assertEqual(sd, load_sd)
if __name__ == "__main__":
run_tests()