diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index 95aad6e29b7..0d10382605d 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -18,6 +18,9 @@ if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available( fi popd +# enable debug asserts in serialization +export TORCH_SERIALIZATION_DEBUG=1 + setup_test_python() { # The CircleCI worker hostname doesn't resolve to an address. # This environment variable makes ProcessGroupGloo default to diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 075d025c2ef..0535bb9066d 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -46,6 +46,9 @@ BUILD_BIN_DIR="$BUILD_DIR"/bin SHARD_NUMBER="${SHARD_NUMBER:=1}" NUM_TEST_SHARDS="${NUM_TEST_SHARDS:=1}" +# enable debug asserts in serialization +export TORCH_SERIALIZATION_DEBUG=1 + export VALGRIND=ON # export TORCH_INDUCTOR_INSTALL_GXX=ON if [[ "$BUILD_ENVIRONMENT" == *clang9* || "$BUILD_ENVIRONMENT" == *xpu* ]]; then diff --git a/.ci/pytorch/win-test.sh b/.ci/pytorch/win-test.sh index aa2a2cf88f9..0426982a3ad 100755 --- a/.ci/pytorch/win-test.sh +++ b/.ci/pytorch/win-test.sh @@ -18,6 +18,9 @@ export PYTORCH_FINAL_PACKAGE_DIR="${PYTORCH_FINAL_PACKAGE_DIR:-/c/w/build-result PYTORCH_FINAL_PACKAGE_DIR_WIN=$(cygpath -w "${PYTORCH_FINAL_PACKAGE_DIR}") export PYTORCH_FINAL_PACKAGE_DIR_WIN +# enable debug asserts in serialization +export TORCH_SERIALIZATION_DEBUG=1 + mkdir -p "$TMP_DIR"/build/torch export SCRIPT_HELPERS_DIR=$SCRIPT_PARENT_DIR/win-test-helpers diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index 2554d25fd2a..01f6da12fbf 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -250,11 +250,8 @@ constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28; constexpr int MZ_ZIP_DATA_DESCRIPTOR_ID = 0x08074b50; namespace detail { -size_t getPadding( - size_t cursor, - size_t filename_size, - size_t size, - std::string& padding_buf) { + +std::tuple getOffset(size_t cursor, size_t filename_size, size_t size) { size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size + sizeof(mz_uint16) * 2; if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) { @@ -268,6 +265,16 @@ size_t getPadding( } size_t mod = start % kFieldAlignment; size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod); + std::tuple result(next_offset, start); + return result; +} + +size_t getPadding( + size_t cursor, + size_t filename_size, + size_t size, + std::string& padding_buf) { + auto [next_offset, start] = getOffset(cursor, filename_size, size); size_t padding_size = next_offset - start; size_t padding_size_plus_fbxx = padding_size + 4; if (padding_buf.size() < padding_size_plus_fbxx) { @@ -610,6 +617,17 @@ size_t PyTorchStreamReader::getRecordSize(const std::string& name) { return stat.m_uncomp_size; } +size_t PyTorchStreamReader::getRecordOffsetNoRead( + size_t cursor, + std::string filename, + size_t size) { + std::string full_name = archive_name_plus_slash_ + filename; + size_t full_name_size = full_name.size(); + std::tuple result = detail::getOffset(cursor, full_name_size, size); + size_t offset = std::get<0>(result); + return offset; +} + PyTorchStreamReader::~PyTorchStreamReader() { mz_zip_clear_last_error(ar_.get()); mz_zip_reader_end(ar_.get()); diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 59e0991399a..0b6bb37b67b 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -174,6 +174,8 @@ class TORCH_API PyTorchStreamReader final { size_t getRecordSize(const std::string& name); size_t getRecordOffset(const std::string& name); + size_t + getRecordOffsetNoRead(size_t cursor, std::string filename, size_t size); bool hasRecord(const std::string& name); std::vector getAllRecords(); @@ -289,6 +291,9 @@ size_t getPadding( size_t filename_size, size_t size, std::string& padding_buf); + +std::tuple getOffset(size_t cursor, size_t filename_size, size_t size); + } // namespace detail } // namespace serialize diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index 3b74878e239..b3ba4feb22e 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -515,3 +515,6 @@ Config (Default : ``torch.serialization.LoadEndianness.NATIVE``) * ``mmap_flags``: See :class:`~torch.serialization.set_default_mmap_options`. (Default : ``MAP_PRIVATE``) + * ``calculate_storage_offsets``: If this config is set to ``True``, offsets for storages will be + calculated rather than read via random reads when using ``torch.load(mmap=True)``. This minimizes + random reads, which can be helpful when the file is being loaded over a network. (Default : ``False``) diff --git a/test/test_serialization.py b/test/test_serialization.py index ef7945a04f4..9945321bf6a 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -45,6 +45,7 @@ from torch.testing._internal.common_utils import ( BytesIOContext, download_file, instantiate_parametrized_tests, + IS_CI, IS_FBCODE, IS_FILESYSTEM_UTF8_ENCODING, IS_WINDOWS, @@ -827,6 +828,11 @@ class SerializationMixin: loaded_data = torch.load(f, weights_only=True) self.assertEqual(data, loaded_data) + @unittest.skipIf(not IS_CI, "only check debug var is set in CI") + def test_debug_set_in_ci(self): + # This test is to make sure that the serialization debug flag is set in CI + self.assertTrue(os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1") + class serialization_method: def __init__(self, use_zip): @@ -1041,6 +1047,23 @@ class TestSerialization(TestCase, SerializationMixin): f.seek(0) state = torch.load(f) + @serialTest() + def test_serialization_4gb_file(self): + ''' + This is a specially engineered testcase that would fail if the data_descriptor size + had been incorrectly set as data_descriptor_size32 when it should be data_descriptor_size64 + ''' + # Run GC to clear up as much memory as possible before running this test + gc.collect() + big_model = torch.nn.ModuleList([torch.nn.Linear(1, int(1024 * 1024 * 1024) + 12, bias=False), + torch.nn.Linear(1, 1, bias=False).to(torch.float8_e4m3fn), + torch.nn.Linear(1, 2, bias=False).to(torch.float8_e4m3fn)]) + + with BytesIOContext() as f: + torch.save(big_model.state_dict(), f) + f.seek(0) + torch.load(f) + @parametrize('weights_only', (True, False)) def test_pathlike_serialization(self, weights_only): model = torch.nn.Conv2d(20, 3200, kernel_size=3) @@ -4533,6 +4556,30 @@ class TestSerialization(TestCase, SerializationMixin): self.assertTrue(opened_zipfile.has_record(".format_version")) self.assertEqual(opened_zipfile.get_record(".format_version"), b'1') + @parametrize('path_type', (str, Path)) + @unittest.skipIf(IS_WINDOWS, "TemporaryFileName on windows") + def test_mmap_load_offset_calculation(self, path_type): + calculate_offsets_before = serialization_config.load.calculate_storage_offsets + try: + serialization_config.load.calculate_storage_offsets = True + m = torch.nn.Sequential(*[torch.nn.Linear(4, 4) for _ in range(20)]) + + with TemporaryFileName() as f: + f = path_type(f) + state_dict = m.state_dict() + torch.save(state_dict, f) + result = torch.load(f, mmap=True) + result_non_mmap = torch.load(f, mmap=False) + + with torch.device("meta"): + model_mmap_state_dict = torch.nn.Sequential(*[torch.nn.Linear(4, 4) for _ in range(20)]) + model_non_mmap_state_dict = torch.nn.Sequential(*[torch.nn.Linear(4, 4) for _ in range(20)]) + model_mmap_state_dict.load_state_dict(result, assign=True) + model_non_mmap_state_dict.load_state_dict(result_non_mmap, assign=True) + inp = torch.randn(4, 4) + self.assertEqual(model_mmap_state_dict(inp), model_non_mmap_state_dict(inp.clone())) + finally: + serialization_config.load.calculate_storage_offsets = calculate_offsets_before def run(self, *args, **kwargs): with serialization_method(use_zip=True): diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 411e61cf912..b64e2a9c050 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1619,6 +1619,15 @@ void initJITBindings(PyObject* module) { "get_record_offset", [](PyTorchStreamReader& self, const std::string& key) { return self.getRecordOffset(key); + }) + .def( + "get_record_offset_no_read", + [](PyTorchStreamReader& self, + size_t zipfile_header_offset, + const std::string filename, + size_t size) { + return self.getRecordOffsetNoRead( + zipfile_header_offset, filename, size); }); // Used by torch.Package to coordinate deserialization of storages across diff --git a/torch/serialization.py b/torch/serialization.py index b9d123c82f8..7658e375f34 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -15,7 +15,7 @@ import threading import warnings from contextlib import closing, contextmanager from enum import Enum -from typing import Any, Callable, cast, Generic, IO, Optional, TypeVar, Union +from typing import Any, Callable, cast, Dict, Generic, IO, Optional, TypeVar, Union from typing_extensions import TypeAlias, TypeIs import torch @@ -1856,6 +1856,11 @@ def _load( loaded_storages = {} + can_calculate_storage_offsets = False + if zip_file.has_record(".format_version"): + version = zip_file.get_record(".format_version") + can_calculate_storage_offsets = version >= b"1" + # check if byteswapping is needed byteordername = "byteorder" byteorderdata = None @@ -1891,15 +1896,90 @@ def _load( UserWarning, ) + from torch.utils.serialization import config + + calculate_storage_offsets = config.load.calculate_storage_offsets + run_debug_asserts = os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1" + current_offset = None + # constants from miniz.h/miniz.c + data_descripter_size64 = 24 + data_descripter_size32 = 16 + mz_uint32_max = 0xFFFFFFFF + offsets: Dict[str, int] = dict() + + def _get_offset(key, name, numel): + """ + Return the offset of the storage associated with key with record name `name` and size numel. + It is expected that the zipfile header of this storage starts at current_offset. + + WARNING: This function relies on the behavior of the zipwriter in miniz.c. In particular, + the behavior of `mz_zip_writer_add_mem_ex_v2`. The behavior of this function must be kept + in sync with that of miniz! + + After reading a storage of size numel that starts at storage_offset + if it is the first time that storage was read, update nonlocal variable + current_offset to the start of the next zipfile header by incrementing + it by numel and the data descriptor size. + """ + nonlocal current_offset, offsets + if name in offsets: + storage_offset = offsets[name] + return storage_offset + + if current_offset is None: + assert key == "0" + current_offset = zip_file.get_record_offset(name) + storage_offset = current_offset + else: + storage_offset = zip_file.get_record_offset_no_read( + current_offset, name, numel + ) + + # This is only actually needed for storages that have typed_storage._data_ptr() == 0 + # after being read. Otherwise persistent_load would never "re-call" load_tensor + # for a given key. + offsets[name] = storage_offset + + # Increment current_offset of offset where next zipfile header starts + local_header_offset = current_offset + current_offset = storage_offset + numel + # add size of data descriptor after payload + if local_header_offset >= mz_uint32_max or numel >= mz_uint32_max: + current_offset += data_descripter_size64 + else: + current_offset += data_descripter_size32 + + return storage_offset + def load_tensor(dtype, numel, key, location): name = f"data/{key}" if torch._guards.detect_fake_mode(None) is not None: nbytes = numel * torch._utils._element_size(dtype) storage = torch.UntypedStorage(nbytes, device="meta") elif overall_storage is not None: - storage_offset = zip_file.get_record_offset(name) + if can_calculate_storage_offsets and calculate_storage_offsets: + storage_offset = _get_offset(key, name, numel) + if run_debug_asserts: + if storage_offset != zip_file.get_record_offset(name): + raise RuntimeError( + "This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment " + f"variable was set: Incorrect offset for {name}, got {storage_offset} expected " + f"{zip_file.get_record_offset(name)}" + ) + else: + storage_offset = zip_file.get_record_offset(name) storage = overall_storage[storage_offset : storage_offset + numel] else: + if can_calculate_storage_offsets and run_debug_asserts: + # This is debug code that we use to test the validity of + # torch.utils.serialization.config.load.calculate_storage_offsets throughout CI + storage_offset = _get_offset(key, name, numel) + if storage_offset != zip_file.get_record_offset(name): + raise RuntimeError( + "This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment " + f"variable was set: Incorrect offset for {name}, got {storage_offset} expected " + f"{zip_file.get_record_offset(name)}" + ) storage = ( zip_file.get_storage_from_record(name, numel, torch.UntypedStorage) ._typed_storage() diff --git a/torch/utils/serialization/config.py b/torch/utils/serialization/config.py index 77138676d1c..0ef12f77d9d 100644 --- a/torch/utils/serialization/config.py +++ b/torch/utils/serialization/config.py @@ -13,6 +13,7 @@ class load: endianness: _Optional["_LoadEndianess"] = None # MAP_PRIVATE = 2 mmap_flags: _Optional[int] = None if sys.platform == "win32" else 2 + calculate_storage_offsets: bool = False class save: