2018-07-28 05:21:05 +00:00
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <cstdio>
|
|
|
|
|
#include <cstring>
|
|
|
|
|
#include <cerrno>
|
2018-09-28 14:41:26 +00:00
|
|
|
#include <istream>
|
|
|
|
|
#include <ostream>
|
|
|
|
|
#include <fstream>
|
2018-07-28 05:21:05 +00:00
|
|
|
|
2018-11-27 20:43:24 +00:00
|
|
|
#include <c10/core/Allocator.h>
|
2018-10-26 19:04:57 +00:00
|
|
|
#include <ATen/core/Backend.h>
|
|
|
|
|
|
|
|
|
|
#include "caffe2/core/logging.h"
|
|
|
|
|
|
2018-07-28 05:21:05 +00:00
|
|
|
namespace torch { namespace jit {
|
|
|
|
|
|
|
|
|
|
// This file defines an on-disk serialization format to be used for PyTorch
|
|
|
|
|
// model serialization. All integer values are serialized as little-endian.
|
|
|
|
|
// Everything in this format is aligned to 64-byte boundaries to allow for direct
|
|
|
|
|
// memory mapping and use in, for example, AVX512 instructions.
|
|
|
|
|
// The format is as follows:
|
|
|
|
|
//
|
|
|
|
|
// -- File header --
|
|
|
|
|
// [8 bytes] Magic number - little endian integer that spells 'PYTORCH1' in ASCII
|
|
|
|
|
// [8 bytes] Version number - The version of this file format that this file is in.
|
|
|
|
|
// this allows us to revise and extend this format
|
2018-10-18 16:41:30 +00:00
|
|
|
// [48 bytes] Padding/reserved
|
2018-07-28 05:21:05 +00:00
|
|
|
//
|
|
|
|
|
// After the file header reside N records of the format
|
|
|
|
|
// [8 bytes] Tag - this is a tag that identifies the type of this record. The
|
|
|
|
|
// values are defined in the RecordTags enum below.
|
|
|
|
|
// [8 bytes] size - Size in bytes of the payload of this record
|
2018-10-18 16:41:30 +00:00
|
|
|
// [48 bytes] Pad/reserved - This space pads out the payload to a 64-byte alignment.
|
2018-07-28 05:21:05 +00:00
|
|
|
// [size bytes] Payload - The actual raw data for the object serialized in this record
|
|
|
|
|
// [size - (size % 64) bytes] Pad/reserved - pad out this record so the next
|
|
|
|
|
// one is aligned to 64 bytes
|
|
|
|
|
//
|
|
|
|
|
// Following those records is a special footer:
|
|
|
|
|
// [8 bytes] Tag - This tag field should contain the value for RecordTags::FOOTER
|
|
|
|
|
// to correctly identify the footer
|
|
|
|
|
// [8 bytes] Offset of last record - The last record in this format is used
|
|
|
|
|
// as an index into the rest of the file, so
|
|
|
|
|
// a reader can use this offset to seek to
|
|
|
|
|
// the last record and read the index.
|
2018-10-18 16:41:30 +00:00
|
|
|
// [48 bytes] Pad/reserved - Pad out the footer s.t. the whole file's size is a
|
2018-07-28 05:21:05 +00:00
|
|
|
// multiple of 64 bytes.
|
|
|
|
|
//
|
|
|
|
|
//
|
|
|
|
|
// When developing this format we want to pay particular attention to the
|
|
|
|
|
// following use cases:
|
|
|
|
|
//
|
|
|
|
|
// -- Reading --
|
|
|
|
|
// 1) Reading with full random access
|
|
|
|
|
// a) Reading with file api's such as fread()
|
|
|
|
|
// b) mmaping the file and jumping around the mapped region
|
|
|
|
|
// 2) Reading with 1-pass sequential access
|
|
|
|
|
// -> A reader will need to build up a data structure of parsed structures
|
|
|
|
|
// as it reads
|
|
|
|
|
//
|
|
|
|
|
// -- Writing --
|
|
|
|
|
// 1) Writing with full random access
|
|
|
|
|
// 2) Writing with 1-pass sequential access
|
|
|
|
|
// -> We must take care not to require updating values that have already
|
|
|
|
|
// been written. We place the variable-length index at the end and do
|
|
|
|
|
// not put any indicies into the header to fulfill this constraint.
|
|
|
|
|
|
2018-10-26 19:04:57 +00:00
|
|
|
// The serialized model, which contains all the metadata information,
|
|
|
|
|
// should be stored as the last record. One major reason is supporting
|
|
|
|
|
// the continuous writing. While writing to file, the index/offset of a tensor
|
|
|
|
|
// is unknown until we start dumping it. So we would like to put the model
|
|
|
|
|
// data (i.e. the header) in the end to allow hard coding the offsets inside
|
|
|
|
|
// the model metadata. Another reasons is that the size of tensor data is
|
|
|
|
|
// usually stable. As long as the shape and type of the tensor do not change,
|
|
|
|
|
// the size of the data won't change. On the other sied, the size of the
|
|
|
|
|
// serialized model is likely to change, so we store it as the last record, and
|
|
|
|
|
// we don't need to move previous records when updating the model data.
|
|
|
|
|
|
2018-07-28 05:21:05 +00:00
|
|
|
namespace {
|
|
|
|
|
|
2018-10-26 19:04:57 +00:00
|
|
|
enum RecordTags {
|
|
|
|
|
STORAGE = 1,
|
|
|
|
|
FOOTER = 2,
|
|
|
|
|
};
|
2018-07-28 05:21:05 +00:00
|
|
|
|
2018-10-26 19:04:57 +00:00
|
|
|
// Common constants
|
|
|
|
|
constexpr uint64_t kFileMagicNumber = 0x314843524f545950L; // PYTORCH1
|
|
|
|
|
constexpr uint64_t kFieldAlignment =
|
|
|
|
|
64L; // 64 byte alignment supports up to AVX512 for mmap
|
2018-07-28 05:21:05 +00:00
|
|
|
|
2018-10-26 19:04:57 +00:00
|
|
|
// Reader-specific constants
|
2018-11-14 08:19:08 +00:00
|
|
|
// FileFOrmatVersion 1 was used in PyTorch 1.0 rc, which is a hacked ONNX proto.
|
|
|
|
|
constexpr uint64_t kMinSupportedFileFormatVersion = 0x2L;
|
|
|
|
|
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x2L;
|
2018-10-26 19:04:57 +00:00
|
|
|
|
|
|
|
|
// Writer-specific constants
|
2018-11-14 08:19:08 +00:00
|
|
|
constexpr uint64_t kFileFormatVersion = 0x2L;
|
2018-10-26 19:04:57 +00:00
|
|
|
constexpr char kPadValue = -17; // 0xEF
|
2018-07-28 05:21:05 +00:00
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2018-10-26 19:04:57 +00:00
|
|
|
class PyTorchStreamReader final {
|
2018-07-28 05:21:05 +00:00
|
|
|
public:
|
2018-10-26 19:04:57 +00:00
|
|
|
PyTorchStreamReader(std::istream* in) : in_(in) {
|
2018-07-28 05:21:05 +00:00
|
|
|
// Store file size so we know when we're done reading because the f* APIs
|
|
|
|
|
// don't do a good job of that
|
2018-10-26 19:04:57 +00:00
|
|
|
in_->seekg(0L, in_->end);
|
|
|
|
|
file_size_ = in_->tellg();
|
|
|
|
|
readAndValidateFileFooter();
|
2018-07-28 05:21:05 +00:00
|
|
|
// Do this now since we're reasonably sure this is actually a PyT file from
|
|
|
|
|
// the header.
|
2018-10-26 19:04:57 +00:00
|
|
|
AT_ASSERTM(
|
|
|
|
|
file_size_ % kFieldAlignment == 0,
|
|
|
|
|
"File length is not a multiple of the alignment"
|
2018-11-14 08:19:08 +00:00
|
|
|
" size. Is this a valid PyTorch model file? File size: ",
|
2018-11-16 20:06:21 +00:00
|
|
|
c10::to_string(file_size_));
|
2018-10-26 19:04:57 +00:00
|
|
|
readAndValidateFileHeader();
|
2018-07-28 05:21:05 +00:00
|
|
|
}
|
2018-10-26 19:04:57 +00:00
|
|
|
|
2018-08-28 02:22:41 +00:00
|
|
|
std::tuple<at::DataPtr, size_t> getLastRecord() {
|
2018-10-26 19:04:57 +00:00
|
|
|
return getRecordWithKey(last_record_offset_);
|
2018-07-28 05:21:05 +00:00
|
|
|
}
|
2018-10-26 19:04:57 +00:00
|
|
|
|
|
|
|
|
// return dataptr, size
|
2018-08-28 02:22:41 +00:00
|
|
|
std::tuple<at::DataPtr, size_t> getRecordWithKey(uint64_t key) {
|
2018-07-28 05:21:05 +00:00
|
|
|
// Seek to the provided offset
|
2018-10-26 19:04:57 +00:00
|
|
|
cursor_ = key;
|
|
|
|
|
in_->seekg(cursor_);
|
|
|
|
|
|
|
|
|
|
at::DataPtr retval;
|
|
|
|
|
size_t size;
|
|
|
|
|
size_t retkey;
|
|
|
|
|
std::tie(retval, retkey, size) = getNextRecord();
|
|
|
|
|
AT_ASSERT(key == retkey);
|
|
|
|
|
return std::tuple<at::DataPtr, size_t>(std::move(retval), size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// return dataptr, key, size
|
|
|
|
|
std::tuple<at::DataPtr, size_t, size_t> getNextRecord() {
|
|
|
|
|
size_t key = cursor_;
|
|
|
|
|
AT_ASSERTM(hasNextRecord(), "No more record, but hasNextRecord is called.");
|
|
|
|
|
AT_ASSERTM(
|
|
|
|
|
key % kFieldAlignment == 0,
|
|
|
|
|
"Provided key is not divisible by the alignment size.");
|
2018-07-28 05:21:05 +00:00
|
|
|
auto tag = read64BitIntegerLittleEndian();
|
2018-10-26 19:04:57 +00:00
|
|
|
AT_ASSERTM(
|
|
|
|
|
tag == RecordTags::STORAGE,
|
|
|
|
|
"Attempted to read a record of non-storage type");
|
2018-07-28 05:21:05 +00:00
|
|
|
auto size = read64BitIntegerLittleEndian();
|
|
|
|
|
seekToNextAlignmentBoundary();
|
2018-10-26 19:04:57 +00:00
|
|
|
auto* ptr = malloc(size);
|
2018-08-28 02:22:41 +00:00
|
|
|
at::DataPtr retval(ptr, ptr, free, at::kCPU);
|
2018-09-28 14:41:26 +00:00
|
|
|
|
2018-10-26 19:04:57 +00:00
|
|
|
in_->read(static_cast<char*>(ptr), size);
|
|
|
|
|
cursor_ += size;
|
2018-07-28 05:21:05 +00:00
|
|
|
seekToNextAlignmentBoundary();
|
2018-10-26 19:04:57 +00:00
|
|
|
return std::tuple<at::DataPtr, size_t, size_t>(
|
|
|
|
|
std::move(retval), key, size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool hasNextRecord() const {
|
|
|
|
|
// if this is not the last record, at least we have
|
|
|
|
|
// another record header (kFieldAlignment) and
|
|
|
|
|
// the footer (kFieldAlignment)
|
|
|
|
|
return cursor_ + kFieldAlignment * 2 <= file_size_;
|
2018-07-28 05:21:05 +00:00
|
|
|
}
|
2018-10-26 19:04:57 +00:00
|
|
|
|
2018-09-28 14:41:26 +00:00
|
|
|
~PyTorchStreamReader() {
|
2018-07-28 05:21:05 +00:00
|
|
|
}
|
2018-10-26 19:04:57 +00:00
|
|
|
|
2018-07-28 05:21:05 +00:00
|
|
|
private:
|
2018-10-26 19:04:57 +00:00
|
|
|
std::istream* in_;
|
|
|
|
|
size_t cursor_ = 0;
|
|
|
|
|
size_t file_size_;
|
|
|
|
|
size_t last_record_offset_;
|
2018-07-28 05:21:05 +00:00
|
|
|
|
|
|
|
|
// Utility functions
|
|
|
|
|
uint64_t read64BitIntegerLittleEndian() {
|
2018-10-26 19:04:57 +00:00
|
|
|
uint64_t retval;
|
|
|
|
|
// TODO endian swap on platforms that need it?
|
|
|
|
|
in_->read(reinterpret_cast<char*>(&retval), 8);
|
|
|
|
|
std::streamsize read_bytes = in_->gcount();
|
|
|
|
|
AT_ASSERTM(
|
|
|
|
|
read_bytes == 8,
|
2018-11-13 20:28:24 +00:00
|
|
|
"Expected to read 8 bytes but got ", read_bytes, " bytes");
|
2018-10-26 19:04:57 +00:00
|
|
|
cursor_ += read_bytes;
|
|
|
|
|
return retval;
|
2018-07-28 05:21:05 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void seekToNextAlignmentBoundary() {
|
2018-10-26 19:04:57 +00:00
|
|
|
size_t next_offset =
|
|
|
|
|
(cursor_ + kFieldAlignment) - (cursor_ % kFieldAlignment);
|
|
|
|
|
size_t pad_amount = next_offset - cursor_;
|
|
|
|
|
cursor_ += pad_amount;
|
|
|
|
|
in_->seekg(cursor_);
|
2018-07-28 05:21:05 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// File format deserialization functions
|
|
|
|
|
void readAndValidateFileHeader() {
|
2018-10-26 19:04:57 +00:00
|
|
|
// Validate magic number
|
|
|
|
|
cursor_ = 0;
|
|
|
|
|
in_->seekg(cursor_);
|
|
|
|
|
uint64_t magic = read64BitIntegerLittleEndian();
|
|
|
|
|
AT_ASSERTM(
|
|
|
|
|
magic == kFileMagicNumber,
|
|
|
|
|
"Magic number mismatch in PyTorch file. File may"
|
|
|
|
|
" be corrupted or is not actually a PyTorch file.");
|
|
|
|
|
// magic number mismatch in PyTorch file.
|
|
|
|
|
uint64_t file_format_version = read64BitIntegerLittleEndian();
|
2018-11-14 08:19:08 +00:00
|
|
|
AT_ASSERTM(
|
|
|
|
|
file_format_version >= kMinSupportedFileFormatVersion,
|
|
|
|
|
"Attempted to read a PyTorch file with version ",
|
2018-11-16 20:06:21 +00:00
|
|
|
c10::to_string(file_format_version),
|
2018-11-14 08:19:08 +00:00
|
|
|
", but the minimum supported version for reading is ",
|
2018-11-16 20:06:21 +00:00
|
|
|
c10::to_string(kMinSupportedFileFormatVersion),
|
2018-11-14 08:19:08 +00:00
|
|
|
". Your PyTorch script module file is too old. Please re-export it again.");
|
2018-10-26 19:04:57 +00:00
|
|
|
AT_ASSERTM(
|
|
|
|
|
file_format_version <= kMaxSupportedFileFormatVersion,
|
2018-11-13 20:28:24 +00:00
|
|
|
"Attempted to read a PyTorch file with version ",
|
2018-10-26 19:04:57 +00:00
|
|
|
file_format_version,
|
2018-11-13 20:28:24 +00:00
|
|
|
", but the maximum supported version for reading is ",
|
|
|
|
|
kMaxSupportedFileFormatVersion,
|
|
|
|
|
". Your PyTorch installation may be too old.");
|
2018-10-26 19:04:57 +00:00
|
|
|
seekToNextAlignmentBoundary();
|
2018-07-28 05:21:05 +00:00
|
|
|
}
|
2018-10-26 19:04:57 +00:00
|
|
|
|
2018-07-28 05:21:05 +00:00
|
|
|
void readAndValidateFileFooter() {
|
|
|
|
|
// Seek to location of file footer. We've already validated that the file
|
|
|
|
|
// length is a multiple of the alignment size
|
2018-10-26 19:04:57 +00:00
|
|
|
cursor_ = file_size_ - kFieldAlignment;
|
|
|
|
|
in_->seekg(cursor_);
|
2018-07-28 05:21:05 +00:00
|
|
|
auto tag = read64BitIntegerLittleEndian();
|
2018-10-26 19:04:57 +00:00
|
|
|
AT_ASSERTM(
|
|
|
|
|
tag == RecordTags::FOOTER,
|
|
|
|
|
"File footer has wrong record type. Is this file corrupted?");
|
|
|
|
|
last_record_offset_ = read64BitIntegerLittleEndian();
|
|
|
|
|
AT_ASSERTM(
|
|
|
|
|
last_record_offset_ < file_size_,
|
|
|
|
|
"Offset of last record is higher than the size"
|
|
|
|
|
" of the file! Is this file corrupted?");
|
2018-07-28 05:21:05 +00:00
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2018-10-26 19:04:57 +00:00
|
|
|
class PyTorchStreamWriter final {
|
2018-07-28 05:21:05 +00:00
|
|
|
public:
|
2018-10-26 19:04:57 +00:00
|
|
|
PyTorchStreamWriter(std::ostream* out) : out_(out) {
|
2018-07-28 05:21:05 +00:00
|
|
|
writeFileHeader();
|
|
|
|
|
// In the case that we do not write any records into this file, the last
|
|
|
|
|
// record index written into the footer will point to the footer itself.
|
2018-10-26 19:04:57 +00:00
|
|
|
last_record_idx_ = cursor_;
|
2018-07-28 05:21:05 +00:00
|
|
|
}
|
2018-10-26 19:04:57 +00:00
|
|
|
|
|
|
|
|
uint64_t writeRecord(const void* data, size_t size) {
|
|
|
|
|
AT_ASSERTM(!finalized_, "should not be finalized!");
|
|
|
|
|
uint64_t record_offset = cursor_;
|
|
|
|
|
last_record_idx_ = record_offset;
|
2018-07-28 05:21:05 +00:00
|
|
|
write64BitIntegerLittleEndian(RecordTags::STORAGE);
|
|
|
|
|
write64BitIntegerLittleEndian(size);
|
|
|
|
|
padToNextAlignmentBoundary();
|
|
|
|
|
writeBuffer(data, size);
|
|
|
|
|
padToNextAlignmentBoundary();
|
|
|
|
|
return record_offset;
|
|
|
|
|
}
|
2018-10-26 19:04:57 +00:00
|
|
|
|
2018-07-28 05:21:05 +00:00
|
|
|
void writeEndOfFile() {
|
2018-10-26 19:04:57 +00:00
|
|
|
AT_ASSERTM(!finalized_, "cannot finalize again!");
|
2018-07-28 05:21:05 +00:00
|
|
|
writeFileFooter();
|
2018-10-26 19:04:57 +00:00
|
|
|
finalized_ = true;
|
2018-07-28 05:21:05 +00:00
|
|
|
}
|
2018-10-26 19:04:57 +00:00
|
|
|
|
|
|
|
|
int64_t getCurrentSize() const {
|
|
|
|
|
return static_cast<int64_t>(cursor_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool finalized() const {
|
|
|
|
|
return finalized_;
|
|
|
|
|
}
|
|
|
|
|
|
2018-09-28 14:41:26 +00:00
|
|
|
~PyTorchStreamWriter() {
|
2018-10-26 19:04:57 +00:00
|
|
|
if (!finalized_) {
|
2018-07-28 05:21:05 +00:00
|
|
|
writeEndOfFile();
|
|
|
|
|
}
|
|
|
|
|
}
|
2018-10-26 19:04:57 +00:00
|
|
|
|
2018-07-28 05:21:05 +00:00
|
|
|
private:
|
2018-10-26 19:04:57 +00:00
|
|
|
std::ostream* out_;
|
|
|
|
|
size_t cursor_ = 0;
|
|
|
|
|
bool finalized_ = false;
|
|
|
|
|
size_t last_record_idx_ = 0;
|
2018-07-28 05:21:05 +00:00
|
|
|
|
|
|
|
|
// Utility functions
|
|
|
|
|
void write64BitIntegerLittleEndian(const uint64_t value) {
|
|
|
|
|
// TODO endian swap on platforms that need it?
|
2018-10-26 19:04:57 +00:00
|
|
|
out_->write(reinterpret_cast<const char*>(&value), 8);
|
|
|
|
|
cursor_ += 8u;
|
2018-07-28 05:21:05 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void writePad(const size_t num_bytes) {
|
2018-10-26 19:04:57 +00:00
|
|
|
// TODO: move this buffer to the .cc file
|
|
|
|
|
static std::vector<char> pad_buffer_(kFieldAlignment, kPadValue);
|
|
|
|
|
out_->write(pad_buffer_.data(), num_bytes);
|
|
|
|
|
cursor_ += num_bytes;
|
2018-07-28 05:21:05 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void padToNextAlignmentBoundary() {
|
2018-10-26 19:04:57 +00:00
|
|
|
size_t next_offset =
|
|
|
|
|
(cursor_ + kFieldAlignment) - (cursor_ % kFieldAlignment);
|
|
|
|
|
size_t pad_amount = next_offset - cursor_;
|
2018-07-28 05:21:05 +00:00
|
|
|
writePad(pad_amount);
|
|
|
|
|
}
|
|
|
|
|
|
2018-10-26 19:04:57 +00:00
|
|
|
void writeBuffer(const void* data, size_t size) {
|
|
|
|
|
out_->write(static_cast<const char*>(data), size);
|
|
|
|
|
cursor_ += size;
|
2018-07-28 05:21:05 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// File format write functions
|
|
|
|
|
void writeFileHeader() {
|
|
|
|
|
write64BitIntegerLittleEndian(kFileMagicNumber);
|
|
|
|
|
write64BitIntegerLittleEndian(kFileFormatVersion);
|
|
|
|
|
padToNextAlignmentBoundary();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void writeFileFooter() {
|
|
|
|
|
write64BitIntegerLittleEndian(RecordTags::FOOTER);
|
2018-10-26 19:04:57 +00:00
|
|
|
write64BitIntegerLittleEndian(last_record_idx_);
|
2018-07-28 05:21:05 +00:00
|
|
|
padToNextAlignmentBoundary();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2018-10-26 19:04:57 +00:00
|
|
|
class PyTorchFileReader final {
|
2018-09-28 14:41:26 +00:00
|
|
|
public:
|
2018-10-26 19:04:57 +00:00
|
|
|
PyTorchFileReader(const std::string& filename)
|
|
|
|
|
: in_(filename, std::ios_base::binary), stream_reader_(&in_) {}
|
|
|
|
|
|
|
|
|
|
bool hasNextRecord() const {
|
|
|
|
|
return stream_reader_.hasNextRecord();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// return dataptr, key, size
|
|
|
|
|
std::tuple<at::DataPtr, int64_t, int64_t> getNextRecord() {
|
|
|
|
|
return stream_reader_.getNextRecord();
|
|
|
|
|
}
|
2018-09-28 14:41:26 +00:00
|
|
|
|
|
|
|
|
std::tuple<at::DataPtr, size_t> getLastRecord() {
|
2018-10-26 19:04:57 +00:00
|
|
|
return stream_reader_.getLastRecord();
|
2018-09-28 14:41:26 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::tuple<at::DataPtr, size_t> getRecordWithKey(uint64_t key) {
|
2018-10-26 19:04:57 +00:00
|
|
|
return stream_reader_.getRecordWithKey(key);
|
2018-09-28 14:41:26 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
2018-10-26 19:04:57 +00:00
|
|
|
std::ifstream in_;
|
|
|
|
|
PyTorchStreamReader stream_reader_;
|
2018-09-28 14:41:26 +00:00
|
|
|
};
|
|
|
|
|
|
2018-10-26 19:04:57 +00:00
|
|
|
class PyTorchFileWriter final {
|
2018-09-28 14:41:26 +00:00
|
|
|
public:
|
2018-10-26 19:04:57 +00:00
|
|
|
PyTorchFileWriter(const std::string& filename)
|
|
|
|
|
: out_(filename, std::ios_base::binary), stream_writer_(&out_) {}
|
2018-09-28 14:41:26 +00:00
|
|
|
|
2018-10-26 19:04:57 +00:00
|
|
|
uint64_t writeRecord(const void* data, size_t size) {
|
|
|
|
|
AT_ASSERTM(
|
|
|
|
|
!stream_writer_.finalized(),
|
|
|
|
|
"cannot write to a finalized stream writer.");
|
|
|
|
|
return stream_writer_.writeRecord(data, size);
|
2018-09-28 14:41:26 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void writeEndOfFile() {
|
2018-10-26 19:04:57 +00:00
|
|
|
AT_ASSERTM(
|
|
|
|
|
!stream_writer_.finalized(),
|
|
|
|
|
"cannot write end to a finalized stream writer.");
|
|
|
|
|
stream_writer_.writeEndOfFile();
|
|
|
|
|
out_.close();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t getCurrentSize() const {
|
|
|
|
|
return stream_writer_.getCurrentSize();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool closed() const {
|
|
|
|
|
return stream_writer_.finalized();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~PyTorchFileWriter() {
|
|
|
|
|
if (!closed()) {
|
|
|
|
|
// make sure we finalize the steam_writer_ before out_
|
|
|
|
|
// is destroyed.
|
|
|
|
|
writeEndOfFile();
|
|
|
|
|
}
|
2018-09-28 14:41:26 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
2018-10-26 19:04:57 +00:00
|
|
|
std::ofstream out_;
|
|
|
|
|
PyTorchStreamWriter stream_writer_;
|
2018-09-28 14:41:26 +00:00
|
|
|
};
|
2018-07-28 05:21:05 +00:00
|
|
|
}} // namespace torch::jit
|