From f02cfcc802555e0f47a97aa1338178505383400f Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Mon, 19 Jul 2021 18:20:53 -0700 Subject: [PATCH] ban PyTorchStreamWriter from writing the same file twice (#61805) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61805 Similar in spirit to https://github.com/pytorch/pytorch/pull/61371. While writing two files with the same name is allowed by the ZIP format, most tools (including our own) handle this poorly. Previously I banned this within `PackageExporter`, but that doesn't cover other uses of the zip format like TorchScript. Given that there are no valid use cases and debugging issues caused by multiple file writes is fiendishly difficult, banning this behavior enitrely. Differential Revision: D29748968 D29748968 Test Plan: Imported from OSS Reviewed By: Lilyjjo Pulled By: suo fbshipit-source-id: 0afee1506c59c0f283ef41e4be562f9c22f21023 --- caffe2/serialize/inline_container.cc | 10 +++++++--- caffe2/serialize/inline_container.h | 7 ++++--- caffe2/serialize/inline_container_test.cc | 16 ++++++++++------ torch/csrc/jit/mobile/backport_manager.cpp | 2 +- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index d74dfb4a4a1..cb36fe0a4bc 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -9,6 +9,7 @@ #include #include #include +#include #include "caffe2/core/common.h" #include "caffe2/core/logging.h" @@ -235,8 +236,9 @@ std::vector PyTorchStreamReader::getAllRecords() { return out; } -const std::vector& PyTorchStreamWriter::getAllWrittenRecords() { - return files_written; +const std::unordered_set& +PyTorchStreamWriter::getAllWrittenRecords() { + return files_written_; } size_t PyTorchStreamReader::getRecordID(const std::string& name) { @@ -356,6 +358,8 @@ void PyTorchStreamWriter::writeRecord( bool compress) { AT_ASSERT(!finalized_); AT_ASSERT(!archive_name_plus_slash_.empty()); + TORCH_INTERNAL_ASSERT( + files_written_.count(name) == 0, "Tried to serialize file twice: ", name); std::string full_name = archive_name_plus_slash_ + name; size_t padding_size = detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_); @@ -376,7 +380,7 @@ void PyTorchStreamWriter::writeRecord( nullptr, 0); valid("writing file ", name.c_str()); - files_written.push_back(name); + files_written_.insert(name); } void PyTorchStreamWriter::writeEndOfFile() { diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 281d1756d75..4eb1b8e71ce 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -140,7 +141,7 @@ class TORCH_API PyTorchStreamWriter final { bool compress = false); void writeEndOfFile(); - const std::vector& getAllWrittenRecords(); + const std::unordered_set& getAllWrittenRecords(); bool finalized() const { return finalized_; @@ -156,7 +157,7 @@ class TORCH_API PyTorchStreamWriter final { void setup(const std::string& file_name); void valid(const char* what, const char* info = ""); size_t current_pos_ = 0; - std::vector files_written; + std::unordered_set files_written_; std::unique_ptr ar_; std::string archive_name_; std::string archive_name_plus_slash_; @@ -184,7 +185,7 @@ size_t getPadding( size_t filename_size, size_t size, std::string& padding_buf); -} +} // namespace detail } // namespace serialize } // namespace caffe2 diff --git a/caffe2/serialize/inline_container_test.cc b/caffe2/serialize/inline_container_test.cc index 3a9f511ee9c..7a65bf1ab45 100644 --- a/caffe2/serialize/inline_container_test.cc +++ b/caffe2/serialize/inline_container_test.cc @@ -35,9 +35,11 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { } writer.writeRecord("key2", data2.data(), data2.size()); - const std::vector& written_records = writer.getAllWrittenRecords(); - ASSERT_EQ(written_records[0], "key1"); - ASSERT_EQ(written_records[1], "key2"); + const std::unordered_set& written_records = + writer.getAllWrittenRecords(); + ASSERT_EQ(written_records.size(), 2); + ASSERT_EQ(written_records.count("key1"), 1); + ASSERT_EQ(written_records.count("key2"), 1); writer.writeEndOfFile(); @@ -95,9 +97,11 @@ TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) { } writer.writeRecord("key2", data2.data(), data2.size()); - const std::vector& written_records = writer.getAllWrittenRecords(); - ASSERT_EQ(written_records[0], "key1"); - ASSERT_EQ(written_records[1], "key2"); + const std::unordered_set& written_records = + writer.getAllWrittenRecords(); + ASSERT_EQ(written_records.size(), 2); + ASSERT_EQ(written_records.count("key1"), 1); + ASSERT_EQ(written_records.count("key2"), 1); writer.writeEndOfFile(); diff --git a/torch/csrc/jit/mobile/backport_manager.cpp b/torch/csrc/jit/mobile/backport_manager.cpp index 91c8548ee7d..3cd815626c0 100644 --- a/torch/csrc/jit/mobile/backport_manager.cpp +++ b/torch/csrc/jit/mobile/backport_manager.cpp @@ -244,7 +244,7 @@ void writeArchiveV5( std::string prefix = archive_name + "/"; TORCH_INTERNAL_ASSERT(tensor_names.size() == data_pickle.tensorData().size()); - const std::vector& pre_serialized_files = + const std::unordered_set& pre_serialized_files = writer.getAllWrittenRecords(); for (const auto& td : data_pickle.tensorData()) {