mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28263 When looking at profiles of deserializing small data from torch::load(), we found some straightforward string-related changes that in aggregate improve the base time by 25%. One of the main problems was over-use of std::stringstream - the constructors alone were 18%+ of the time spent. This change improves unpickling/deserializing by converting a handful of the hottest usecases from the profiles: - unpickler's readString() goes from 10.3% of time to mostly out of the picture - QualifiedHame constructor (particularly Join call) was 8.9% of time, but afterwards disappears from the profiles. - getRecordID/hasRecord were ~5% each, but also get somewhat smaller. ghstack-source-id: 92158727 Test Plan: Benchmark in buck build mode/opt experimental/jeremyl/c2:SerializationBench Correctness in buck test mode/dev-nosan caffe2/test/... Differential Revision: D17997056 fbshipit-source-id: fc6d6c7da7557ff23c8e8c7dbe4c060abf860018
260 lines
8.9 KiB
C++
260 lines
8.9 KiB
C++
#include <ATen/core/functional.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <torch/csrc/jit/import.h>
|
|
#include <torch/csrc/jit/import_export_helpers.h>
|
|
#ifndef C10_MOBILE
|
|
#include <torch/csrc/jit/import_legacy.h>
|
|
#endif
|
|
#include <torch/csrc/jit/import_source.h>
|
|
#include <torch/csrc/jit/ir.h>
|
|
#include <torch/csrc/jit/pickle.h>
|
|
#include <torch/csrc/jit/unpickler.h>
|
|
#include <torch/csrc/jit/script/script_type_parser.h>
|
|
#include <torch/csrc/jit/source_range_serialization.h>
|
|
|
|
#include "caffe2/serialize/file_adapter.h"
|
|
#include "caffe2/serialize/inline_container.h"
|
|
#include "caffe2/serialize/istream_adapter.h"
|
|
|
|
#include <ATen/ATen.h>
|
|
|
|
#include <fstream>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using caffe2::serialize::FileAdapter;
|
|
using caffe2::serialize::IStreamAdapter;
|
|
using caffe2::serialize::PyTorchStreamReader;
|
|
using caffe2::serialize::ReadAdapterInterface;
|
|
|
|
void postSetStateValidate(const IValue& v) {
|
|
auto obj = v.toObject();
|
|
const auto& objType = obj->type();
|
|
for (size_t i = 0; i < objType->numAttributes(); i++) {
|
|
const auto& attrType = objType->getAttribute(i);
|
|
const auto& attrName = objType->getAttributeName(i);
|
|
const auto& slot = obj->getSlot(i);
|
|
// const auto attrType = objType->getAttribute(i);
|
|
// Verify that all the non-optional attributes have been initialized
|
|
// TODO: Issue #20497
|
|
if (attrType->kind() != TypeKind::OptionalType) {
|
|
TORCH_CHECK(
|
|
!slot.isNone(),
|
|
"The field '",
|
|
attrName,
|
|
"' was left unitialized after __setstate__, but expected a ",
|
|
"value of type '",
|
|
attrType->python_str(),
|
|
"'");
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
|
|
// This is a deserializer class which loads script modules from pt files.
|
|
// Content of the file is written using PyTorchStreamWriter, for details please
|
|
// check caffe2/serialize/inline_container.h.
|
|
// The module is saved in pickle. readArchive() is called to parse and construct
|
|
// the constant table and the script module.
|
|
class ScriptModuleDeserializer final {
|
|
public:
|
|
ScriptModuleDeserializer(
|
|
std::shared_ptr<script::CompilationUnit> cu,
|
|
std::unique_ptr<PyTorchStreamReader> reader)
|
|
: compilation_unit_(cu),
|
|
reader_(std::move(reader)),
|
|
source_importer_(
|
|
compilation_unit_,
|
|
&constants_table_,
|
|
[this](const std::string& qualifier) {
|
|
return findSourceInArchiveFromQualifier(
|
|
*reader_, export_prefix_, qualifier);
|
|
},
|
|
reader_->version()) {}
|
|
|
|
script::Module deserialize(
|
|
c10::optional<at::Device> device,
|
|
script::ExtraFilesMap& extra_files);
|
|
|
|
private:
|
|
IValue readArchive(const std::string& archive_name);
|
|
|
|
std::shared_ptr<script::CompilationUnit> compilation_unit_;
|
|
std::unique_ptr<PyTorchStreamReader> reader_;
|
|
c10::optional<at::Device> device_;
|
|
std::vector<at::Tensor> constants_table_;
|
|
script::SourceImporter source_importer_;
|
|
std::string export_prefix_ = "code/";
|
|
};
|
|
|
|
IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) {
|
|
std::string picklename = archive_name + ".pkl";
|
|
at::DataPtr pickle_ptr;
|
|
size_t pickle_size;
|
|
std::tie(pickle_ptr, pickle_size) = reader_->getRecord(picklename);
|
|
|
|
size_t bytes_read = 0;
|
|
auto data = reinterpret_cast<const char*>(pickle_ptr.get());
|
|
auto reader = [&](char* buffer, size_t len) -> size_t {
|
|
if (bytes_read >= pickle_size) {
|
|
return 0;
|
|
}
|
|
len = std::min(pickle_size - bytes_read, len);
|
|
// Copy len bytes into buffer
|
|
const char* start = data + bytes_read;
|
|
std::memcpy(buffer, start, len);
|
|
bytes_read += len;
|
|
return len;
|
|
};
|
|
|
|
auto class_resolver = [&](const c10::QualifiedName& qn) {
|
|
auto cls = source_importer_.loadNamedType(qn)->expect<ClassType>();
|
|
return c10::StrongTypePtr(compilation_unit_, std::move(cls));
|
|
};
|
|
|
|
// Decouple how to get obj from type. In this file it's dependent on
|
|
// Method.run() and graph executor, etc.
|
|
// For bytecode import we need to decouple these dependencies.
|
|
auto obj_loader = [&](at::StrongTypePtr type, IValue input) {
|
|
auto cls = type.type_->expect<at::ClassType>();
|
|
size_t n = cls->numAttributes();
|
|
if (checkHasValidSetGetState(type.type_)) {
|
|
auto obj = c10::ivalue::Object::create(type, n);
|
|
// XXX: Do not optimize __setstate__, so that we don't try to
|
|
// specialize the class before it is initialized.
|
|
setGraphExecutorOptimize(false);
|
|
Function* set_state = type.type_->getMethod("__setstate__");
|
|
// since we are in the middle of unpickling we might still have lists and
|
|
// dicts that do not have accurate tags (e.g. they report they are
|
|
// List[Any]). But we need to run __setstate__ which will check the input
|
|
// type and may access the tags. Since setstate has a known input type, we
|
|
// can correctly restore the tags now by apply the input type of set_state
|
|
// to the state object being passed.
|
|
restoreAccurateTypeTags(
|
|
input, set_state->getSchema().arguments().at(1).type());
|
|
(*set_state)({obj, input});
|
|
setGraphExecutorOptimize(true);
|
|
postSetStateValidate(obj);
|
|
return obj;
|
|
} else {
|
|
auto dict = std::move(input).toGenericDict();
|
|
auto obj = c10::ivalue::Object::create(type, n);
|
|
for (size_t i = 0; i < n; ++i) {
|
|
obj->setSlot(i, dict.at(cls->getAttributeName(i)));
|
|
}
|
|
return obj;
|
|
}
|
|
};
|
|
|
|
std::string archive_name_plus_slash = archive_name + "/";
|
|
auto read_record = [&](const std::string& name) {
|
|
std::string ss = archive_name_plus_slash + name;
|
|
return std::get<0>(reader_->getRecord(ss));
|
|
};
|
|
|
|
Unpickler unpickler(
|
|
reader, std::move(class_resolver), std::move(obj_loader),
|
|
std::move(read_record), device_);
|
|
return unpickler.parse_ivalue();
|
|
}
|
|
|
|
script::Module ScriptModuleDeserializer::deserialize(
|
|
c10::optional<at::Device> device,
|
|
script::ExtraFilesMap& extra_files) {
|
|
C10_LOG_API_USAGE_ONCE("torch.script.load");
|
|
device_ = device;
|
|
// Load extra files.
|
|
for (const auto& kv : extra_files) {
|
|
const std::string& key = "extra/" + kv.first;
|
|
if (reader_->hasRecord(key)) {
|
|
at::DataPtr meta_ptr;
|
|
size_t meta_size;
|
|
std::tie(meta_ptr, meta_size) = reader_->getRecord(key);
|
|
extra_files[kv.first] =
|
|
std::string(static_cast<char*>(meta_ptr.get()), meta_size);
|
|
}
|
|
}
|
|
if (reader_->hasRecord("model.json")) {
|
|
#ifndef C10_MOBILE
|
|
return torch::jit::LEGACY_deserialize(
|
|
compilation_unit_, std::move(reader_), device_);
|
|
#else
|
|
AT_ERROR("Legacy model format is not supported on mobile.");
|
|
#endif
|
|
}
|
|
auto tuple = readArchive("constants").toTuple();
|
|
for (auto constant : tuple->elements()) {
|
|
constants_table_.push_back(constant.toTensor());
|
|
}
|
|
return script::Module(readArchive("data").toObject());
|
|
}
|
|
|
|
} // namespace
|
|
|
|
script::Module import_ir_module(
|
|
std::shared_ptr<script::CompilationUnit> cu,
|
|
std::istream& in,
|
|
c10::optional<at::Device> device,
|
|
script::ExtraFilesMap& extra_files) {
|
|
auto reader = torch::make_unique<PyTorchStreamReader>(&in);
|
|
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
|
|
return deserializer.deserialize(device, extra_files);
|
|
}
|
|
|
|
script::Module import_ir_module(
|
|
std::shared_ptr<script::CompilationUnit> cu,
|
|
const std::string& filename,
|
|
c10::optional<at::Device> device,
|
|
script::ExtraFilesMap& extra_files) {
|
|
auto reader = torch::make_unique<PyTorchStreamReader>(filename);
|
|
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
|
|
return deserializer.deserialize(device, extra_files);
|
|
}
|
|
|
|
script::Module import_ir_module(
|
|
std::shared_ptr<script::CompilationUnit> cu,
|
|
std::unique_ptr<ReadAdapterInterface> rai,
|
|
c10::optional<at::Device> device,
|
|
script::ExtraFilesMap& extra_files) {
|
|
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
|
|
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
|
|
return deserializer.deserialize(device, extra_files);
|
|
}
|
|
|
|
script::Module load(
|
|
std::istream& in,
|
|
c10::optional<at::Device> device,
|
|
script::ExtraFilesMap& extra_files) {
|
|
std::unique_ptr<IStreamAdapter> rai =
|
|
caffe2::make_unique<IStreamAdapter>(&in);
|
|
auto module = load(std::move(rai), device, extra_files);
|
|
return module;
|
|
}
|
|
|
|
script::Module load(
|
|
const std::string& filename,
|
|
c10::optional<at::Device> device,
|
|
script::ExtraFilesMap& extra_files) {
|
|
std::unique_ptr<FileAdapter> rai = caffe2::make_unique<FileAdapter>(filename);
|
|
auto module = load(std::move(rai), device, extra_files);
|
|
return module;
|
|
}
|
|
|
|
script::Module load(
|
|
std::unique_ptr<ReadAdapterInterface> rai,
|
|
c10::optional<c10::Device> device,
|
|
script::ExtraFilesMap& extra_files) {
|
|
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
|
|
auto cu = std::make_shared<script::CompilationUnit>();
|
|
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
|
|
return deserializer.deserialize(device, extra_files);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|