Add option to serialization config to reduce random reads from get_record_offset when loading with mmap=True (#143880)

## Background

This PR adds `torch.utils.serialization.config.load.calculate_storage_offsets`. This option relies  on the previous PR in this stack, where storage order was changed to non lexicographical. A `.format_version` entry was added to the zipfile and `calculate_storage_offsets` will only work on checkpoints with `.format_version`.

When this is turned on, for `torch.load(mmap=True)`, offsets of each storage record (other than the 0th storage will be calculated instead of relying on `miniz` APIs to determine this).

The existing APIs will issue multiple random reads (reading the end of central directory record, then reading the zipfile header for the record) to determine the storage offset where the record starts. This can greatly degrade `torch.load(mmap=True)` performance for non-filesystem cases.

6aaae9d78f/caffe2/serialize/inline_container.cc (L589-L605)

## Testing strategy

The agreed upon testing strategy was as follows:
- Add debug code gated by an environment flag `TORCH_SERIALIZATION_DEBUG` that will run this offset calculation logic and verify it against getRecordOffset for each storage (when mmap=False)
- This flag is set throughout CI, which means that every time `torch.load` is called, the offset calculation logic is implicitly being tested.

Differential Revision: [D67673026](https://our.internmc.facebook.com/intern/diff/D67673026)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143880
Approved by: https://github.com/albanD
ghstack dependencies: #143879
This commit is contained in:
Mikayla Gawarecki 2025-01-27 10:18:13 -08:00 committed by PyTorch MergeBot
parent 7db0afabaa
commit db3685a35c
10 changed files with 179 additions and 7 deletions

View file

@ -18,6 +18,9 @@ if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available(
fi
popd
# enable debug asserts in serialization
export TORCH_SERIALIZATION_DEBUG=1
setup_test_python() {
# The CircleCI worker hostname doesn't resolve to an address.
# This environment variable makes ProcessGroupGloo default to

View file

@ -46,6 +46,9 @@ BUILD_BIN_DIR="$BUILD_DIR"/bin
SHARD_NUMBER="${SHARD_NUMBER:=1}"
NUM_TEST_SHARDS="${NUM_TEST_SHARDS:=1}"
# enable debug asserts in serialization
export TORCH_SERIALIZATION_DEBUG=1
export VALGRIND=ON
# export TORCH_INDUCTOR_INSTALL_GXX=ON
if [[ "$BUILD_ENVIRONMENT" == *clang9* || "$BUILD_ENVIRONMENT" == *xpu* ]]; then

View file

@ -18,6 +18,9 @@ export PYTORCH_FINAL_PACKAGE_DIR="${PYTORCH_FINAL_PACKAGE_DIR:-/c/w/build-result
PYTORCH_FINAL_PACKAGE_DIR_WIN=$(cygpath -w "${PYTORCH_FINAL_PACKAGE_DIR}")
export PYTORCH_FINAL_PACKAGE_DIR_WIN
# enable debug asserts in serialization
export TORCH_SERIALIZATION_DEBUG=1
mkdir -p "$TMP_DIR"/build/torch
export SCRIPT_HELPERS_DIR=$SCRIPT_PARENT_DIR/win-test-helpers

View file

@ -250,11 +250,8 @@ constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
constexpr int MZ_ZIP_DATA_DESCRIPTOR_ID = 0x08074b50;
namespace detail {
size_t getPadding(
size_t cursor,
size_t filename_size,
size_t size,
std::string& padding_buf) {
std::tuple<size_t, size_t> getOffset(size_t cursor, size_t filename_size, size_t size) {
size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size +
sizeof(mz_uint16) * 2;
if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
@ -268,6 +265,16 @@ size_t getPadding(
}
size_t mod = start % kFieldAlignment;
size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
std::tuple<size_t, size_t> result(next_offset, start);
return result;
}
size_t getPadding(
size_t cursor,
size_t filename_size,
size_t size,
std::string& padding_buf) {
auto [next_offset, start] = getOffset(cursor, filename_size, size);
size_t padding_size = next_offset - start;
size_t padding_size_plus_fbxx = padding_size + 4;
if (padding_buf.size() < padding_size_plus_fbxx) {
@ -610,6 +617,17 @@ size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
return stat.m_uncomp_size;
}
size_t PyTorchStreamReader::getRecordOffsetNoRead(
size_t cursor,
std::string filename,
size_t size) {
std::string full_name = archive_name_plus_slash_ + filename;
size_t full_name_size = full_name.size();
std::tuple<size_t, size_t> result = detail::getOffset(cursor, full_name_size, size);
size_t offset = std::get<0>(result);
return offset;
}
PyTorchStreamReader::~PyTorchStreamReader() {
mz_zip_clear_last_error(ar_.get());
mz_zip_reader_end(ar_.get());

View file

@ -174,6 +174,8 @@ class TORCH_API PyTorchStreamReader final {
size_t getRecordSize(const std::string& name);
size_t getRecordOffset(const std::string& name);
size_t
getRecordOffsetNoRead(size_t cursor, std::string filename, size_t size);
bool hasRecord(const std::string& name);
std::vector<std::string> getAllRecords();
@ -289,6 +291,9 @@ size_t getPadding(
size_t filename_size,
size_t size,
std::string& padding_buf);
std::tuple<size_t, size_t> getOffset(size_t cursor, size_t filename_size, size_t size);
} // namespace detail
} // namespace serialize

View file

@ -515,3 +515,6 @@ Config
(Default : ``torch.serialization.LoadEndianness.NATIVE``)
* ``mmap_flags``: See :class:`~torch.serialization.set_default_mmap_options`.
(Default : ``MAP_PRIVATE``)
* ``calculate_storage_offsets``: If this config is set to ``True``, offsets for storages will be
calculated rather than read via random reads when using ``torch.load(mmap=True)``. This minimizes
random reads, which can be helpful when the file is being loaded over a network. (Default : ``False``)

View file

@ -45,6 +45,7 @@ from torch.testing._internal.common_utils import (
BytesIOContext,
download_file,
instantiate_parametrized_tests,
IS_CI,
IS_FBCODE,
IS_FILESYSTEM_UTF8_ENCODING,
IS_WINDOWS,
@ -827,6 +828,11 @@ class SerializationMixin:
loaded_data = torch.load(f, weights_only=True)
self.assertEqual(data, loaded_data)
@unittest.skipIf(not IS_CI, "only check debug var is set in CI")
def test_debug_set_in_ci(self):
# This test is to make sure that the serialization debug flag is set in CI
self.assertTrue(os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1")
class serialization_method:
def __init__(self, use_zip):
@ -1041,6 +1047,23 @@ class TestSerialization(TestCase, SerializationMixin):
f.seek(0)
state = torch.load(f)
@serialTest()
def test_serialization_4gb_file(self):
'''
This is a specially engineered testcase that would fail if the data_descriptor size
had been incorrectly set as data_descriptor_size32 when it should be data_descriptor_size64
'''
# Run GC to clear up as much memory as possible before running this test
gc.collect()
big_model = torch.nn.ModuleList([torch.nn.Linear(1, int(1024 * 1024 * 1024) + 12, bias=False),
torch.nn.Linear(1, 1, bias=False).to(torch.float8_e4m3fn),
torch.nn.Linear(1, 2, bias=False).to(torch.float8_e4m3fn)])
with BytesIOContext() as f:
torch.save(big_model.state_dict(), f)
f.seek(0)
torch.load(f)
@parametrize('weights_only', (True, False))
def test_pathlike_serialization(self, weights_only):
model = torch.nn.Conv2d(20, 3200, kernel_size=3)
@ -4533,6 +4556,30 @@ class TestSerialization(TestCase, SerializationMixin):
self.assertTrue(opened_zipfile.has_record(".format_version"))
self.assertEqual(opened_zipfile.get_record(".format_version"), b'1')
@parametrize('path_type', (str, Path))
@unittest.skipIf(IS_WINDOWS, "TemporaryFileName on windows")
def test_mmap_load_offset_calculation(self, path_type):
calculate_offsets_before = serialization_config.load.calculate_storage_offsets
try:
serialization_config.load.calculate_storage_offsets = True
m = torch.nn.Sequential(*[torch.nn.Linear(4, 4) for _ in range(20)])
with TemporaryFileName() as f:
f = path_type(f)
state_dict = m.state_dict()
torch.save(state_dict, f)
result = torch.load(f, mmap=True)
result_non_mmap = torch.load(f, mmap=False)
with torch.device("meta"):
model_mmap_state_dict = torch.nn.Sequential(*[torch.nn.Linear(4, 4) for _ in range(20)])
model_non_mmap_state_dict = torch.nn.Sequential(*[torch.nn.Linear(4, 4) for _ in range(20)])
model_mmap_state_dict.load_state_dict(result, assign=True)
model_non_mmap_state_dict.load_state_dict(result_non_mmap, assign=True)
inp = torch.randn(4, 4)
self.assertEqual(model_mmap_state_dict(inp), model_non_mmap_state_dict(inp.clone()))
finally:
serialization_config.load.calculate_storage_offsets = calculate_offsets_before
def run(self, *args, **kwargs):
with serialization_method(use_zip=True):

View file

@ -1619,6 +1619,15 @@ void initJITBindings(PyObject* module) {
"get_record_offset",
[](PyTorchStreamReader& self, const std::string& key) {
return self.getRecordOffset(key);
})
.def(
"get_record_offset_no_read",
[](PyTorchStreamReader& self,
size_t zipfile_header_offset,
const std::string filename,
size_t size) {
return self.getRecordOffsetNoRead(
zipfile_header_offset, filename, size);
});
// Used by torch.Package to coordinate deserialization of storages across

View file

@ -15,7 +15,7 @@ import threading
import warnings
from contextlib import closing, contextmanager
from enum import Enum
from typing import Any, Callable, cast, Generic, IO, Optional, TypeVar, Union
from typing import Any, Callable, cast, Dict, Generic, IO, Optional, TypeVar, Union
from typing_extensions import TypeAlias, TypeIs
import torch
@ -1856,6 +1856,11 @@ def _load(
loaded_storages = {}
can_calculate_storage_offsets = False
if zip_file.has_record(".format_version"):
version = zip_file.get_record(".format_version")
can_calculate_storage_offsets = version >= b"1"
# check if byteswapping is needed
byteordername = "byteorder"
byteorderdata = None
@ -1891,15 +1896,90 @@ def _load(
UserWarning,
)
from torch.utils.serialization import config
calculate_storage_offsets = config.load.calculate_storage_offsets
run_debug_asserts = os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1"
current_offset = None
# constants from miniz.h/miniz.c
data_descripter_size64 = 24
data_descripter_size32 = 16
mz_uint32_max = 0xFFFFFFFF
offsets: Dict[str, int] = dict()
def _get_offset(key, name, numel):
"""
Return the offset of the storage associated with key with record name `name` and size numel.
It is expected that the zipfile header of this storage starts at current_offset.
WARNING: This function relies on the behavior of the zipwriter in miniz.c. In particular,
the behavior of `mz_zip_writer_add_mem_ex_v2`. The behavior of this function must be kept
in sync with that of miniz!
After reading a storage of size numel that starts at storage_offset
if it is the first time that storage was read, update nonlocal variable
current_offset to the start of the next zipfile header by incrementing
it by numel and the data descriptor size.
"""
nonlocal current_offset, offsets
if name in offsets:
storage_offset = offsets[name]
return storage_offset
if current_offset is None:
assert key == "0"
current_offset = zip_file.get_record_offset(name)
storage_offset = current_offset
else:
storage_offset = zip_file.get_record_offset_no_read(
current_offset, name, numel
)
# This is only actually needed for storages that have typed_storage._data_ptr() == 0
# after being read. Otherwise persistent_load would never "re-call" load_tensor
# for a given key.
offsets[name] = storage_offset
# Increment current_offset of offset where next zipfile header starts
local_header_offset = current_offset
current_offset = storage_offset + numel
# add size of data descriptor after payload
if local_header_offset >= mz_uint32_max or numel >= mz_uint32_max:
current_offset += data_descripter_size64
else:
current_offset += data_descripter_size32
return storage_offset
def load_tensor(dtype, numel, key, location):
name = f"data/{key}"
if torch._guards.detect_fake_mode(None) is not None:
nbytes = numel * torch._utils._element_size(dtype)
storage = torch.UntypedStorage(nbytes, device="meta")
elif overall_storage is not None:
storage_offset = zip_file.get_record_offset(name)
if can_calculate_storage_offsets and calculate_storage_offsets:
storage_offset = _get_offset(key, name, numel)
if run_debug_asserts:
if storage_offset != zip_file.get_record_offset(name):
raise RuntimeError(
"This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment "
f"variable was set: Incorrect offset for {name}, got {storage_offset} expected "
f"{zip_file.get_record_offset(name)}"
)
else:
storage_offset = zip_file.get_record_offset(name)
storage = overall_storage[storage_offset : storage_offset + numel]
else:
if can_calculate_storage_offsets and run_debug_asserts:
# This is debug code that we use to test the validity of
# torch.utils.serialization.config.load.calculate_storage_offsets throughout CI
storage_offset = _get_offset(key, name, numel)
if storage_offset != zip_file.get_record_offset(name):
raise RuntimeError(
"This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment "
f"variable was set: Incorrect offset for {name}, got {storage_offset} expected "
f"{zip_file.get_record_offset(name)}"
)
storage = (
zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)
._typed_storage()

View file

@ -13,6 +13,7 @@ class load:
endianness: _Optional["_LoadEndianess"] = None
# MAP_PRIVATE = 2
mmap_flags: _Optional[int] = None if sys.platform == "win32" else 2
calculate_storage_offsets: bool = False
class save: