mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
This reverts commit c5b6be045f.
### Description
Revert
### Motivation and Context
This needs simpler and more robust approach
930 lines
37 KiB
C++
930 lines
37 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include <memory>
|
|
#include "core/common/logging/logging.h"
|
|
#include "core/flatbuffers/schema/ort.fbs.h"
|
|
#include "core/flatbuffers/flatbuffers_utils.h"
|
|
#include "core/framework/tensorprotoutils.h"
|
|
#include "core/graph/model.h"
|
|
#include "core/graph/model_load_utils.h"
|
|
|
|
#ifdef _MSC_VER
|
|
#pragma warning(push)
|
|
// 'type' : forcing value to bool 'true' or 'false' (performance warning)
|
|
#pragma warning(disable : 4800)
|
|
#endif
|
|
#include <google/protobuf/io/coded_stream.h>
|
|
#ifdef _MSC_VER
|
|
#pragma warning(pop)
|
|
#endif
|
|
#include "core/util/protobuf_parsing_utils.h"
|
|
|
|
#include <gsl/gsl>
|
|
|
|
#include "core/platform/env.h"
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
#include "core/graph/schema_registry.h"
|
|
#include "core/graph/function_utils.h"
|
|
#endif
|
|
|
|
#if defined(__wasm__)
|
|
#include <emscripten.h>
|
|
#endif
|
|
|
|
using namespace ONNX_NAMESPACE;
|
|
using namespace onnxruntime;
|
|
using namespace onnxruntime::common;
|
|
|
|
namespace onnxruntime {
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
|
|
void Model::RemoveLocalFunctionsProtos(const InlinedHashSet<std::string>& retained) {
|
|
auto* local_functions = model_proto_.mutable_functions();
|
|
if (retained.empty()) {
|
|
model_local_function_templates_maps_.clear();
|
|
model_local_functions_.clear();
|
|
local_functions->erase(local_functions->begin(), local_functions->end());
|
|
} else {
|
|
const auto retained_end = retained.cend();
|
|
for (auto it = model_local_functions_.begin();
|
|
it != model_local_functions_.end();) {
|
|
if (retained.find(it->first) == retained_end) {
|
|
model_local_function_templates_maps_.erase(it->first);
|
|
it = model_local_functions_.erase(it);
|
|
} else {
|
|
++it;
|
|
}
|
|
}
|
|
|
|
for (auto it = local_functions->begin(); it != local_functions->end();) {
|
|
const auto function_id = function_utils::GetFunctionIdentifier(it->domain(), it->name());
|
|
if (retained.find(function_id) == retained_end) {
|
|
it = local_functions->erase(it);
|
|
} else {
|
|
++it;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static constexpr int DEFAULT_PROTOBUF_BLOCK_SIZE = 4 * 1024 * 1024;
|
|
|
|
Model::Model(const std::string& graph_name,
|
|
bool is_onnx_domain_only,
|
|
const ModelMetaData& model_metadata,
|
|
const PathString& model_path,
|
|
const IOnnxRuntimeOpSchemaRegistryList& local_registries,
|
|
const std::unordered_map<std::string, int>& domain_to_version,
|
|
const std::vector<ONNX_NAMESPACE::FunctionProto>& model_local_functions,
|
|
const logging::Logger& logger,
|
|
const ModelOptions& options)
|
|
: model_path_(model_path) {
|
|
model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
|
model_proto_.mutable_graph()->set_name(graph_name);
|
|
model_metadata_ = model_metadata;
|
|
for (auto& metadata : model_metadata_) {
|
|
const gsl::not_null<StringStringEntryProto*> prop{model_proto_.add_metadata_props()};
|
|
prop->set_key(metadata.first);
|
|
prop->set_value(metadata.second);
|
|
}
|
|
|
|
auto schema_registry = std::make_shared<SchemaRegistryManager>();
|
|
for (const auto& schema_collection : local_registries) {
|
|
schema_registry->RegisterRegistry(schema_collection);
|
|
}
|
|
|
|
// IsAllowReleasedONNXOpsetsOnlySet() checks for the appropriate env var in the process (i.e.) process-wide
|
|
// `allow_released_opsets_only` is for this specific Model instance
|
|
// We will only support released opsets iff IsAllowReleasedONNXOpsetsOnlySet() and `allow_released_opsets_only`
|
|
// are both true
|
|
auto allow_released_opsets_only_final =
|
|
options.allow_released_opsets_only && model_load_utils::IsAllowReleasedONNXOpsetsOnlySet();
|
|
|
|
auto* p_domain_to_version = &domain_to_version;
|
|
DomainToVersionMap domain_to_version_static;
|
|
domain_to_version_static = allow_released_opsets_only_final
|
|
? schema_registry->GetLastReleasedOpsetVersions(is_onnx_domain_only)
|
|
: schema_registry->GetLatestOpsetVersions(is_onnx_domain_only);
|
|
if (p_domain_to_version->empty()) {
|
|
p_domain_to_version = &domain_to_version_static;
|
|
}
|
|
|
|
for (const auto& [domain, version] : *p_domain_to_version) {
|
|
model_load_utils::ValidateOpsetForDomain(domain_to_version_static, logger, allow_released_opsets_only_final,
|
|
domain, version);
|
|
const gsl::not_null<OperatorSetIdProto*> opset_id_proto{model_proto_.add_opset_import()};
|
|
opset_id_proto->set_domain(domain);
|
|
opset_id_proto->set_version(version);
|
|
}
|
|
|
|
model_local_functions_.reserve(model_local_functions.size());
|
|
for (auto& func : model_local_functions) {
|
|
auto func_ptr = model_proto_.add_functions();
|
|
func_ptr->CopyFrom(func);
|
|
model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func_ptr->domain(), func_ptr->name()),
|
|
func_ptr);
|
|
}
|
|
|
|
model_local_function_templates_maps_.reserve(model_proto_.functions().size());
|
|
for (auto& func : model_proto_.functions()) {
|
|
auto func_schema_ptr = function_utils::CreateSchema(func.domain(),
|
|
func.name(),
|
|
model_local_functions_,
|
|
*p_domain_to_version,
|
|
*schema_registry,
|
|
logger,
|
|
allow_released_opsets_only_final);
|
|
auto func_template_ptr = std::make_unique<FunctionTemplate>();
|
|
func_template_ptr->op_schema_ = std::move(func_schema_ptr);
|
|
func_template_ptr->onnx_func_proto_ = &func;
|
|
model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(),
|
|
func.name()),
|
|
std::move(func_template_ptr));
|
|
}
|
|
|
|
// need to call private ctor so can't use make_shared
|
|
GSL_SUPPRESS(r .11)
|
|
graph_.reset(new Graph(*this, model_proto_.mutable_graph(), *p_domain_to_version, IrVersion(), schema_registry,
|
|
logger, options.strict_shape_type_inference));
|
|
}
|
|
|
|
Model::Model(const ModelProto& model_proto, const PathString& model_path,
|
|
const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger,
|
|
const ModelOptions& options)
|
|
: Model(ModelProto(model_proto), model_path, local_registries, logger, options) {
|
|
}
|
|
|
|
Model::Model(ModelProto&& model_proto, const PathString& model_path,
|
|
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
|
const logging::Logger& logger, const ModelOptions& options)
|
|
: model_path_(model_path) {
|
|
if (!utils::HasGraph(model_proto)) {
|
|
ORT_THROW("ModelProto does not have a graph.");
|
|
}
|
|
|
|
if (model_proto.opset_import_size() == 0) {
|
|
ORT_THROW(
|
|
"Missing opset in the model. All ModelProtos MUST have at least one entry that"
|
|
" specifies which version of the ONNX OperatorSet is being imported.");
|
|
}
|
|
|
|
if (!model_proto.has_ir_version()) {
|
|
ORT_THROW("Missing model IR version.");
|
|
}
|
|
|
|
if (const auto ir_version = model_proto.ir_version();
|
|
ir_version > ONNX_NAMESPACE::Version::IR_VERSION) {
|
|
ORT_THROW("Unsupported model IR version: ", ir_version,
|
|
", max supported IR version: ", ONNX_NAMESPACE::Version::IR_VERSION);
|
|
}
|
|
|
|
model_proto_ = std::move(model_proto);
|
|
for (auto& prop : model_proto_.metadata_props()) {
|
|
model_metadata_[prop.key()] = prop.value();
|
|
}
|
|
|
|
auto schema_registry = std::make_shared<SchemaRegistryManager>();
|
|
if (local_registries != nullptr) {
|
|
for (const auto& schema_collection : *local_registries) {
|
|
schema_registry->RegisterRegistry(schema_collection);
|
|
}
|
|
}
|
|
|
|
// IsAllowReleasedONNXOpsetsOnlySet() checks for the appropriate env var in the process (i.e.) process-wide
|
|
// `allow_released_opsets_only` is for this specific Model instance
|
|
// We will only support released opsets iff IsAllowReleasedONNXOpsetsOnlySet() and `allow_released_opsets_only`
|
|
// are both true
|
|
auto allow_official_onnx_release_only_final =
|
|
options.allow_released_opsets_only && model_load_utils::IsAllowReleasedONNXOpsetsOnlySet();
|
|
|
|
const auto onnx_released_versions =
|
|
schema_registry->GetLastReleasedOpsetVersions(false);
|
|
|
|
std::unordered_map<std::string, int> domain_to_version;
|
|
for (auto& opSet : model_proto_.opset_import()) {
|
|
const auto& domain = opSet.domain();
|
|
const auto version = gsl::narrow_cast<int>(opSet.version());
|
|
// empty domain and 'ai.onnx' are equivalent
|
|
if ((domain.empty() || domain == kOnnxDomainAlias) && version < 7) {
|
|
// TODO: Check if we can upgrade all the current opset 6 models that are being tested
|
|
// in CI to opset 7 or above
|
|
LOGS(logger, WARNING) << "ONNX Runtime only *guarantees* support for models stamped "
|
|
"with opset version 7 or above for opset domain 'ai.onnx'. "
|
|
"Please upgrade your model to opset 7 or higher. "
|
|
"For now, this opset "
|
|
<< version
|
|
<< " model may run depending upon legacy support "
|
|
"of some older opset version operators.";
|
|
}
|
|
|
|
model_load_utils::ValidateOpsetForDomain(onnx_released_versions, logger,
|
|
allow_official_onnx_release_only_final, domain, version);
|
|
|
|
// We need to overwrite the domain here with ("") or else the loop below will try to find ("")
|
|
// in the map and if not found (when domain == kOnnxDomainAlias), adds an entry for ("", 11).
|
|
// This effectively ignores the opset version specified by the model for the onnx domain.
|
|
if (domain == kOnnxDomainAlias) {
|
|
domain_to_version[kOnnxDomain] = version;
|
|
} else {
|
|
domain_to_version[domain] = version;
|
|
}
|
|
}
|
|
|
|
// special-case the internal NHWC domain as it must match the ONNX opset if not explicitly imported
|
|
if (domain_to_version.find(kMSInternalNHWCDomain) == domain_to_version.end()) {
|
|
auto onnx_version = domain_to_version.find(kOnnxDomain);
|
|
if (onnx_version != domain_to_version.end()) {
|
|
domain_to_version[kMSInternalNHWCDomain] = onnx_version->second;
|
|
}
|
|
}
|
|
|
|
auto domain_map = allow_official_onnx_release_only_final
|
|
? schema_registry->GetLastReleasedOpsetVersions(false)
|
|
: schema_registry->GetLatestOpsetVersions(false);
|
|
for (const auto& [domain, version] : domain_map) {
|
|
if (domain_to_version.find(domain) == domain_to_version.end()) {
|
|
domain_to_version[domain] = version;
|
|
const gsl::not_null<OperatorSetIdProto*> opset_id_proto{model_proto_.add_opset_import()};
|
|
opset_id_proto->set_domain(domain);
|
|
opset_id_proto->set_version(version);
|
|
}
|
|
}
|
|
|
|
model_local_functions_.reserve(model_proto_.functions().size());
|
|
for (auto& func : model_proto_.functions()) {
|
|
model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), &func);
|
|
}
|
|
|
|
model_local_function_templates_maps_.reserve(model_proto_.functions().size());
|
|
for (auto& func : model_proto_.functions()) {
|
|
auto func_schema_ptr = function_utils::CreateSchema(func.domain(),
|
|
func.name(),
|
|
model_local_functions_,
|
|
domain_to_version,
|
|
*schema_registry,
|
|
logger,
|
|
allow_official_onnx_release_only_final);
|
|
auto func_template_ptr = std::make_unique<FunctionTemplate>();
|
|
func_template_ptr->op_schema_ = std::move(func_schema_ptr);
|
|
func_template_ptr->onnx_func_proto_ = &func;
|
|
model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(),
|
|
func.name()),
|
|
std::move(func_template_ptr));
|
|
}
|
|
|
|
// create instance. need to call private ctor so can't use make_unique
|
|
GSL_SUPPRESS(r .11)
|
|
graph_.reset(new Graph(*this, model_proto_.mutable_graph(), domain_to_version, IrVersion(), schema_registry,
|
|
logger, options.strict_shape_type_inference));
|
|
}
|
|
|
|
const NodeHashMap<std::string, std::unique_ptr<FunctionTemplate>>& Model::GetModelLocalFunctionTemplates() const {
|
|
return model_local_function_templates_maps_;
|
|
}
|
|
|
|
Version Model::IrVersion() const {
|
|
if (utils::HasIrVersion(model_proto_)) {
|
|
return model_proto_.ir_version();
|
|
}
|
|
return kNoVersion;
|
|
}
|
|
|
|
const std::string Model::ProducerName() const {
|
|
if (model_proto_.has_producer_name()) {
|
|
return model_proto_.producer_name();
|
|
}
|
|
return std::string();
|
|
}
|
|
|
|
void Model::SetProducerName(const std::string& producer_name) {
|
|
model_proto_.set_producer_name(producer_name);
|
|
}
|
|
|
|
const std::string Model::ProducerVersion() const {
|
|
if (model_proto_.has_producer_version()) {
|
|
return model_proto_.producer_version();
|
|
}
|
|
return std::string();
|
|
}
|
|
|
|
void Model::SetProducerVersion(const std::string& producer_version) {
|
|
model_proto_.set_producer_version(producer_version);
|
|
}
|
|
|
|
const std::string Model::Domain() const {
|
|
if (model_proto_.has_domain()) {
|
|
return model_proto_.domain();
|
|
}
|
|
return std::string();
|
|
}
|
|
|
|
void Model::SetDomain(const std::string& domain) {
|
|
model_proto_.set_domain(domain);
|
|
}
|
|
|
|
Version Model::ModelVersion() const {
|
|
if (utils::HasModelVersion(model_proto_)) {
|
|
return model_proto_.model_version();
|
|
}
|
|
return kNoVersion;
|
|
}
|
|
|
|
void Model::SetModelVersion(onnxruntime::Version version) {
|
|
model_proto_.set_model_version(version);
|
|
}
|
|
|
|
const std::string Model::DocString() const {
|
|
if (model_proto_.has_doc_string()) {
|
|
return model_proto_.doc_string();
|
|
}
|
|
return std::string();
|
|
}
|
|
|
|
void Model::SetDocString(const std::string& doc_string) {
|
|
model_proto_.set_doc_string(doc_string);
|
|
}
|
|
|
|
const std::string Model::GraphDocString() const {
|
|
if (model_proto_.has_graph() && model_proto_.graph().has_doc_string()) {
|
|
return model_proto_.graph().doc_string();
|
|
}
|
|
return std::string();
|
|
}
|
|
|
|
#endif // !defined(ORT_MINIMAL_BUILD)
|
|
|
|
const ModelMetaData& Model::MetaData() const noexcept {
|
|
return model_metadata_;
|
|
}
|
|
|
|
Graph& Model::MainGraph() noexcept {
|
|
return *graph_;
|
|
}
|
|
|
|
const Graph& Model::MainGraph() const noexcept {
|
|
return *graph_;
|
|
}
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
ModelProto Model::ToProto() const {
|
|
// We want to return back the original proto
|
|
// To that end invoke const overload of ToGraphProto()
|
|
// that returns by value and, therefore, allows us to filter
|
|
// out dense duplicates of sparse initializers and leave the original
|
|
// proto intact.
|
|
ModelProto result(model_proto_);
|
|
const auto& graph = *graph_;
|
|
*(result.mutable_graph()) = graph.ToGraphProto();
|
|
return result;
|
|
}
|
|
|
|
ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name,
|
|
const std::filesystem::path& file_path,
|
|
size_t initializer_size_threshold,
|
|
const Graph::OffsetAlignmentInfo& align_info) const {
|
|
ModelProto result(model_proto_);
|
|
const auto& graph = *graph_;
|
|
*(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name,
|
|
file_path,
|
|
initializer_size_threshold,
|
|
align_info);
|
|
return result;
|
|
}
|
|
|
|
Status Model::Load(std::istream& model_istream, ModelProto* p_model_proto) {
|
|
if (!model_istream.good()) {
|
|
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid istream object.");
|
|
}
|
|
if (!p_model_proto) {
|
|
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Null model_proto ptr.");
|
|
}
|
|
|
|
google::protobuf::io::IstreamInputStream zero_copy_input(&model_istream);
|
|
const bool result = p_model_proto->ParseFromZeroCopyStream(&zero_copy_input) && model_istream.eof();
|
|
if (!result) {
|
|
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Failed to load model because protobuf parsing failed.");
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status Model::Load(const ModelProto& model_proto,
|
|
std::shared_ptr<Model>& model,
|
|
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
|
const logging::Logger& logger,
|
|
const ModelOptions& options) {
|
|
return Model::Load(model_proto, PathString{}, model, local_registries, logger, options);
|
|
}
|
|
|
|
Status Model::Load(const ModelProto& model_proto,
|
|
const PathString& model_path,
|
|
std::shared_ptr<Model>& model,
|
|
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
|
const logging::Logger& logger,
|
|
const ModelOptions& options) {
|
|
// we expect a graph to be present
|
|
if (!utils::HasGraph(model_proto)) {
|
|
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No graph was found in the protobuf.");
|
|
}
|
|
|
|
// need to call private ctor so can't use make_shared
|
|
GSL_SUPPRESS(r .11)
|
|
|
|
auto status = Status::OK();
|
|
ORT_TRY {
|
|
model = std::make_unique<Model>(model_proto, model_path, local_registries, logger, options);
|
|
}
|
|
ORT_CATCH(const std::exception& ex) {
|
|
ORT_HANDLE_EXCEPTION([&]() {
|
|
status = Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what()));
|
|
});
|
|
}
|
|
ORT_RETURN_IF_ERROR(status);
|
|
|
|
Graph::ResolveOptions resolve_options;
|
|
resolve_options.no_proto_sync_required = true;
|
|
ORT_RETURN_IF_ERROR(model->MainGraph().Resolve(resolve_options));
|
|
|
|
return status;
|
|
}
|
|
|
|
Status Model::Load(ModelProto&& model_proto,
|
|
std::shared_ptr<Model>& model,
|
|
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
|
const logging::Logger& logger,
|
|
const ModelOptions& options) {
|
|
return Model::Load(std::move(model_proto), PathString{}, model, local_registries, logger, options);
|
|
}
|
|
|
|
Status Model::Load(ModelProto&& model_proto,
|
|
const PathString& model_path,
|
|
std::shared_ptr<Model>& model,
|
|
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
|
const logging::Logger& logger,
|
|
const ModelOptions& options) {
|
|
// we expect a graph to be present
|
|
if (!utils::HasGraph(model_proto)) {
|
|
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No graph was found in the protobuf.");
|
|
}
|
|
|
|
// need to call private ctor so can't use make_shared
|
|
GSL_SUPPRESS(r .11)
|
|
auto status = Status::OK();
|
|
ORT_TRY {
|
|
model = std::make_unique<Model>(std::move(model_proto), model_path, local_registries, logger, options);
|
|
}
|
|
ORT_CATCH(const std::exception& ex) {
|
|
ORT_HANDLE_EXCEPTION([&]() {
|
|
status = Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what()));
|
|
});
|
|
}
|
|
ORT_RETURN_IF_ERROR(status);
|
|
|
|
Graph::ResolveOptions resolve_options;
|
|
resolve_options.no_proto_sync_required = true;
|
|
ORT_RETURN_IF_ERROR(model->MainGraph().Resolve(resolve_options));
|
|
|
|
return status;
|
|
}
|
|
|
|
template <typename T, typename Loader>
|
|
static Status LoadModelHelper(const T& file_path, Loader loader) {
|
|
int fd;
|
|
Status status = Env::Default().FileOpenRd(file_path, fd);
|
|
if (!status.IsOK()) {
|
|
if (status.Category() == common::SYSTEM) {
|
|
switch (status.Code()) {
|
|
case ENOENT:
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, NO_SUCHFILE, "Load model ", ToUTF8String(file_path),
|
|
" failed. File doesn't exist");
|
|
case EINVAL:
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Load model ", ToUTF8String(file_path), " failed");
|
|
default:
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "system error number ", status.Code());
|
|
}
|
|
}
|
|
}
|
|
|
|
ORT_TRY {
|
|
status = loader(fd);
|
|
}
|
|
ORT_CATCH(const std::exception& ex) {
|
|
ORT_HANDLE_EXCEPTION([&]() {
|
|
status = Status(ONNXRUNTIME, FAIL, ex.what());
|
|
});
|
|
}
|
|
|
|
if (!status.IsOK()) {
|
|
GSL_SUPPRESS(es .84)
|
|
ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd));
|
|
return status;
|
|
}
|
|
return Env::Default().FileClose(fd);
|
|
}
|
|
|
|
template <typename T>
|
|
static Status LoadModel(const T& file_path, ONNX_NAMESPACE::ModelProto& model_proto) {
|
|
const auto loader = [&model_proto](int fd) {
|
|
return Model::Load(fd, model_proto);
|
|
};
|
|
|
|
return LoadModelHelper(file_path, loader);
|
|
}
|
|
|
|
template <typename T>
|
|
static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model,
|
|
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
|
const logging::Logger& logger, const ModelOptions& options) {
|
|
const auto loader = [&file_path, &p_model, local_registries, &logger, &options](int fd) {
|
|
return Model::Load(fd, ToPathString(file_path), p_model, local_registries, logger, options);
|
|
};
|
|
|
|
return LoadModelHelper(file_path, loader);
|
|
}
|
|
|
|
template <typename T>
|
|
static Status SaveModel(Model& model, const T& file_path) {
|
|
#if defined(__wasm__) && defined(ORT_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL)
|
|
ORT_RETURN_IF_ERROR(model.MainGraph().Resolve());
|
|
auto model_proto = model.ToProto();
|
|
auto buffer_size = model_proto.ByteSizeLong();
|
|
void* buffer = malloc(buffer_size);
|
|
model_proto.SerializeToArray(buffer, buffer_size);
|
|
|
|
EM_ASM(({
|
|
const buffer = Number($0);
|
|
const buffer_size = Number($1);
|
|
const file_path = UTF8ToString($2);
|
|
const bytes = new Uint8Array(buffer_size);
|
|
bytes.set(HEAPU8.subarray(buffer, buffer + buffer_size));
|
|
if (typeof process == 'object' && typeof process.versions == 'object' &&
|
|
typeof process.versions.node == 'string') {
|
|
// Node.js
|
|
require('fs').writeFileSync(file_path, bytes);
|
|
} else {
|
|
// Browser
|
|
const file = new File([bytes], file_path, {type: "application/octet-stream" });
|
|
const url = URL.createObjectURL(file);
|
|
window.open(url, '_blank');
|
|
}
|
|
}),
|
|
buffer,
|
|
buffer_size,
|
|
file_path.c_str());
|
|
|
|
free(buffer);
|
|
return Status::OK();
|
|
|
|
#else
|
|
int fd;
|
|
Status status = Env::Default().FileOpenWr(file_path, fd);
|
|
ORT_RETURN_IF_ERROR(status);
|
|
|
|
ORT_TRY {
|
|
status = Model::Save(model, fd);
|
|
}
|
|
ORT_CATCH(const std::exception& ex) {
|
|
ORT_HANDLE_EXCEPTION([&]() {
|
|
status = Status(ONNXRUNTIME, FAIL, ex.what());
|
|
});
|
|
}
|
|
if (!status.IsOK()) {
|
|
GSL_SUPPRESS(es .84)
|
|
ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd));
|
|
return status;
|
|
}
|
|
return Env::Default().FileClose(fd);
|
|
#endif
|
|
}
|
|
|
|
Status Model::Save(Model& model, const PathString& file_path) {
|
|
return SaveModel(model, file_path);
|
|
}
|
|
|
|
template <typename T>
|
|
static Status SaveModelWithExternalInitializers(Model& model,
|
|
const T& file_path,
|
|
const std::filesystem::path& external_file_name,
|
|
size_t initializer_size_threshold,
|
|
const Graph::OffsetAlignmentInfo& align_info) {
|
|
int fd = 0;
|
|
Status status = Env::Default().FileOpenWr(file_path, fd);
|
|
ORT_RETURN_IF_ERROR(status);
|
|
|
|
ORT_TRY {
|
|
status = Model::SaveWithExternalInitializers(model, fd, file_path, external_file_name,
|
|
initializer_size_threshold,
|
|
align_info);
|
|
}
|
|
ORT_CATCH(const std::exception& ex) {
|
|
ORT_HANDLE_EXCEPTION([&]() {
|
|
status = Status(ONNXRUNTIME, FAIL, ex.what());
|
|
});
|
|
}
|
|
if (!status.IsOK()) {
|
|
GSL_SUPPRESS(es .84)
|
|
ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd));
|
|
return status;
|
|
}
|
|
return Env::Default().FileClose(fd);
|
|
}
|
|
|
|
Status Model::Load(const PathString& file_path,
|
|
ONNX_NAMESPACE::ModelProto& model_proto) {
|
|
return LoadModel(file_path, model_proto);
|
|
}
|
|
|
|
GSL_SUPPRESS(r .30) // spurious warnings. p_model is potentially reset in the internal call to Load
|
|
GSL_SUPPRESS(r .35)
|
|
Status Model::Load(const PathString& file_path, std::shared_ptr<Model>& p_model,
|
|
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
|
const logging::Logger& logger, const ModelOptions& options) {
|
|
return LoadModel(file_path, p_model, local_registries, logger, options);
|
|
}
|
|
|
|
Status Model::SaveWithExternalInitializers(Model& model, const std::filesystem::path& file_path,
|
|
const std::filesystem::path& external_file_name,
|
|
size_t initializer_size_threshold,
|
|
const Graph::OffsetAlignmentInfo& align_info) {
|
|
return SaveModelWithExternalInitializers(model, file_path, external_file_name, initializer_size_threshold,
|
|
align_info);
|
|
}
|
|
|
|
Status Model::LoadFromBytes(int count, const void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) {
|
|
const bool result = model_proto.ParseFromArray(p_bytes, count);
|
|
if (!result) {
|
|
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ std::shared_ptr<Model>& p_model,
|
|
const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger,
|
|
const ModelOptions& options) {
|
|
return LoadFromBytes(count, p_bytes, PathString{}, p_model, local_registries, logger, options);
|
|
}
|
|
|
|
Status Model::LoadFromBytes(int count, void* p_bytes, const PathString& model_path,
|
|
std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
|
const logging::Logger& logger, const ModelOptions& options) {
|
|
ModelProto model_proto;
|
|
|
|
auto status = LoadFromBytes(count, p_bytes, model_proto);
|
|
if (!status.IsOK()) {
|
|
return status;
|
|
}
|
|
|
|
p_model = std::make_shared<Model>(std::move(model_proto), model_path, local_registries, logger, options);
|
|
|
|
Graph::ResolveOptions resolve_options;
|
|
resolve_options.no_proto_sync_required = true;
|
|
ORT_RETURN_IF_ERROR(p_model->MainGraph().Resolve(resolve_options));
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
using ::google::protobuf::io::CodedInputStream;
|
|
using ::google::protobuf::io::FileInputStream;
|
|
using ::google::protobuf::io::ZeroCopyInputStream;
|
|
|
|
Status Model::Load(int fd, ONNX_NAMESPACE::ModelProto& model_proto) {
|
|
if (fd < 0) {
|
|
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "<p_fd> less than 0.");
|
|
}
|
|
|
|
#if GOOGLE_PROTOBUF_VERSION >= 3002000
|
|
size_t file_size = 0;
|
|
int block_size = -1;
|
|
Status st = Env::Default().GetFileLength(fd, file_size);
|
|
if (st.IsOK()) {
|
|
block_size = std::min(DEFAULT_PROTOBUF_BLOCK_SIZE, static_cast<int>(file_size));
|
|
}
|
|
FileInputStream input(fd, block_size);
|
|
const bool result = model_proto.ParseFromZeroCopyStream(&input) && input.GetErrno() == 0;
|
|
if (!result) {
|
|
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
|
|
}
|
|
#else
|
|
// CNTK uses ORT as a submodule in order to use its GraphIR code.
|
|
// CNTK needs to be built with protobuf 3.1.0 for its version specific features.
|
|
// This code block is needed to support CNTK and any other
|
|
// GraphIR client that will be built with protobuf at a version older than 3.2.0.
|
|
FileInputStream fs(fd);
|
|
CodedInputStream cis(&fs);
|
|
|
|
// Allows protobuf library versions < 3.2.0 to parse messages greater than 64MB.
|
|
cis.SetTotalBytesLimit(INT_MAX);
|
|
if (!model_proto->ParseFromCodedStream(&cis)) {
|
|
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
|
|
}
|
|
#endif
|
|
return Status::OK();
|
|
}
|
|
|
|
Status Model::Load(int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
|
const logging::Logger& logger, const ModelOptions& options) {
|
|
return Load(fd, PathString{}, p_model, local_registries, logger, options);
|
|
}
|
|
|
|
Status Model::Load(int fd, const PathString& model_path, std::shared_ptr<Model>& p_model,
|
|
const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger,
|
|
const ModelOptions& options) {
|
|
ModelProto model_proto;
|
|
|
|
ORT_RETURN_IF_ERROR(Load(fd, model_proto));
|
|
|
|
p_model = std::make_shared<Model>(std::move(model_proto), model_path, local_registries, logger, options);
|
|
|
|
Graph::ResolveOptions resolve_options;
|
|
resolve_options.no_proto_sync_required = true;
|
|
ORT_RETURN_IF_ERROR(p_model->MainGraph().Resolve(resolve_options));
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status Model::Save(Model& model, int p_fd) {
|
|
if (p_fd < 0) {
|
|
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "<p_fd> is less than 0.");
|
|
}
|
|
|
|
ORT_RETURN_IF_ERROR(model.MainGraph().Resolve());
|
|
|
|
auto model_proto = model.ToProto();
|
|
google::protobuf::io::FileOutputStream output(p_fd);
|
|
const bool result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush();
|
|
if (result) {
|
|
return Status::OK();
|
|
}
|
|
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf serialization failed.");
|
|
}
|
|
|
|
Status Model::SaveWithExternalInitializers(Model& model,
|
|
int fd,
|
|
const std::filesystem::path& file_path,
|
|
const std::filesystem::path& external_file_name,
|
|
size_t initializer_size_threshold,
|
|
const Graph::OffsetAlignmentInfo& align_info) {
|
|
if (fd < 0) {
|
|
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "<fd> is less than 0.");
|
|
}
|
|
|
|
ORT_RETURN_IF_ERROR(model.MainGraph().Resolve());
|
|
|
|
auto model_proto = model.ToGraphProtoWithExternalInitializers(external_file_name, file_path,
|
|
initializer_size_threshold,
|
|
align_info);
|
|
google::protobuf::io::FileOutputStream output(fd);
|
|
const bool result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush();
|
|
if (result) {
|
|
return Status::OK();
|
|
}
|
|
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf serialization failed.");
|
|
}
|
|
|
|
common::Status Model::SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
|
flatbuffers::Offset<fbs::Model>& fbs_model) const {
|
|
auto producer_name = fbs::utils::SaveStringToOrtFormat(
|
|
builder, model_proto_.has_producer_name(), model_proto_.producer_name());
|
|
auto producer_version = fbs::utils::SaveStringToOrtFormat(
|
|
builder, model_proto_.has_producer_version(), model_proto_.producer_version());
|
|
auto domain = builder.CreateSharedString(model_proto_.domain());
|
|
auto doc_string = fbs::utils::SaveStringToOrtFormat(
|
|
builder, model_proto_.has_doc_string(), model_proto_.doc_string());
|
|
auto graph_doc_string = fbs::utils::SaveStringToOrtFormat(
|
|
builder, model_proto_.has_graph() && model_proto_.graph().has_doc_string(), model_proto_.graph().doc_string());
|
|
|
|
std::vector<flatbuffers::Offset<fbs::OperatorSetId>> op_set_ids_vec;
|
|
op_set_ids_vec.reserve(model_proto_.opset_import().size());
|
|
for (const auto& entry : model_proto_.opset_import()) {
|
|
auto op_set_domain = builder.CreateSharedString(entry.domain());
|
|
fbs::OperatorSetIdBuilder ob(builder);
|
|
ob.add_domain(op_set_domain);
|
|
ob.add_version(entry.version());
|
|
op_set_ids_vec.push_back(ob.Finish());
|
|
}
|
|
auto op_set_ids = builder.CreateVector(op_set_ids_vec);
|
|
|
|
flatbuffers::Offset<flatbuffers::Vector<
|
|
flatbuffers::Offset<onnxruntime::fbs::StringStringEntry>>>
|
|
metadata_props{0};
|
|
|
|
// We will not serialize an empty metadata_props
|
|
if (!model_metadata_.empty()) {
|
|
std::vector<flatbuffers::Offset<onnxruntime::fbs::StringStringEntry>> metadata_props_vec;
|
|
metadata_props_vec.reserve(model_metadata_.size());
|
|
for (const auto& prop : model_metadata_) {
|
|
metadata_props_vec.push_back(
|
|
fbs::CreateStringStringEntryDirect(builder, prop.first.c_str(), prop.second.c_str()));
|
|
}
|
|
metadata_props = builder.CreateVector(metadata_props_vec);
|
|
}
|
|
|
|
flatbuffers::Offset<fbs::Graph> fbs_graph;
|
|
ORT_RETURN_IF_ERROR(graph_->SaveToOrtFormat(builder, fbs_graph));
|
|
|
|
fbs::ModelBuilder mb(builder);
|
|
mb.add_ir_version(IrVersion());
|
|
mb.add_opset_import(op_set_ids);
|
|
mb.add_producer_name(producer_name);
|
|
mb.add_producer_version(producer_version);
|
|
mb.add_domain(domain);
|
|
mb.add_model_version(ModelVersion());
|
|
mb.add_doc_string(doc_string);
|
|
mb.add_graph_doc_string(graph_doc_string);
|
|
mb.add_metadata_props(metadata_props);
|
|
mb.add_graph(fbs_graph);
|
|
|
|
// add graph
|
|
fbs_model = mb.Finish();
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
#endif // !defined(ORT_MINIMAL_BUILD)
|
|
|
|
Model::Model() : model_path_{} {
|
|
}
|
|
|
|
common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
|
#endif
|
|
const OrtFormatLoadOptions& load_options,
|
|
const logging::Logger& logger,
|
|
std::unique_ptr<Model>& model) {
|
|
model = std::make_unique<Model>();
|
|
|
|
// Load the model metadata
|
|
if (const auto* fbs_metadata_props = fbs_model.metadata_props()) {
|
|
model->model_metadata_.reserve(fbs_metadata_props->size());
|
|
for (const auto* prop : *fbs_metadata_props) {
|
|
ORT_RETURN_IF(nullptr == prop, "Null entry in metadata_props. Invalid ORT format model.");
|
|
std::string key, value;
|
|
fbs::utils::LoadStringFromOrtFormat(key, prop->key());
|
|
fbs::utils::LoadStringFromOrtFormat(value, prop->value());
|
|
model->model_metadata_.insert({key, value});
|
|
}
|
|
}
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
LOAD_STR_FROM_ORT_FORMAT(model->model_proto_, producer_name, fbs_model.producer_name());
|
|
LOAD_STR_FROM_ORT_FORMAT(model->model_proto_, producer_version, fbs_model.producer_version());
|
|
LOAD_STR_FROM_ORT_FORMAT(model->model_proto_, domain, fbs_model.domain());
|
|
LOAD_STR_FROM_ORT_FORMAT(model->model_proto_, doc_string, fbs_model.doc_string());
|
|
if (fbs_model.graph_doc_string()) {
|
|
model->model_proto_.mutable_graph()->set_doc_string(fbs_model.graph_doc_string()->c_str());
|
|
}
|
|
model->model_proto_.set_model_version(fbs_model.model_version());
|
|
model->model_proto_.set_ir_version(fbs_model.ir_version());
|
|
|
|
auto schema_registry = std::make_shared<SchemaRegistryManager>();
|
|
if (local_registries != nullptr) {
|
|
for (const auto& schema_collection : *local_registries) {
|
|
schema_registry->RegisterRegistry(schema_collection);
|
|
}
|
|
}
|
|
|
|
// Populate the metadata to model_proto
|
|
for (auto& metadata : model->model_metadata_) {
|
|
const gsl::not_null<StringStringEntryProto*> prop{model->model_proto_.add_metadata_props()};
|
|
prop->set_key(metadata.first);
|
|
prop->set_value(metadata.second);
|
|
}
|
|
#else
|
|
fbs::utils::LoadStringFromOrtFormat(model->producer_name_, fbs_model.producer_name());
|
|
fbs::utils::LoadStringFromOrtFormat(model->producer_version_, fbs_model.producer_version());
|
|
fbs::utils::LoadStringFromOrtFormat(model->domain_, fbs_model.domain());
|
|
fbs::utils::LoadStringFromOrtFormat(model->doc_string_, fbs_model.doc_string());
|
|
fbs::utils::LoadStringFromOrtFormat(model->graph_doc_string_, fbs_model.graph_doc_string());
|
|
model->model_version_ = fbs_model.model_version();
|
|
model->ir_version_ = fbs_model.ir_version();
|
|
#endif
|
|
|
|
std::unordered_map<std::string, int> domain_to_version;
|
|
ORT_RETURN_IF_ERROR(fbs::utils::LoadOpsetImportOrtFormat(fbs_model.opset_import(), domain_to_version));
|
|
|
|
auto fbs_graph = fbs_model.graph();
|
|
ORT_RETURN_IF(nullptr == fbs_graph, "Graph is null. Invalid ORT format model.");
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
// add the opset imports to the model_proto in case we're updating an ORT format model and need those to be
|
|
// included when SaveToOrtFormat is called later
|
|
for (const auto& [domain, version] : domain_to_version) {
|
|
const gsl::not_null<OperatorSetIdProto*> opset_id_proto{model->model_proto_.add_opset_import()};
|
|
opset_id_proto->set_domain(domain);
|
|
opset_id_proto->set_version(version);
|
|
}
|
|
|
|
ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version, schema_registry,
|
|
load_options, logger, model->graph_));
|
|
#else
|
|
ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version,
|
|
load_options, logger, model->graph_));
|
|
#endif
|
|
return Status::OK();
|
|
}
|
|
|
|
} // namespace onnxruntime
|