pytorch/test/distributed/checkpoint/test_utils.py
Marc Horowitz 95d333f52e [distributed] Fix _ReaderView.read() and readinto() to stop reading at the end of the slice (#143357)
_ReaderView doesn't work correctly if the slice ends past the view.

read(-1) would call read(-1) on the base_stream, which would consume the entire underlying stream, even if the view ended before that.
read(n) would read n bytes, even if the view ended before that.

The new implementation clamps the size read to the size of the view.

readinto(b) would read len(b) bytes, even if the view ended before that.

Since the interface depends on the size of b, we use a (potentially) shortened view into b to avoid a copy.  If the view doesn't contain enough data to fill the view, then this will appear as end of stream to the caller, which is the desired behavior.

This fix should not be user facing, since the bug is in an internal helper, and is only visible with new code down the stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143357
Approved by: https://github.com/saumishr
2025-01-11 00:22:10 +00:00

189 lines
6.7 KiB
Python

# Owner(s): ["oncall: distributed"]
import io
import sys
import torch
from torch.distributed._shard.sharded_tensor import (
Shard,
ShardedTensor,
ShardedTensorMetadata,
ShardMetadata,
)
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
from torch.distributed.c10d_logger import _c10d_logger
from torch.distributed.checkpoint.logger import _dcp_logger
from torch.distributed.checkpoint.metadata import MetadataIndex
from torch.distributed.checkpoint.utils import _create_file_view, find_state_dict_object
from torch.testing._internal.common_utils import (
run_tests,
TEST_WITH_DEV_DBG_ASAN,
TestCase,
)
from torch.testing._internal.distributed.distributed_utils import with_fake_comms
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
def create_sharded_tensor(rank, world_size, shards_per_rank):
shards_metadata = []
local_shards = []
for idx in range(0, world_size * shards_per_rank):
shard_rank = idx // shards_per_rank
shard_md = ShardMetadata(
shard_offsets=[idx * 8], shard_sizes=[8], placement=f"rank:{shard_rank}/cpu"
)
shards_metadata.append(shard_md)
if shard_rank == rank:
shard = Shard.from_tensor_and_offsets(
torch.rand(*shard_md.shard_sizes),
shard_offsets=shard_md.shard_offsets,
rank=rank,
)
local_shards.append(shard)
sharded_tensor_md = ShardedTensorMetadata(
shards_metadata=shards_metadata,
size=torch.Size([8 * len(shards_metadata)]),
tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)),
)
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md
)
class TestMedatadaIndex(TestCase):
def test_init_convert_offset(self):
a = MetadataIndex("foo", [1, 2])
b = MetadataIndex("foo", torch.Size([1, 2]))
self.assertEqual(a, b)
def test_index_hint_ignored_on_equals(self):
a = MetadataIndex("foo")
b = MetadataIndex("foo", index=99)
self.assertEqual(a, b)
def test_index_hint_ignored_on_hash(self):
a = MetadataIndex("foo")
b = MetadataIndex("foo", index=99)
self.assertEqual(hash(a), hash(b))
def test_flat_data(self):
state_dict = {
"a": torch.rand(10),
"b": [1, 2, 3],
}
a = find_state_dict_object(state_dict, MetadataIndex("a"))
self.assertEqual(a, state_dict["a"])
a = find_state_dict_object(state_dict, MetadataIndex("a", [0]))
self.assertEqual(a, state_dict["a"])
a = find_state_dict_object(state_dict, MetadataIndex("a", index=99))
self.assertEqual(a, state_dict["a"])
b = find_state_dict_object(state_dict, MetadataIndex("b"))
self.assertEqual(b, state_dict["b"])
b = find_state_dict_object(state_dict, MetadataIndex("b", index=1))
self.assertEqual(b, state_dict["b"])
with self.assertRaisesRegex(ValueError, "FQN"):
find_state_dict_object(state_dict, MetadataIndex("c"))
with self.assertRaisesRegex(ValueError, "ShardedTensor"):
find_state_dict_object(state_dict, MetadataIndex("b", [1]))
@with_fake_comms(rank=0, world_size=2)
def test_sharded_tensor_lookup(self):
st = create_sharded_tensor(rank=0, world_size=2, shards_per_rank=3)
state_dict = {"st": st}
obj = find_state_dict_object(state_dict, MetadataIndex("st", [8]))
self.assertEqual(obj, st.local_shards()[1].tensor)
# good hint
obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=1))
self.assertEqual(obj, st.local_shards()[1].tensor)
# bad hint
obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=2))
self.assertEqual(obj, st.local_shards()[1].tensor)
# broken hint
obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=99))
self.assertEqual(obj, st.local_shards()[1].tensor)
with self.assertRaisesRegex(ValueError, "no offset was provided"):
find_state_dict_object(state_dict, MetadataIndex("st"))
with self.assertRaisesRegex(ValueError, "Could not find shard"):
find_state_dict_object(state_dict, MetadataIndex("st", [1]))
def test_dcp_logger(self):
self.assertTrue(_c10d_logger is not _dcp_logger)
self.assertEqual(1, len(_c10d_logger.handlers))
class TestReaderView(TestCase):
def setUp(self):
buffer = io.BytesIO(bytearray(range(ord("A"), ord("Z") + 1)))
self.front_view = _create_file_view(buffer, 0, 5)
buffer = io.BytesIO(bytearray(range(ord("A"), ord("Z") + 1)))
self.middle_view = _create_file_view(buffer, 10, 5)
buffer = io.BytesIO(bytearray(range(ord("A"), ord("Z") + 1)))
self.back_view = _create_file_view(buffer, len(buffer.getbuffer()) - 5, 5)
def testShortRead(self):
self.assertEqual(self.front_view.read(3), b"ABC")
self.assertEqual(self.middle_view.read(3), b"KLM")
self.assertEqual(self.back_view.read(3), b"VWX")
def testLongRead(self):
self.assertEqual(self.front_view.read(10), b"ABCDE")
self.assertEqual(self.middle_view.read(10), b"KLMNO")
self.assertEqual(self.back_view.read(10), b"VWXYZ")
def testAllRead(self):
self.assertEqual(self.front_view.read(-1), b"ABCDE")
self.assertEqual(self.middle_view.read(-1), b"KLMNO")
self.assertEqual(self.back_view.read(-1), b"VWXYZ")
def testShortReadinto(self):
ba = bytearray(3)
self.assertEqual(self.front_view.readinto(ba), 3)
self.assertEqual(ba, b"ABC")
self.assertEqual(self.middle_view.readinto(ba), 3)
self.assertEqual(ba, b"KLM")
self.assertEqual(self.back_view.readinto(ba), 3)
self.assertEqual(ba, b"VWX")
def testLongReadinto(self):
ba = bytearray(8)
self.assertEqual(self.front_view.readinto(ba), 5)
self.assertEqual(ba, b"ABCDE\0\0\0")
self.assertEqual(self.front_view.readinto(ba), 0)
self.assertEqual(ba, b"ABCDE\0\0\0")
self.assertEqual(self.middle_view.readinto(ba), 5)
self.assertEqual(ba, b"KLMNO\0\0\0")
self.assertEqual(self.middle_view.readinto(ba), 0)
self.assertEqual(ba, b"KLMNO\0\0\0")
self.assertEqual(self.back_view.readinto(ba), 5)
self.assertEqual(ba, b"VWXYZ\0\0\0")
self.assertEqual(self.back_view.readinto(ba), 0)
self.assertEqual(ba, b"VWXYZ\0\0\0")
if __name__ == "__main__":
run_tests()