mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61805 Similar in spirit to https://github.com/pytorch/pytorch/pull/61371. While writing two files with the same name is allowed by the ZIP format, most tools (including our own) handle this poorly. Previously I banned this within `PackageExporter`, but that doesn't cover other uses of the zip format like TorchScript. Given that there are no valid use cases and debugging issues caused by multiple file writes is fiendishly difficult, banning this behavior enitrely. Differential Revision: D29748968 D29748968 Test Plan: Imported from OSS Reviewed By: Lilyjjo Pulled By: suo fbshipit-source-id: 0afee1506c59c0f283ef41e4be562f9c22f21023
126 lines
4 KiB
C++
126 lines
4 KiB
C++
#include <cstdio>
|
|
#include <string>
|
|
#include <array>
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include "caffe2/serialize/inline_container.h"
|
|
|
|
namespace caffe2 {
|
|
namespace serialize {
|
|
namespace {
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
|
int64_t kFieldAlignment = 64L;
|
|
|
|
std::ostringstream oss;
|
|
// write records through writers
|
|
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
|
|
oss.write(static_cast<const char*>(b), n);
|
|
return oss ? n : 0;
|
|
});
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
|
|
std::array<char, 127> data1;
|
|
|
|
for (int i = 0; i < data1.size(); ++i) {
|
|
data1[i] = data1.size() - i;
|
|
}
|
|
writer.writeRecord("key1", data1.data(), data1.size());
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
|
|
std::array<char, 64> data2;
|
|
for (int i = 0; i < data2.size(); ++i) {
|
|
data2[i] = data2.size() - i;
|
|
}
|
|
writer.writeRecord("key2", data2.data(), data2.size());
|
|
|
|
const std::unordered_set<std::string>& written_records =
|
|
writer.getAllWrittenRecords();
|
|
ASSERT_EQ(written_records.size(), 2);
|
|
ASSERT_EQ(written_records.count("key1"), 1);
|
|
ASSERT_EQ(written_records.count("key2"), 1);
|
|
|
|
writer.writeEndOfFile();
|
|
|
|
std::string the_file = oss.str();
|
|
std::ofstream foo("output.zip");
|
|
foo.write(the_file.c_str(), the_file.size());
|
|
foo.close();
|
|
|
|
std::istringstream iss(the_file);
|
|
|
|
// read records through readers
|
|
PyTorchStreamReader reader(&iss);
|
|
ASSERT_TRUE(reader.hasRecord("key1"));
|
|
ASSERT_TRUE(reader.hasRecord("key2"));
|
|
ASSERT_FALSE(reader.hasRecord("key2000"));
|
|
at::DataPtr data_ptr;
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
int64_t size;
|
|
std::tie(data_ptr, size) = reader.getRecord("key1");
|
|
size_t off1 = reader.getRecordOffset("key1");
|
|
ASSERT_EQ(size, data1.size());
|
|
ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0);
|
|
ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0);
|
|
ASSERT_EQ(off1 % kFieldAlignment, 0);
|
|
|
|
std::tie(data_ptr, size) = reader.getRecord("key2");
|
|
size_t off2 = reader.getRecordOffset("key2");
|
|
ASSERT_EQ(off2 % kFieldAlignment, 0);
|
|
|
|
ASSERT_EQ(size, data2.size());
|
|
ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0);
|
|
ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
|
|
std::ostringstream oss;
|
|
// write records through writers
|
|
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
|
|
oss.write(static_cast<const char*>(b), n);
|
|
return oss ? n : 0;
|
|
});
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
|
|
std::array<char, 127> data1;
|
|
|
|
for (int i = 0; i < data1.size(); ++i) {
|
|
data1[i] = data1.size() - i;
|
|
}
|
|
writer.writeRecord("key1", data1.data(), data1.size());
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
|
|
std::array<char, 64> data2;
|
|
for (int i = 0; i < data2.size(); ++i) {
|
|
data2[i] = data2.size() - i;
|
|
}
|
|
writer.writeRecord("key2", data2.data(), data2.size());
|
|
|
|
const std::unordered_set<std::string>& written_records =
|
|
writer.getAllWrittenRecords();
|
|
ASSERT_EQ(written_records.size(), 2);
|
|
ASSERT_EQ(written_records.count("key1"), 1);
|
|
ASSERT_EQ(written_records.count("key2"), 1);
|
|
|
|
writer.writeEndOfFile();
|
|
|
|
std::string the_file = oss.str();
|
|
std::ofstream foo("output2.zip");
|
|
foo.write(the_file.c_str(), the_file.size());
|
|
foo.close();
|
|
|
|
std::istringstream iss(the_file);
|
|
|
|
// read records through readers
|
|
PyTorchStreamReader reader(&iss);
|
|
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
|
EXPECT_THROW(reader.getRecord("key3"), c10::Error);
|
|
|
|
// Reader should still work after throwing
|
|
EXPECT_TRUE(reader.hasRecord("key1"));
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace serialize
|
|
} // namespace caffe2
|