mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: BC NOTE: This change makes it so modules saved with torch.jit.save in PyTorch 1.6 can be loaded by previous versions of PyTorch unless they use torch.div or (soon) torch.full. It also lets tensors saved using torch.save be loaded by previous versions. So this is the opposite of BC-breaking, but I'm using that label to highlight this issue since we don't have a "BC-improving" label. PR NOTE: When an operator's semantics change in PyTorch we want to do two things: 1) Preserve the semantics of older serialized Torchscript programs that use the operator 2) Ensure the new semantics are respected Historically, this meant writing a Versioned Symbol that would remap older versions of the operator into current PyTorch code (1), and bumping the produced file format version (2). Unfortunately, bumping the produced file format version is a nuclear option for ensuring semantics are respected, since it also prevents older versions of PyTorch from loading anything (even tensors!) from newer versions. Dynamic versioning addresses the nuclear consequences of bumping the produced file format version by only bumping it when necessary. That is, when an operator with changed semantics is detected in the serialized Torchscript. This will prevent Torchscript programs that use the changed operator from loading on earlier versions of PyTorch, as desired, but will have no impact on programs that don't use the changed operator. Note that this change is only applicable when using torch.jit.save and torch.jit.load. torch.save pickles the given object using pickle (by default), which saves a function's Python directly. No new tests for this behavior are added since the existing tests for versioned division in test_save_load already validate that models with div are loaded correctly at version 4. Pull Request resolved: https://github.com/pytorch/pytorch/pull/40279 Reviewed By: dzhulgakov Differential Revision: D22168291 Pulled By: mruberry fbshipit-source-id: e71d6380e727e25123c7eedf6d80e5d7f1fe9f95
408 lines
14 KiB
C++
408 lines
14 KiB
C++
#include <torch/csrc/jit/serialization/export.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
#include <torch/csrc/jit/ir/type_hashing.h>
|
|
#include <torch/csrc/jit/passes/inliner.h>
|
|
#include <torch/csrc/jit/runtime/instruction.h>
|
|
#include <torch/csrc/jit/serialization/import_export_constants.h>
|
|
#include <torch/csrc/jit/serialization/import_export_functions.h>
|
|
#include <torch/csrc/jit/serialization/import_export_helpers.h>
|
|
#include <torch/csrc/jit/serialization/pickle.h>
|
|
#include <torch/csrc/jit/serialization/python_print.h>
|
|
#include <torch/csrc/jit/serialization/source_range_serialization.h>
|
|
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
|
|
|
|
#include <caffe2/serialize/inline_container.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
|
|
#include <ATen/core/jit_type.h>
|
|
#include <ATen/core/qualified_name.h>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
char const* toString(OpCode op);
|
|
|
|
namespace {
|
|
ExportModuleExtraFilesHook& GetExtraFilesHook() {
|
|
static ExportModuleExtraFilesHook func = nullptr;
|
|
return func;
|
|
}
|
|
|
|
static IValue Tup(std::vector<IValue> ivalues) {
|
|
return c10::ivalue::Tuple::create(std::move(ivalues));
|
|
}
|
|
|
|
static IValue Table(
|
|
const std::vector<std::pair<std::string, IValue>>& entries) {
|
|
std::vector<IValue> ivalue_entries;
|
|
for (const auto& e : entries) {
|
|
ivalue_entries.push_back(Tup({e.first, e.second}));
|
|
}
|
|
return Tup(std::move(ivalue_entries));
|
|
}
|
|
|
|
c10::IValue getFunctionTuple(const Function& func) {
|
|
auto graph = func.graph()->copy();
|
|
Inline(*graph);
|
|
torch::jit::Code code(graph, func.name());
|
|
|
|
auto instructions_copy = code.instructions();
|
|
|
|
// operator names
|
|
std::vector<c10::OperatorName> opnames;
|
|
std::vector<std::string> method_names;
|
|
for (size_t i = 0; i < instructions_copy.size(); ++i) {
|
|
Instruction ins = instructions_copy[i];
|
|
if (ins.op == OP || ins.op == OPN) {
|
|
auto node = code.instructions_source()[i];
|
|
opnames.emplace_back(node->schema().operator_name());
|
|
}
|
|
// CALL nodes at this point represent built-in (i.e. non-Graph)
|
|
// functions that were not inlined. Here we convert the CALL
|
|
// instructions for these functions into INTERFACE_CALL instructions
|
|
// s.t. at runtime, we will look up the Function* on the Type of the
|
|
// 0th argument in the stack and call that directly.
|
|
if (ins.op == CALL) {
|
|
auto node = code.instructions_source()[i];
|
|
if (node->kind() == prim::CallMethod) {
|
|
// NB: replacing instruction
|
|
auto method_name_idx =
|
|
code.constant_table().size() + method_names.size();
|
|
method_names.emplace_back(node->s(attr::name));
|
|
Instruction new_instr{INTERFACE_CALL,
|
|
static_cast<int32_t>(method_name_idx),
|
|
static_cast<uint16_t>(node->inputs().size())};
|
|
instructions_copy[i] = std::move(new_instr);
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Unsupported node kind on CALL opcode for mobile");
|
|
}
|
|
}
|
|
}
|
|
|
|
// instructions
|
|
std::vector<IValue> instructions;
|
|
instructions.reserve(instructions_copy.size());
|
|
for (Instruction ins : instructions_copy) {
|
|
instructions.emplace_back(Tup({toString(ins.op), ins.X, ins.N}));
|
|
}
|
|
|
|
// operators
|
|
std::vector<IValue> operators;
|
|
operators.reserve(opnames.size());
|
|
for (const auto& opname : opnames) {
|
|
operators.emplace_back(Tup({opname.name, opname.overload_name}));
|
|
}
|
|
|
|
// constants
|
|
//
|
|
// Make a copy of the constants and append the method names
|
|
// that we emitted for the converted INTERFACE_CALL nodes above.
|
|
auto constants = code.constant_table();
|
|
for (auto& method_name : method_names) {
|
|
constants.emplace_back(std::move(method_name));
|
|
}
|
|
|
|
// types
|
|
std::vector<IValue> types;
|
|
types.reserve(code.type_table().size());
|
|
for (const TypePtr& t : code.type_table()) {
|
|
types.emplace_back(t->annotation_str());
|
|
}
|
|
|
|
// since the register location is embedded into the bytecode, pass the
|
|
// register size
|
|
auto register_size = static_cast<int>(code.register_size());
|
|
|
|
auto table = Table({{"instructions", Tup(instructions)},
|
|
{"operators", Tup(operators)},
|
|
{"constants", Tup(constants)},
|
|
{"types", Tup(types)},
|
|
{"register_size", register_size}});
|
|
|
|
return Tup({func.qualname().qualifiedName(), table});
|
|
}
|
|
|
|
void setstateTuple(const IValue& ivalue, std::vector<c10::IValue>& elements) {
|
|
if (!ivalue.isObject())
|
|
return;
|
|
auto obj = ivalue.toObject();
|
|
auto type = obj->type();
|
|
if (checkHasValidSetGetState(type)) {
|
|
Function& setstate = type->getMethod("__setstate__");
|
|
if (setstate.isGraphFunction()) {
|
|
elements.push_back(getFunctionTuple(setstate));
|
|
}
|
|
} else {
|
|
for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
|
|
setstateTuple(obj->getSlot(i), elements);
|
|
}
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
void moduleMethodsTuple(
|
|
const Module& module,
|
|
std::vector<c10::IValue>& elements) {
|
|
auto methods = module.get_methods();
|
|
// top level methods
|
|
for (const auto& method : methods) {
|
|
elements.push_back(getFunctionTuple(method.function()));
|
|
}
|
|
|
|
// __setstate__ of all components
|
|
setstateTuple(module._ivalue(), elements);
|
|
}
|
|
|
|
void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) {
|
|
GetExtraFilesHook() = hook;
|
|
}
|
|
|
|
class ScriptModuleSerializer {
|
|
public:
|
|
explicit ScriptModuleSerializer(const std::string& filename)
|
|
: writer_(filename) {}
|
|
|
|
explicit ScriptModuleSerializer(
|
|
const std::function<size_t(const void*, size_t)>& writer_func)
|
|
: writer_(writer_func) {}
|
|
|
|
void serialize(
|
|
const Module& module,
|
|
const ExtraFilesMap& extra_files,
|
|
bool bytecode_format) {
|
|
C10_LOG_API_USAGE_ONCE("torch.script.save");
|
|
writeExtraFiles(module, extra_files);
|
|
// Serialize the model object
|
|
writeArchive("data", module._ivalue());
|
|
// Then we serialize all code info.
|
|
writeCode(module.type());
|
|
// The tensor constants from the code are written to a separate archive
|
|
// so loading the code does not depend on loading the data
|
|
std::vector<IValue> ivalue_constants(
|
|
constant_table_.begin(), constant_table_.end());
|
|
writeArchive("constants", c10::ivalue::Tuple::create(ivalue_constants));
|
|
if (bytecode_format) {
|
|
writeByteCode(module);
|
|
}
|
|
|
|
// Acquires and sets minimum (dynamic) version
|
|
for (auto& item : file_streams_) {
|
|
writer_.setMinVersion(item.value().minVersion());
|
|
}
|
|
}
|
|
|
|
private:
|
|
void writeArchive(const std::string& archive_name, const IValue& value) {
|
|
std::vector<char> data;
|
|
// Vector to capture the run-time class types during pickling the IValues
|
|
std::vector<c10::ClassTypePtr> memorizedClassTypes;
|
|
Pickler data_pickle(
|
|
[&](const char* buf, size_t size) {
|
|
data.insert(data.end(), buf, buf + size);
|
|
},
|
|
nullptr,
|
|
[&](const c10::ClassTypePtr& t) {
|
|
return type_name_uniquer_.getUniqueName(t);
|
|
},
|
|
&memorizedClassTypes);
|
|
data_pickle.protocol();
|
|
data_pickle.pushIValue(value);
|
|
data_pickle.stop();
|
|
size_t i = 0;
|
|
std::string prefix = archive_name + "/";
|
|
for (const auto& td : data_pickle.tensorData()) {
|
|
WriteableTensorData writable_td = getWriteableTensorData(td);
|
|
std::string fname = prefix + c10::to_string(i++);
|
|
writer_.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes());
|
|
}
|
|
std::string fname = archive_name + ".pkl";
|
|
writer_.writeRecord(fname, data.data(), data.size());
|
|
|
|
// serialize all the captured run-time class types
|
|
for (const c10::ClassTypePtr& wroteType : memorizedClassTypes) {
|
|
convertNamedType(wroteType);
|
|
}
|
|
}
|
|
|
|
void writeExtraFiles(const Module& module, const ExtraFilesMap& extra_files) {
|
|
// Write out extra files.
|
|
for (const auto& kv : extra_files) {
|
|
const std::string key = "extra/" + kv.first;
|
|
writer_.writeRecord(key, kv.second.data(), kv.second.size());
|
|
}
|
|
auto hook = GetExtraFilesHook();
|
|
if (hook) {
|
|
ExtraFilesMap hook_files = hook(module);
|
|
for (const auto& kv : hook_files) {
|
|
// Checks if the hooked file is already written in extra files,
|
|
// if so, skips it and warns
|
|
if (extra_files.find(kv.first) != extra_files.end()) {
|
|
TORCH_WARN_ONCE(
|
|
"An extra files hook attempted to write ",
|
|
kv.first,
|
|
" but ",
|
|
"this is already written in extra files and so will be skipped. ",
|
|
"This warning will only appear once per process.");
|
|
continue;
|
|
}
|
|
const std::string key = "extra/" + kv.first;
|
|
writer_.writeRecord(key, kv.second.data(), kv.second.size());
|
|
}
|
|
}
|
|
}
|
|
|
|
void writeCode(const at::NamedTypePtr& root_type) {
|
|
class_deps_.push_back(root_type);
|
|
for (size_t i = 0; i < class_deps_.size(); ++i) {
|
|
// note: convertNameType may extend class_deps_, so re-checking
|
|
// .size() is necessary
|
|
convertNamedType(class_deps_[i]);
|
|
}
|
|
|
|
// Mapping of filename => src. We need this because multiple classes may go
|
|
// in the same file (e.g. foo.bar.Baz and foo.bar.Qux)
|
|
for (auto& item : file_streams_) {
|
|
const std::string filename = qualifierToArchivePath(item.key(), "code/");
|
|
|
|
std::string src = item.value().str();
|
|
|
|
// Only compress these records if they're not tiny.
|
|
// The cpu cost of generating zip datastructs and compressing isn't
|
|
// well-spent for very small records.
|
|
static constexpr size_t kMinToCompress = 200;
|
|
|
|
writer_.writeRecord(
|
|
filename,
|
|
src.c_str(),
|
|
src.size(),
|
|
src.size() > kMinToCompress /*compress*/);
|
|
|
|
// Write out the debug information
|
|
std::string debugFilename = filename + ".debug_pkl";
|
|
SourceRangePickler source_range_pickler;
|
|
auto range_data = source_range_pickler.pickle(item.value().ranges());
|
|
writer_.writeRecord(
|
|
debugFilename,
|
|
range_data.data(),
|
|
range_data.size(),
|
|
range_data.size() > kMinToCompress /*compress*/);
|
|
}
|
|
}
|
|
|
|
void writeByteCode(const Module& module) {
|
|
std::vector<c10::IValue> elements;
|
|
moduleMethodsTuple(module, elements);
|
|
auto telements = Tup(std::move(elements));
|
|
writeArchive("bytecode", telements);
|
|
}
|
|
|
|
void convertNamedType(const c10::NamedTypePtr& class_type) {
|
|
if (converted_types_.count(class_type)) {
|
|
return;
|
|
}
|
|
converted_types_.insert(class_type);
|
|
auto qualname = type_name_uniquer_.getUniqueName(class_type);
|
|
std::string qualifier = qualname.prefix();
|
|
PythonPrint* pp = file_streams_.find(qualifier);
|
|
|
|
auto type_printer =
|
|
[&](const c10::ConstTypePtr& t) -> c10::optional<std::string> {
|
|
auto namedType = t->cast<c10::NamedType>();
|
|
if (namedType && namedType->name()) {
|
|
return type_name_uniquer_.getUniqueName(namedType).qualifiedName();
|
|
}
|
|
return c10::nullopt;
|
|
};
|
|
if (!pp) {
|
|
pp = &file_streams_.insert(
|
|
qualifier,
|
|
PythonPrint(
|
|
constant_table_,
|
|
class_deps_,
|
|
type_printer,
|
|
/*enforce_importable=*/true));
|
|
}
|
|
pp->printNamedType(class_type);
|
|
}
|
|
|
|
caffe2::serialize::PyTorchStreamWriter writer_;
|
|
std::vector<at::Tensor> constant_table_;
|
|
std::unordered_set<c10::NamedTypePtr> converted_types_;
|
|
std::vector<c10::NamedTypePtr> class_deps_;
|
|
TypeNameUniquer type_name_uniquer_;
|
|
|
|
// qualifier, e.g. '__torch__.Bar' -> PythonPrint for the file that will be
|
|
// created
|
|
OrderedDict<std::string, PythonPrint> file_streams_;
|
|
};
|
|
|
|
void ExportModule(
|
|
const Module& module,
|
|
std::ostream& out,
|
|
const ExtraFilesMap& extra_files,
|
|
bool bytecode_format) {
|
|
ScriptModuleSerializer serializer(
|
|
[&](const void* buf, size_t nbytes) -> size_t {
|
|
out.write(static_cast<const char*>(buf), nbytes);
|
|
return !out ? 0 : nbytes;
|
|
});
|
|
serializer.serialize(module, extra_files, bytecode_format);
|
|
}
|
|
|
|
void ExportModule(
|
|
const Module& module,
|
|
const std::string& filename,
|
|
const ExtraFilesMap& extra_files,
|
|
bool bytecode_format) {
|
|
ScriptModuleSerializer serializer(filename);
|
|
serializer.serialize(module, extra_files, bytecode_format);
|
|
}
|
|
|
|
void ExportModule(
|
|
const Module& module,
|
|
const std::function<size_t(const void*, size_t)>& writer_func,
|
|
const ExtraFilesMap& extra_files,
|
|
bool bytecode_format) {
|
|
ScriptModuleSerializer serializer(writer_func);
|
|
serializer.serialize(module, extra_files, bytecode_format);
|
|
}
|
|
|
|
namespace {
|
|
void export_opnames(const script::Module& m, std::set<std::string>& opnames) {
|
|
std::vector<c10::IValue> elements;
|
|
moduleMethodsTuple(m, elements);
|
|
for (const auto& element : elements) {
|
|
auto table = element.toTuple()->elements()[1];
|
|
auto row =
|
|
table.toTuple()->elements().at(BYTECODE_INDEX_OPERATOR).toTuple();
|
|
TORCH_INTERNAL_ASSERT(
|
|
row->elements().at(0).toStringRef() == "operators",
|
|
"Expected operators but found ",
|
|
row->elements().at(0).toStringRef());
|
|
const auto& ops_list = row->elements().at(1).toTuple()->elements();
|
|
for (const auto& op : ops_list) {
|
|
auto op_item = op.toTuple()->elements();
|
|
TORCH_CHECK(
|
|
op_item.size() == 2,
|
|
"There should be two parts in an operator name.");
|
|
auto opname = op_item[0].toString()->string();
|
|
auto overload = op_item[1].toString()->string();
|
|
opnames.emplace(overload.empty() ? opname : opname + "." + overload);
|
|
}
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
std::vector<std::string> export_opnames(const script::Module& m) {
|
|
std::set<std::string> names;
|
|
export_opnames(m, names);
|
|
return std::vector<std::string>(names.begin(), names.end());
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|