mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: After consulting with Owen, who pointed out the existence of the miniz library, I decided to take one last shot at using zip as our container format. miniz makes this surprisingly feasible and I think the benefits of using zip are large enough that we should do it. This replaces our custom container format with a zip archive, preserving all of the desirable features of our custom format, such as append-oriented writing, and mmap'able tensor data while adding a bunch of debugging advantages: 1. You can unzip and explore the container to debug what is going on with a model. 2. You can edit the model using a text editor (e.g. change the definition of a method, or editing the json-serialized meta-data), re-zip the file use OSX's native 'Compress' option, and re-load the result into pytorch. Note: this enables you to, e.g., print-debug serialized models. 3. We can easily enable features like compression in the future. 4. Stock python , without pytorch installed, and other programming languages can reasonably consume this format,using json and zipfile packages, which enables people to build tools like visualizers without those visualizers depending on pytorch. This will be especially useful if you want to, for instance, write a visualizer in javascript. Notes: * This add miniz (https://github.com/richgel999/miniz) as a dependency. miniz is a self-contained library for reading/writing zipfiles that unlike other zip libraries also includes libz compatible compress/decompress support. It is a single header and a single C file without any other dependencies. Note that the instructions for miniz explicitly state: > Please use the files from the releases page in your projects. Do not use the git checkout directly! So we have checked in the 'release' source. Miniz supports zip64, and its API is amenable to doing zip-align style things to align data. * Removes 'size' from RecordRef. This allows you to edit files in the zip archive without editing the meta-data file. Very important if you want to print-debug serialized models. * PyTorchStreamReader/PyTorchStreamWriter keep mostly the same API (though keys become strings) However, their implementation is completely swapped out to use miniz. * Code exists to check for the old magic number to give a decent warning to our preview users after we change the format. * Container version information is now put in a stand-alone 'version' file in the archive and serves a similar purpose to the other container version info. * All files in the zip archive start at 64-byte boundaries, using an approach similar to zip-align. Tests check that this property remains true. While the writer does this, the reader doesn't depend on it, allowing user-created archives that can use compression, and do not have to align data. * Added test to check for > 4GB files and archives. Disabled by default because it takes almost 2 minutes to run. * torchscript files are now optional: if a submodule does not have methods, it will not be written. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14521 Reviewed By: jamesr66a Differential Revision: D13252945 Pulled By: zdevito fbshipit-source-id: 01209294c0f6543d0fd716f85a38532249c52f8c
305 lines
9.2 KiB
C++
305 lines
9.2 KiB
C++
#include <cstdio>
|
|
#include <cstring>
|
|
#include <cerrno>
|
|
#include <istream>
|
|
#include <ostream>
|
|
#include <fstream>
|
|
|
|
#include <c10/core/Allocator.h>
|
|
#include <ATen/core/Backend.h>
|
|
|
|
#include "caffe2/core/logging.h"
|
|
#include "caffe2/serialize/inline_container.h"
|
|
|
|
#include "miniz.h"
|
|
|
|
namespace torch { namespace jit {
|
|
|
|
size_t istream_read_func(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) {
|
|
auto self = static_cast<PyTorchStreamReader*>(pOpaque);
|
|
return self->read(file_ofs, static_cast<char*>(pBuf), n);
|
|
}
|
|
|
|
static std::string basename(const std::string& name) {
|
|
size_t start = 0;
|
|
for(size_t i = 0; i < name.size(); ++i) {
|
|
if (name[i] == '\\' || name[i] == '/') {
|
|
start = i + 1;
|
|
}
|
|
}
|
|
|
|
if (start >= name.size())
|
|
return "";
|
|
|
|
size_t end = name.size();
|
|
for(size_t i = end; i > start; --i) {
|
|
if (name[i - 1] == '.') {
|
|
end = i - 1;
|
|
break;
|
|
}
|
|
}
|
|
return name.substr(start, end - start);
|
|
}
|
|
|
|
size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) {
|
|
in_->seekg(pos);
|
|
if(!*in_)
|
|
return 0;
|
|
in_->read(static_cast<char*>(buf), n);
|
|
if(!*in_)
|
|
return 0;
|
|
return n;
|
|
}
|
|
|
|
PyTorchStreamReader::PyTorchStreamReader(std::string file_name, std::istream* in)
|
|
: ar_(new mz_zip_archive), in_(in) {
|
|
memset(ar_.get(), 0, sizeof(mz_zip_archive));
|
|
|
|
if (!in_) {
|
|
file_stream_.open(file_name, std::ifstream::in | std::ifstream::binary);
|
|
in_ = &file_stream_;
|
|
valid("opening archive");
|
|
}
|
|
|
|
in_->seekg(0, in_->end);
|
|
size_t size = in_->tellg();
|
|
|
|
// check for the old magic number,
|
|
constexpr size_t kMagicValueLength = 8;
|
|
if (size > kMagicValueLength) {
|
|
char buf[kMagicValueLength];
|
|
read(0, buf, kMagicValueLength);
|
|
valid("checking magic number");
|
|
AT_ASSERTM(
|
|
memcmp("PYTORCH1", buf, kMagicValueLength) != 0,
|
|
"File is an unsupported archive format from the preview release.");
|
|
}
|
|
|
|
ar_->m_pIO_opaque = this;
|
|
ar_->m_pRead = istream_read_func;
|
|
|
|
mz_zip_reader_init(ar_.get(), size, 0);
|
|
valid("reading zip archive");
|
|
|
|
|
|
// figure out the archive_name (i.e. the zip folder all the other files are in)
|
|
// all lookups to getRecord will be prefixed by this folder
|
|
int n = mz_zip_reader_get_num_files(ar_.get());
|
|
if (n == 0) {
|
|
CAFFE_THROW("archive does not contain any files");
|
|
}
|
|
size_t name_size = mz_zip_reader_get_filename(ar_.get(), 0, nullptr, 0);
|
|
valid("getting filename");
|
|
std::string buf(name_size, '\0');
|
|
mz_zip_reader_get_filename(ar_.get(), 0, &buf[0], name_size);
|
|
valid("getting filename");
|
|
auto pos = buf.find_first_of('/');
|
|
if (pos == std::string::npos) {
|
|
CAFFE_THROW("file in archive is not in a subdirectory: ", buf);
|
|
}
|
|
archive_name_ = buf.substr(0, pos);
|
|
|
|
// version check
|
|
at::DataPtr version_ptr;
|
|
size_t version_size;
|
|
std::tie(version_ptr, version_size) = getRecord("version");
|
|
std::string version(static_cast<const char*>(version_ptr.get()), version_size);
|
|
size_t version_number = caffe2::stoull(version);
|
|
AT_ASSERTM(
|
|
version_number >= kMinSupportedFileFormatVersion,
|
|
"Attempted to read a PyTorch file with version ",
|
|
c10::to_string(version_number),
|
|
", but the minimum supported version for reading is ",
|
|
c10::to_string(kMinSupportedFileFormatVersion),
|
|
". Your PyTorch script module file is too old. Please re-export it again.");
|
|
AT_ASSERTM(
|
|
version_number <= kMaxSupportedFileFormatVersion,
|
|
"Attempted to read a PyTorch file with version ",
|
|
version_number,
|
|
", but the maximum supported version for reading is ",
|
|
kMaxSupportedFileFormatVersion,
|
|
". Your PyTorch installation may be too old.");
|
|
}
|
|
|
|
void PyTorchStreamReader::valid(const char* what) {
|
|
auto err = mz_zip_get_last_error(ar_.get());
|
|
if (err != MZ_ZIP_NO_ERROR) {
|
|
CAFFE_THROW("PytorchStreamReader failed ", what, ": ", mz_zip_get_error_string(err));
|
|
}
|
|
if (!*in_) {
|
|
CAFFE_THROW("PytorchStreamReader failed ", what, ".");
|
|
}
|
|
}
|
|
|
|
constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
|
|
constexpr int MZ_ZIP_LDH_FILENAME_LEN_OFS = 26;
|
|
constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
|
|
|
|
static std::string getPadding(size_t cursor, const std::string& filename, size_t size) {
|
|
size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename.size() + sizeof(mz_uint16) * 2;
|
|
if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
|
|
start += sizeof(mz_uint16) * 2;
|
|
if (size >= MZ_UINT32_MAX) {
|
|
start += 2*sizeof(mz_uint64);
|
|
}
|
|
if (cursor >= MZ_UINT32_MAX) {
|
|
start += sizeof(mz_uint64);
|
|
}
|
|
}
|
|
size_t mod = start % kFieldAlignment;
|
|
size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
|
|
size_t padding_size = next_offset - start;
|
|
std::string buf(padding_size + 4, 'Z');
|
|
// zip extra encoding (key, size_of_extra_bytes)
|
|
buf[0] = 'F';
|
|
buf[1] = 'B';
|
|
buf[2] = (uint8_t) padding_size;
|
|
buf[3] = (uint8_t) (padding_size >> 8);
|
|
return buf;
|
|
}
|
|
|
|
size_t PyTorchStreamReader::getFileID(const std::string& name) {
|
|
std::stringstream ss;
|
|
ss << archive_name_ << "/" << name;
|
|
size_t result = mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0);
|
|
if (ar_->m_last_error == MZ_ZIP_FILE_NOT_FOUND) {
|
|
CAFFE_THROW("file not found: ", ss.str());
|
|
}
|
|
valid("locating file");
|
|
return result;
|
|
}
|
|
|
|
// return dataptr, size
|
|
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
|
|
size_t key = getFileID(name);
|
|
mz_zip_archive_file_stat stat;
|
|
mz_zip_reader_file_stat(ar_.get(), key, &stat);
|
|
valid("retrieving file meta-data");
|
|
void * ptr = malloc(stat.m_uncomp_size);
|
|
mz_zip_reader_extract_to_mem(ar_.get(), key, ptr, stat.m_uncomp_size, 0);
|
|
valid("reading file");
|
|
|
|
at::DataPtr retval(ptr, ptr, free, at::kCPU);
|
|
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
|
|
}
|
|
|
|
static int64_t read_le_16(uint8_t* buf) {
|
|
return buf[0] + (buf[1] << 8);
|
|
}
|
|
|
|
size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
|
|
mz_zip_archive_file_stat stat;
|
|
mz_zip_reader_file_stat(ar_.get(), getFileID(name), &stat);
|
|
valid("retriving file meta-data");
|
|
in_->seekg(stat.m_local_header_ofs);
|
|
valid("seeking to file header");
|
|
uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
|
|
in_->read(reinterpret_cast<char*>(local_header), MZ_ZIP_LOCAL_DIR_HEADER_SIZE);
|
|
valid("reading file header");
|
|
size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
|
|
size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
|
|
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
|
|
}
|
|
|
|
|
|
PyTorchStreamReader::~PyTorchStreamReader() {
|
|
mz_zip_reader_end(ar_.get());
|
|
valid("closing reader");
|
|
}
|
|
|
|
size_t ostream_write_func(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n) {
|
|
auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
|
|
if (self->current_pos_ != file_ofs) {
|
|
// xxx - windows ostringstream refuses to seek to the end of an empty string
|
|
// so we workaround this by not calling seek unless necessary
|
|
// in the case of the first write (to the empty string) file_ofs and
|
|
// current_pos_ will be 0 and the seek won't occur.
|
|
self->out_->seekp(file_ofs);
|
|
if(!*self->out_)
|
|
return 0;
|
|
}
|
|
|
|
self->out_->write(static_cast<const char*>(pBuf), n);
|
|
if(!*self->out_)
|
|
return 0;
|
|
self->current_pos_ = file_ofs + n;
|
|
return n;
|
|
}
|
|
|
|
PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name, std::ostream* out)
|
|
: ar_(new mz_zip_archive), archive_name_(basename(file_name)), out_(out) {
|
|
memset(ar_.get(), 0, sizeof(mz_zip_archive));
|
|
|
|
if (archive_name_.size() == 0) {
|
|
CAFFE_THROW("invalid file name: ", file_name);
|
|
}
|
|
if (!out_) {
|
|
file_stream_.open(file_name, std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
|
|
out_ = &file_stream_;
|
|
valid("opening archive");
|
|
}
|
|
|
|
ar_->m_pIO_opaque = this;
|
|
ar_->m_pWrite = ostream_write_func;
|
|
|
|
mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
|
|
valid("initializing archive");
|
|
|
|
std::stringstream version;
|
|
version << kMaxSupportedFileFormatVersion << "\n";
|
|
writeRecord("version", version.str().c_str(), version.str().size());
|
|
}
|
|
|
|
void PyTorchStreamWriter::writeRecord(const std::string& name, const void* data, size_t size) {
|
|
AT_ASSERT(!finalized_);
|
|
std::stringstream ss;
|
|
ss << archive_name_ << "/" << name;
|
|
const std::string& full_name = ss.str();
|
|
std::string padding = getPadding(ar_->m_archive_size, full_name, size);
|
|
uint32_t flags = 0;
|
|
mz_zip_writer_add_mem_ex_v2(
|
|
ar_.get(),
|
|
full_name.c_str(),
|
|
data,
|
|
size,
|
|
nullptr,
|
|
0,
|
|
flags,
|
|
0,
|
|
0,
|
|
nullptr,
|
|
padding.c_str(),
|
|
padding.size(),
|
|
nullptr,
|
|
0);
|
|
valid("writing file");
|
|
}
|
|
|
|
void PyTorchStreamWriter::writeEndOfFile() {
|
|
AT_ASSERT(!finalized_);
|
|
finalized_ = true;
|
|
mz_zip_writer_finalize_archive(ar_.get());
|
|
mz_zip_writer_end(ar_.get());
|
|
valid("writing central directory");
|
|
if (file_stream_.is_open())
|
|
file_stream_.close();
|
|
}
|
|
|
|
|
|
void PyTorchStreamWriter::valid(const char* what) {
|
|
auto err = mz_zip_get_last_error(ar_.get());
|
|
if (err != MZ_ZIP_NO_ERROR) {
|
|
CAFFE_THROW("PytorchStreamWriter failed ", what, ": ", mz_zip_get_error_string(err));
|
|
}
|
|
if (!*out_) {
|
|
CAFFE_THROW("PytorchStreamWriter failed ", what, ".");
|
|
}
|
|
}
|
|
|
|
PyTorchStreamWriter::~PyTorchStreamWriter() {
|
|
if (!finalized_) {
|
|
writeEndOfFile();
|
|
}
|
|
}
|
|
|
|
}} // namespace torch::jit
|