From 001e355a56a6b25fddda1395d5508d62571a18cb Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 30 Jan 2025 10:40:56 -0800 Subject: [PATCH] Add option to serialization config to reduce random reads from get_record_offset when loading with mmap=True (#143880) ## Background This PR adds `torch.utils.serialization.config.load.calculate_storage_offsets`. This option relies on the previous PR in this stack, where storage order was changed to non lexicographical. A `.format_version` entry was added to the zipfile and `calculate_storage_offsets` will only work on checkpoints with `.format_version`. When this is turned on, for `torch.load(mmap=True)`, offsets of each storage record (other than the 0th storage will be calculated instead of relying on `miniz` APIs to determine this). The existing APIs will issue multiple random reads (reading the end of central directory record, then reading the zipfile header for the record) to determine the storage offset where the record starts. This can greatly degrade `torch.load(mmap=True)` performance for non-filesystem cases. https://github.com/pytorch/pytorch/blob/6aaae9d78f0992ac6265552e4f8323ef11d62bb0/caffe2/serialize/inline_container.cc#L589-L605 ## How does this work The format for the checkpoint is as such ``` archive_name/ |_ data.pkl |_.format_version |_byteorder |_data/ |_ 0 |_ 1 |_ 2 |_ ... |_ ``` Each `data/i` record represents a storage, where storages are written in the order that the Pickler encounters them. For each storage, our `persistent_load` logic saves the following metadata to the pickle file `dtype, numel, key, location` where `numel` is the number of bytes in the storage. Note that we always use `miniz` writer in the zip64 mode per [here](https://github.com/pytorch/pytorch/blob/7796e308d0636bcbfb2490c80291edd440d4bc42/caffe2/serialize/inline_container.cc#L701) A zipfile record written by miniz looks as such ``` ---------------- ----------------- ------------------- ---------------- --------- ------------------------------ | 30 byte header | n byte filename | zip64_extra_data | m byte padding | storage | 16 or 24 byte local dir footer | ---------------- ----------------- ------------------- ---------------- --------- ------------------------------ ``` - The header size (30) is given by [`MZ_ZIP_LOCAL_DIR_HEADER_SIZE`](https://github.com/pytorch/pytorch/blob/main/third_party/miniz-3.0.2/miniz.c?fbclid=IwZXh0bgNhZW0CMTEAAR2O8Vysd--UoSCxW70gabXIS1dbz733oHwuUQ5_Ff1hY2WU6PL2i6CSH4A_aem_J9oaU2HpDeWtJKOU9EnVqw#L3290) - filename will be `"{archive_name}/{filepath}"` - `zip64_extra_data` is determined by [`mz_zip_writer_create_zip64_extra_data`](https://github.com/pytorch/pytorch/blob/7796e308d0636bcbfb2490c80291edd440d4bc42/third_party/miniz-3.0.2/miniz.c#L6202). Note that [we only create zip64_extra_data if storage_size >= 0xFFFFFFFF or the offset of the start of the header >= 0xFFFFFFFF](https://github.com/pytorch/pytorch/blob/7796e308d0636bcbfb2490c80291edd440d4bc42/third_party/miniz-3.0.2/miniz.c#L6519-L6524) - `m` is determined by [`getPadding`](https://github.com/pytorch/pytorch/blob/7796e308d0636bcbfb2490c80291edd440d4bc42/caffe2/serialize/inline_container.cc#L254), which accounts for filename, zip64_extra_data to determine `m` such that the start of `storage` is aligned to 64 bytes. The `m` bytes will always start with `F B padding_size" as the first 4 bytes - The local dir footer size is determined based on [this snippet ](https://github.com/pytorch/pytorch/blob/7796e308d0636bcbfb2490c80291edd440d4bc42/third_party/miniz-3.0.2/miniz.c#L6610-L6632): if the buffer size is 0 it is skipped. If the zip64_extra_data was created, it is 24, otherwise it is 16. When `torch.utils.serialization.config.load.calculate_storage_offsets` is set we do the following - We keep track of where the "cursor" is in the file using `current_offset`, after each persistent_load call, it will be at the offset where the header for the next record starts - for the 0th storage, "data/0", we use the regular get_record_offset to determine the start of the storage - for any other storage, (where the storages will be in order encountered by the unpickler, 0, 1, 2, 3, ...) we use `get_record_offset_no_read`, which re-uses the `getPadding` logic to determine the offset of the storage - Note that `load_tensor` will only ever be called again with the same key if the storage's `._data_ptr()` is 0 [[pointer1](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L1917-L1918)][[pointer2](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L1936-L1937)], so we cache the offsets for this edge case - After each storage, if the storage is non-zero, we account for the local dir footer based on the logic described above ## Testing strategy The agreed upon testing strategy was as follows: - Add debug code gated by an environment flag `TORCH_SERIALIZATION_DEBUG` that will run this offset calculation logic and verify it against getRecordOffset for each storage (when mmap=False) - This flag is set throughout CI, which means that every time `torch.load` is called, the offset calculation logic is implicitly being tested. Differential Revision: [D67673026](https://our.internmc.facebook.com/intern/diff/D67673026) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143880 Approved by: https://github.com/albanD ghstack dependencies: #143879 --- .ci/pytorch/macos-test.sh | 3 + .ci/pytorch/test.sh | 3 + .ci/pytorch/win-test.sh | 3 + caffe2/serialize/inline_container.cc | 36 ++++++++++-- caffe2/serialize/inline_container.h | 7 ++- docs/source/notes/serialization.rst | 3 + test/test_serialization.py | 47 +++++++++++++++ torch/csrc/jit/python/init.cpp | 14 +++++ torch/serialization.py | 86 +++++++++++++++++++++++++++- torch/utils/serialization/config.py | 1 + 10 files changed, 195 insertions(+), 8 deletions(-) diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index 95aad6e29b7..0d10382605d 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -18,6 +18,9 @@ if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available( fi popd +# enable debug asserts in serialization +export TORCH_SERIALIZATION_DEBUG=1 + setup_test_python() { # The CircleCI worker hostname doesn't resolve to an address. # This environment variable makes ProcessGroupGloo default to diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 075d025c2ef..0535bb9066d 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -46,6 +46,9 @@ BUILD_BIN_DIR="$BUILD_DIR"/bin SHARD_NUMBER="${SHARD_NUMBER:=1}" NUM_TEST_SHARDS="${NUM_TEST_SHARDS:=1}" +# enable debug asserts in serialization +export TORCH_SERIALIZATION_DEBUG=1 + export VALGRIND=ON # export TORCH_INDUCTOR_INSTALL_GXX=ON if [[ "$BUILD_ENVIRONMENT" == *clang9* || "$BUILD_ENVIRONMENT" == *xpu* ]]; then diff --git a/.ci/pytorch/win-test.sh b/.ci/pytorch/win-test.sh index aa2a2cf88f9..0426982a3ad 100755 --- a/.ci/pytorch/win-test.sh +++ b/.ci/pytorch/win-test.sh @@ -18,6 +18,9 @@ export PYTORCH_FINAL_PACKAGE_DIR="${PYTORCH_FINAL_PACKAGE_DIR:-/c/w/build-result PYTORCH_FINAL_PACKAGE_DIR_WIN=$(cygpath -w "${PYTORCH_FINAL_PACKAGE_DIR}") export PYTORCH_FINAL_PACKAGE_DIR_WIN +# enable debug asserts in serialization +export TORCH_SERIALIZATION_DEBUG=1 + mkdir -p "$TMP_DIR"/build/torch export SCRIPT_HELPERS_DIR=$SCRIPT_PARENT_DIR/win-test-helpers diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index ec993253340..2b8545af9f8 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -251,11 +251,8 @@ constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28; constexpr int MZ_ZIP_DATA_DESCRIPTOR_ID = 0x08074b50; namespace detail { -size_t getPadding( - size_t cursor, - size_t filename_size, - size_t size, - std::string& padding_buf) { + +std::tuple getOffset(size_t cursor, size_t filename_size, size_t size) { size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size + sizeof(mz_uint16) * 2; if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) { @@ -269,6 +266,16 @@ size_t getPadding( } size_t mod = start % kFieldAlignment; size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod); + std::tuple result(next_offset, start); + return result; +} + +size_t getPadding( + size_t cursor, + size_t filename_size, + size_t size, + std::string& padding_buf) { + auto [next_offset, start] = getOffset(cursor, filename_size, size); size_t padding_size = next_offset - start; size_t padding_size_plus_fbxx = padding_size + 4; if (padding_buf.size() < padding_size_plus_fbxx) { @@ -587,6 +594,14 @@ static int64_t read_le_16(uint8_t* buf) { return buf[0] + (buf[1] << 8); } +size_t PyTorchStreamReader::getRecordHeaderOffset(const std::string& name) { + std::lock_guard guard(reader_lock_); + mz_zip_archive_file_stat stat; + mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat); + valid("retrieving file meta-data for ", name.c_str()); + return stat.m_local_header_ofs; +} + size_t PyTorchStreamReader::getRecordOffset(const std::string& name) { std::lock_guard guard(reader_lock_); mz_zip_archive_file_stat stat; @@ -611,6 +626,17 @@ size_t PyTorchStreamReader::getRecordSize(const std::string& name) { return stat.m_uncomp_size; } +size_t PyTorchStreamReader::getRecordOffsetNoRead( + size_t cursor, + std::string filename, + size_t size) { + std::string full_name = archive_name_plus_slash_ + filename; + size_t full_name_size = full_name.size(); + std::tuple result = detail::getOffset(cursor, full_name_size, size); + size_t offset = std::get<0>(result); + return offset; +} + PyTorchStreamReader::~PyTorchStreamReader() { mz_zip_clear_last_error(ar_.get()); mz_zip_reader_end(ar_.get()); diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 59e0991399a..7b183fb0969 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -172,8 +172,10 @@ class TORCH_API PyTorchStreamReader final { size_t n); size_t getRecordSize(const std::string& name); - + size_t getRecordHeaderOffset(const std::string& name); size_t getRecordOffset(const std::string& name); + size_t + getRecordOffsetNoRead(size_t cursor, std::string filename, size_t size); bool hasRecord(const std::string& name); std::vector getAllRecords(); @@ -289,6 +291,9 @@ size_t getPadding( size_t filename_size, size_t size, std::string& padding_buf); + +std::tuple getOffset(size_t cursor, size_t filename_size, size_t size); + } // namespace detail } // namespace serialize diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index 3b74878e239..b3ba4feb22e 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -515,3 +515,6 @@ Config (Default : ``torch.serialization.LoadEndianness.NATIVE``) * ``mmap_flags``: See :class:`~torch.serialization.set_default_mmap_options`. (Default : ``MAP_PRIVATE``) + * ``calculate_storage_offsets``: If this config is set to ``True``, offsets for storages will be + calculated rather than read via random reads when using ``torch.load(mmap=True)``. This minimizes + random reads, which can be helpful when the file is being loaded over a network. (Default : ``False``) diff --git a/test/test_serialization.py b/test/test_serialization.py index ef7945a04f4..9945321bf6a 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -45,6 +45,7 @@ from torch.testing._internal.common_utils import ( BytesIOContext, download_file, instantiate_parametrized_tests, + IS_CI, IS_FBCODE, IS_FILESYSTEM_UTF8_ENCODING, IS_WINDOWS, @@ -827,6 +828,11 @@ class SerializationMixin: loaded_data = torch.load(f, weights_only=True) self.assertEqual(data, loaded_data) + @unittest.skipIf(not IS_CI, "only check debug var is set in CI") + def test_debug_set_in_ci(self): + # This test is to make sure that the serialization debug flag is set in CI + self.assertTrue(os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1") + class serialization_method: def __init__(self, use_zip): @@ -1041,6 +1047,23 @@ class TestSerialization(TestCase, SerializationMixin): f.seek(0) state = torch.load(f) + @serialTest() + def test_serialization_4gb_file(self): + ''' + This is a specially engineered testcase that would fail if the data_descriptor size + had been incorrectly set as data_descriptor_size32 when it should be data_descriptor_size64 + ''' + # Run GC to clear up as much memory as possible before running this test + gc.collect() + big_model = torch.nn.ModuleList([torch.nn.Linear(1, int(1024 * 1024 * 1024) + 12, bias=False), + torch.nn.Linear(1, 1, bias=False).to(torch.float8_e4m3fn), + torch.nn.Linear(1, 2, bias=False).to(torch.float8_e4m3fn)]) + + with BytesIOContext() as f: + torch.save(big_model.state_dict(), f) + f.seek(0) + torch.load(f) + @parametrize('weights_only', (True, False)) def test_pathlike_serialization(self, weights_only): model = torch.nn.Conv2d(20, 3200, kernel_size=3) @@ -4533,6 +4556,30 @@ class TestSerialization(TestCase, SerializationMixin): self.assertTrue(opened_zipfile.has_record(".format_version")) self.assertEqual(opened_zipfile.get_record(".format_version"), b'1') + @parametrize('path_type', (str, Path)) + @unittest.skipIf(IS_WINDOWS, "TemporaryFileName on windows") + def test_mmap_load_offset_calculation(self, path_type): + calculate_offsets_before = serialization_config.load.calculate_storage_offsets + try: + serialization_config.load.calculate_storage_offsets = True + m = torch.nn.Sequential(*[torch.nn.Linear(4, 4) for _ in range(20)]) + + with TemporaryFileName() as f: + f = path_type(f) + state_dict = m.state_dict() + torch.save(state_dict, f) + result = torch.load(f, mmap=True) + result_non_mmap = torch.load(f, mmap=False) + + with torch.device("meta"): + model_mmap_state_dict = torch.nn.Sequential(*[torch.nn.Linear(4, 4) for _ in range(20)]) + model_non_mmap_state_dict = torch.nn.Sequential(*[torch.nn.Linear(4, 4) for _ in range(20)]) + model_mmap_state_dict.load_state_dict(result, assign=True) + model_non_mmap_state_dict.load_state_dict(result_non_mmap, assign=True) + inp = torch.randn(4, 4) + self.assertEqual(model_mmap_state_dict(inp), model_non_mmap_state_dict(inp.clone())) + finally: + serialization_config.load.calculate_storage_offsets = calculate_offsets_before def run(self, *args, **kwargs): with serialization_method(use_zip=True): diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 411e61cf912..2ba2094b3f3 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1619,6 +1619,20 @@ void initJITBindings(PyObject* module) { "get_record_offset", [](PyTorchStreamReader& self, const std::string& key) { return self.getRecordOffset(key); + }) + .def( + "get_record_header_offset", + [](PyTorchStreamReader& self, const std::string& key) { + return self.getRecordHeaderOffset(key); + }) + .def( + "get_record_offset_no_read", + [](PyTorchStreamReader& self, + size_t zipfile_header_offset, + const std::string filename, + size_t size) { + return self.getRecordOffsetNoRead( + zipfile_header_offset, filename, size); }); // Used by torch.Package to coordinate deserialization of storages across diff --git a/torch/serialization.py b/torch/serialization.py index b9d123c82f8..e1ca2b0ab88 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -15,7 +15,7 @@ import threading import warnings from contextlib import closing, contextmanager from enum import Enum -from typing import Any, Callable, cast, Generic, IO, Optional, TypeVar, Union +from typing import Any, Callable, cast, Dict, Generic, IO, Optional, TypeVar, Union from typing_extensions import TypeAlias, TypeIs import torch @@ -1856,6 +1856,11 @@ def _load( loaded_storages = {} + can_calculate_storage_offsets = False + if zip_file.has_record(".format_version"): + version = zip_file.get_record(".format_version") + can_calculate_storage_offsets = version >= b"1" + # check if byteswapping is needed byteordername = "byteorder" byteorderdata = None @@ -1891,15 +1896,92 @@ def _load( UserWarning, ) + from torch.utils.serialization import config + + calculate_storage_offsets = config.load.calculate_storage_offsets + run_debug_asserts = os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1" + current_offset = None + # constants from miniz.h/miniz.c + data_descripter_size64 = 24 + data_descripter_size32 = 16 + mz_uint32_max = 0xFFFFFFFF + offsets: Dict[str, int] = dict() + + def _get_offset(key, name, numel): + """ + Return the offset of the storage associated with key with record name `name` and size numel. + It is expected that the zipfile header of this storage starts at current_offset. + + WARNING: This function relies on the behavior of the zipwriter in miniz.c. In particular, + the behavior of `mz_zip_writer_add_mem_ex_v2`. The behavior of this function must be kept + in sync with that of miniz! + + After reading a storage of size numel that starts at storage_offset + if it is the first time that storage was read, update nonlocal variable + current_offset to the start of the next zipfile header by incrementing + it by numel and the data descriptor size. + """ + nonlocal current_offset, offsets + if name in offsets: + storage_offset = offsets[name] + return storage_offset + + if current_offset is None: + assert key == "0" + current_offset = zip_file.get_record_offset(name) + local_header_offset = zip_file.get_record_header_offset(name) + storage_offset = current_offset + else: + storage_offset = zip_file.get_record_offset_no_read( + current_offset, name, numel + ) + local_header_offset = current_offset + + # This is only actually needed for storages that have typed_storage._data_ptr() == 0 + # after being read. Otherwise persistent_load would never "re-call" load_tensor + # for a given key. + offsets[name] = storage_offset + + # Increment current_offset of offset where next zipfile header starts + current_offset = storage_offset + numel + # add size of data descriptor after payload + if numel > 0: + if local_header_offset >= mz_uint32_max or numel >= mz_uint32_max: + current_offset += data_descripter_size64 + else: + current_offset += data_descripter_size32 + + return storage_offset + def load_tensor(dtype, numel, key, location): name = f"data/{key}" if torch._guards.detect_fake_mode(None) is not None: nbytes = numel * torch._utils._element_size(dtype) storage = torch.UntypedStorage(nbytes, device="meta") elif overall_storage is not None: - storage_offset = zip_file.get_record_offset(name) + if can_calculate_storage_offsets and calculate_storage_offsets: + storage_offset = _get_offset(key, name, numel) + if run_debug_asserts: + if storage_offset != zip_file.get_record_offset(name): + raise RuntimeError( + "This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment " + f"variable was set: Incorrect offset for {name}, got {storage_offset} expected " + f"{zip_file.get_record_offset(name)}" + ) + else: + storage_offset = zip_file.get_record_offset(name) storage = overall_storage[storage_offset : storage_offset + numel] else: + if can_calculate_storage_offsets and run_debug_asserts: + # This is debug code that we use to test the validity of + # torch.utils.serialization.config.load.calculate_storage_offsets throughout CI + storage_offset = _get_offset(key, name, numel) + if storage_offset != zip_file.get_record_offset(name): + raise RuntimeError( + "This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment " + f"variable was set: Incorrect offset for {name}, got {storage_offset} expected " + f"{zip_file.get_record_offset(name)}" + ) storage = ( zip_file.get_storage_from_record(name, numel, torch.UntypedStorage) ._typed_storage() diff --git a/torch/utils/serialization/config.py b/torch/utils/serialization/config.py index 77138676d1c..0ef12f77d9d 100644 --- a/torch/utils/serialization/config.py +++ b/torch/utils/serialization/config.py @@ -13,6 +13,7 @@ class load: endianness: _Optional["_LoadEndianess"] = None # MAP_PRIVATE = 2 mmap_flags: _Optional[int] = None if sys.platform == "win32" else 2 + calculate_storage_offsets: bool = False class save: