mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Ignore saved runtime optimizations when updating ORT format model <v5. (#13393)
The old runtime optimization format is not readily convertible to the new one without extra information for translating kernel def hashes. Ignore such saved runtime optimizations and output a warning for now.
This commit is contained in:
parent
b383312f4c
commit
215732f74b
10 changed files with 175 additions and 94 deletions
|
|
@ -46,6 +46,7 @@
|
|||
#endif
|
||||
#include "core/graph/graph_nodes.h"
|
||||
#include "core/graph/node_arg.h"
|
||||
#include "core/graph/ort_format_load_options.h"
|
||||
|
||||
namespace flatbuffers {
|
||||
class FlatBufferBuilder;
|
||||
|
|
@ -487,11 +488,11 @@ class Node {
|
|||
#endif
|
||||
|
||||
static Status LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, Graph& graph,
|
||||
bool can_use_flatbuffer_for_initializers,
|
||||
const OrtFormatLoadOptions& load_options,
|
||||
const logging::Logger& logger, std::unique_ptr<Node>& node);
|
||||
|
||||
Status LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node,
|
||||
bool can_use_flatbuffer_for_initializers,
|
||||
const OrtFormatLoadOptions& load_options,
|
||||
const logging::Logger& logger);
|
||||
Status LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge& fbs_node_edgs, const Graph& graph);
|
||||
|
||||
|
|
@ -1312,18 +1313,18 @@ class Graph {
|
|||
|
||||
virtual ~Graph();
|
||||
|
||||
static common::Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, const Model& owning_model,
|
||||
const std::unordered_map<std::string, int>& domain_to_version,
|
||||
static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, const Model& owning_model,
|
||||
const std::unordered_map<std::string, int>& domain_to_version,
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
|
||||
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
|
||||
#endif
|
||||
bool can_use_flatbuffer_for_initializers,
|
||||
const logging::Logger& logger, std::unique_ptr<Graph>& graph);
|
||||
const OrtFormatLoadOptions& load_options,
|
||||
const logging::Logger& logger, std::unique_ptr<Graph>& graph);
|
||||
|
||||
// deserialize a subgraph
|
||||
static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
|
||||
Graph& parent_graph, const Node& parent_node,
|
||||
bool can_use_flatbuffer_for_initializers,
|
||||
const OrtFormatLoadOptions& load_options,
|
||||
const logging::Logger& logger, std::unique_ptr<Graph>& graph);
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
|
@ -1353,8 +1354,8 @@ class Graph {
|
|||
bool strict_shape_type_inference);
|
||||
|
||||
// Populate Graph instance from ORT format serialized data.
|
||||
common::Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
|
||||
bool can_use_flatbuffer_for_initializers);
|
||||
Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
|
||||
const OrtFormatLoadOptions& load_options);
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
// Constructor: Given a <GraphProto> loaded from model file, construct
|
||||
|
|
|
|||
|
|
@ -723,14 +723,14 @@ flatbuffers::Offset<fbs::NodeEdge> Node::SaveEdgesToOrtFormat(flatbuffers::FlatB
|
|||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
Status Node::LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, Graph& graph,
|
||||
bool can_use_flatbuffer_for_initializers,
|
||||
const OrtFormatLoadOptions& load_options,
|
||||
const logging::Logger& logger, std::unique_ptr<Node>& node) {
|
||||
node = std::make_unique<Node>(fbs_node.index(), graph);
|
||||
return node->LoadFromOrtFormat(fbs_node, can_use_flatbuffer_for_initializers, logger);
|
||||
return node->LoadFromOrtFormat(fbs_node, load_options, logger);
|
||||
}
|
||||
|
||||
Status Node::LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node,
|
||||
bool can_use_flatbuffer_for_initializers,
|
||||
const OrtFormatLoadOptions& load_options,
|
||||
const logging::Logger& logger) {
|
||||
auto LoadNodeArgsFromOrtFormat =
|
||||
[&](const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>* fbs_node_arg_names,
|
||||
|
|
@ -769,8 +769,7 @@ Status Node::LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node,
|
|||
AttributeProto attr_proto;
|
||||
std::unique_ptr<Graph> subgraph;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
fbs::utils::LoadAttributeOrtFormat(*fbs_attr, attr_proto, subgraph, *graph_, *this,
|
||||
can_use_flatbuffer_for_initializers, logger));
|
||||
fbs::utils::LoadAttributeOrtFormat(*fbs_attr, attr_proto, subgraph, *graph_, *this, load_options, logger));
|
||||
|
||||
// If we have a sub graph in this attributes, it will be loaded into subgraph ptr
|
||||
// while the attribute proto contains the sub graph will have the empty g() field
|
||||
|
|
@ -4200,7 +4199,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
|
|||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
|
||||
#endif
|
||||
bool can_use_flatbuffer_for_initializers,
|
||||
const OrtFormatLoadOptions& load_options,
|
||||
const logging::Logger& logger, std::unique_ptr<Graph>& graph) {
|
||||
graph = std::make_unique<Graph>(owning_model, domain_to_version,
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
|
@ -4210,7 +4209,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
|
|||
// Assume anything in ORT format has already been validated.
|
||||
false);
|
||||
|
||||
ORT_RETURN_IF_ERROR(graph->LoadFromOrtFormat(fbs_graph, can_use_flatbuffer_for_initializers));
|
||||
ORT_RETURN_IF_ERROR(graph->LoadFromOrtFormat(fbs_graph, load_options));
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
// in a full build we need to run Resolve to fully populate ResolveContext and Node::op_,
|
||||
|
|
@ -4226,7 +4225,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
|
|||
|
||||
Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
|
||||
Graph& parent_graph, const Node& parent_node,
|
||||
bool can_use_flatbuffer_for_initializers,
|
||||
const OrtFormatLoadOptions& load_options,
|
||||
const logging::Logger& logger, std::unique_ptr<Graph>& graph) {
|
||||
graph = std::make_unique<Graph>(parent_graph.owning_model_,
|
||||
parent_graph.domain_to_version_,
|
||||
|
|
@ -4238,7 +4237,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
|
|||
// Assume anything in ORT format has already been validated.
|
||||
false);
|
||||
|
||||
return graph->LoadFromOrtFormat(fbs_graph, can_use_flatbuffer_for_initializers);
|
||||
return graph->LoadFromOrtFormat(fbs_graph, load_options);
|
||||
}
|
||||
|
||||
Graph::Graph(const Model& owning_model,
|
||||
|
|
@ -4268,7 +4267,7 @@ Graph::Graph(const Model& owning_model,
|
|||
}
|
||||
|
||||
common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
|
||||
bool can_use_flatbuffer_for_initializers) {
|
||||
const OrtFormatLoadOptions& load_options) {
|
||||
// We deserialize the graph from ORT format in the following order:
|
||||
// 1. Deserialize the initializers and sparse initializers. Convert sparse to dense.
|
||||
// 2. Deserialize the NodeArgs
|
||||
|
|
@ -4298,8 +4297,7 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph
|
|||
for (const auto* fbs_tensor : *fbs_initializers) {
|
||||
ORT_RETURN_IF(nullptr == fbs_tensor, "Initializer tensor is missing. Invalid ORT format model.");
|
||||
TensorProto* initializer = deserialized_proto_data_.add_initializer();
|
||||
ORT_RETURN_IF_ERROR(fbs::utils::LoadInitializerOrtFormat(*fbs_tensor, *initializer,
|
||||
can_use_flatbuffer_for_initializers));
|
||||
ORT_RETURN_IF_ERROR(fbs::utils::LoadInitializerOrtFormat(*fbs_tensor, *initializer, load_options));
|
||||
auto p = name_to_initial_tensor_.emplace(initializer->name(), initializer);
|
||||
if (!p.second) {
|
||||
LOGS(logger_, WARNING) << "Duplicate initializer (dense or ConstantNode): '" << initializer->name()
|
||||
|
|
@ -4318,7 +4316,8 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph
|
|||
for (const auto* fbs_sparse_tensor : *fbs_sparse_initializers) {
|
||||
ORT_RETURN_IF(nullptr == fbs_sparse_tensor, "Sparse Initializer tensor is missing. Invalid ORT format model.");
|
||||
SparseTensorProto sparse_initializer;
|
||||
ORT_RETURN_IF_ERROR(fbs::utils::LoadSparseInitializerOrtFormat(*fbs_sparse_tensor, sparse_initializer));
|
||||
ORT_RETURN_IF_ERROR(fbs::utils::LoadSparseInitializerOrtFormat(*fbs_sparse_tensor, sparse_initializer,
|
||||
load_options));
|
||||
TensorProto& initializer = *deserialized_proto_data_.add_initializer();
|
||||
ORT_RETURN_IF_ERROR(utils::SparseTensorProtoToDenseTensorProto(sparse_initializer, model_path, initializer));
|
||||
auto p = name_to_initial_tensor_.emplace(initializer.name(), &initializer);
|
||||
|
|
@ -4359,8 +4358,7 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph
|
|||
for (const auto* fbs_node : *fbs_nodes) {
|
||||
ORT_RETURN_IF(nullptr == fbs_node, "Node is missing. Invalid ORT format model.");
|
||||
std::unique_ptr<Node> node;
|
||||
ORT_RETURN_IF_ERROR(Node::LoadFromOrtFormat(*fbs_node, *this, can_use_flatbuffer_for_initializers, logger_,
|
||||
node));
|
||||
ORT_RETURN_IF_ERROR(Node::LoadFromOrtFormat(*fbs_node, *this, load_options, logger_, node));
|
||||
ORT_RETURN_IF(node->Index() >= fbs_graph.max_node_index(), "Node index is out of range");
|
||||
nodes_[node->Index()] = std::move(node);
|
||||
++num_of_nodes_;
|
||||
|
|
@ -4407,9 +4405,11 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph
|
|||
ORT_RETURN_IF_ERROR(PopulateNodeArgToProducerConsumerLookupsFromNodes());
|
||||
|
||||
// runtime optimizations
|
||||
if (const auto* fbs_runtime_optimizations = fbs_graph.runtime_optimizations()) {
|
||||
if (const auto* fbs_runtime_optimization_records = fbs_runtime_optimizations->records()) {
|
||||
ORT_RETURN_IF_ERROR(MutableRuntimeOptimizations().LoadFromOrtFormat(*fbs_runtime_optimization_records));
|
||||
if (!load_options.ignore_saved_runtime_optimizations) {
|
||||
if (const auto* fbs_runtime_optimizations = fbs_graph.runtime_optimizations()) {
|
||||
if (const auto* fbs_runtime_optimization_records = fbs_runtime_optimizations->records()) {
|
||||
ORT_RETURN_IF_ERROR(MutableRuntimeOptimizations().LoadFromOrtFormat(*fbs_runtime_optimization_records));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
|
|
|||
|
|
@ -177,7 +177,7 @@ Status SaveAttributeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
|
|||
#endif
|
||||
|
||||
Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& initializer,
|
||||
bool can_use_flatbuffer_for_initializers) {
|
||||
const OrtFormatLoadOptions& load_options) {
|
||||
initializer.Clear();
|
||||
|
||||
LOAD_STR_FROM_ORT_FORMAT(initializer, name, fbs_tensor.name());
|
||||
|
|
@ -201,7 +201,7 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init
|
|||
const auto* fbs_raw_data = fbs_tensor.raw_data();
|
||||
ORT_RETURN_IF(nullptr == fbs_raw_data, "Missing raw data for initializer. Invalid ORT format model.");
|
||||
|
||||
if (can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) {
|
||||
if (load_options.can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) {
|
||||
initializer.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL);
|
||||
|
||||
static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE));
|
||||
|
|
@ -231,19 +231,20 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init
|
|||
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
Status LoadSparseInitializerOrtFormat(const fbs::SparseTensor& fbs_sparse_tensor,
|
||||
SparseTensorProto& initializer) {
|
||||
SparseTensorProto& initializer,
|
||||
const OrtFormatLoadOptions& load_options) {
|
||||
SparseTensorProto loaded_initializer;
|
||||
auto fbs_values_tensor = fbs_sparse_tensor.values();
|
||||
ORT_RETURN_IF(nullptr == fbs_values_tensor, "Missing values for sparse initializer. Invalid ORT format model.");
|
||||
auto* values_tensor = loaded_initializer.mutable_values();
|
||||
ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_values_tensor, *values_tensor));
|
||||
ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_values_tensor, *values_tensor, load_options));
|
||||
ORT_RETURN_IF(values_tensor->name().empty(), "Missing name for SparseTensor initializer. Invalid ORT format model.");
|
||||
|
||||
auto fbs_indicies_tensor = fbs_sparse_tensor.indices();
|
||||
ORT_RETURN_IF(nullptr == fbs_indicies_tensor, "Missing indicies for sparse initializer: ", "'", values_tensor->name(), "'",
|
||||
"Invalid ORT format model.");
|
||||
auto* indicies_tensor = loaded_initializer.mutable_indices();
|
||||
ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_indicies_tensor, *indicies_tensor));
|
||||
ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_indicies_tensor, *indicies_tensor, load_options));
|
||||
|
||||
auto fbs_dims = fbs_sparse_tensor.dims();
|
||||
ORT_RETURN_IF(nullptr == fbs_dims, "Missing dims for sparse initializer: ", "'", values_tensor->name(), "'",
|
||||
|
|
@ -259,7 +260,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr,
|
|||
ONNX_NAMESPACE::AttributeProto& attr_proto,
|
||||
std::unique_ptr<onnxruntime::Graph>& sub_graph,
|
||||
onnxruntime::Graph& graph, onnxruntime::Node& node,
|
||||
bool can_use_flatbuffer_for_initializers,
|
||||
const OrtFormatLoadOptions& load_options,
|
||||
const logging::Logger& logger) {
|
||||
attr_proto.Clear();
|
||||
LOAD_STR_FROM_ORT_FORMAT(attr_proto, name, fbs_attr.name());
|
||||
|
|
@ -283,7 +284,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr,
|
|||
auto fbs_tensor = fbs_attr.t();
|
||||
ORT_RETURN_IF(nullptr == fbs_tensor, "Null tensor attribute. Invalid ORT format model.");
|
||||
ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_tensor, *attr_proto.mutable_t(),
|
||||
can_use_flatbuffer_for_initializers));
|
||||
load_options));
|
||||
} break;
|
||||
case AttributeProto_AttributeType_GRAPH: {
|
||||
// If the attribute type is a graph, we will create an empty graph in attr_proto so that the ONNX checker
|
||||
|
|
@ -292,7 +293,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr,
|
|||
ORT_RETURN_IF(nullptr == fbs_graph, "Null graph attribute. Invalid ORT format model.");
|
||||
attr_proto.mutable_g()->set_name("Empty graph proto from deserialization of ORT format model");
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::Graph::LoadFromOrtFormat(*fbs_graph, graph, node,
|
||||
can_use_flatbuffer_for_initializers,
|
||||
load_options,
|
||||
logger, sub_graph));
|
||||
} break;
|
||||
case AttributeProto_AttributeType_FLOATS: {
|
||||
|
|
@ -327,7 +328,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr,
|
|||
for (const auto* fbs_tensor : *fbs_tensors) {
|
||||
ORT_RETURN_IF(nullptr == fbs_tensor, "Null tensor in tensors attribute. Invalid ORT format model.");
|
||||
ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_tensor, *tensors->Add(),
|
||||
can_use_flatbuffer_for_initializers));
|
||||
load_options));
|
||||
}
|
||||
} break;
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <memory>
|
||||
|
||||
#include "core/common/status.h"
|
||||
#include "core/graph/ort_format_load_options.h"
|
||||
|
||||
namespace ONNX_NAMESPACE {
|
||||
class AttributeProto;
|
||||
|
|
@ -66,18 +67,16 @@ Status SaveAttributeOrtFormat(
|
|||
/// </summary>
|
||||
/// <param name="fbs_tensor">Flatbuffer Tensor</param>
|
||||
/// <param name="initializer">TensorProto to load data into</param>
|
||||
/// <param name="can_use_flatbuffer_for_initializers">
|
||||
/// If true, set the TensorProto to point to the memory in the flatbuffer instead of copying data.
|
||||
/// This requires the buffer to remain valid for the entire duration of the InferenceSession.
|
||||
/// </param>
|
||||
/// <param name="load_options">ORT format load options</param>
|
||||
/// <returns>Status</returns>
|
||||
Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor,
|
||||
ONNX_NAMESPACE::TensorProto& initializer,
|
||||
bool can_use_flatbuffer_for_initializers = false);
|
||||
const OrtFormatLoadOptions& load_options);
|
||||
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
Status LoadSparseInitializerOrtFormat(const fbs::SparseTensor& fbs_sparse_tensor,
|
||||
ONNX_NAMESPACE::SparseTensorProto& initializer);
|
||||
ONNX_NAMESPACE::SparseTensorProto& initializer,
|
||||
const OrtFormatLoadOptions& load_options);
|
||||
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
||||
|
||||
// Load a give fbs::Attribute into AttributeProto
|
||||
|
|
@ -87,7 +86,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr,
|
|||
ONNX_NAMESPACE::AttributeProto& attr_proto,
|
||||
std::unique_ptr<onnxruntime::Graph>& sub_graph,
|
||||
onnxruntime::Graph& graph, onnxruntime::Node& node,
|
||||
bool can_use_flatbuffer_for_initializers,
|
||||
const OrtFormatLoadOptions& load_options,
|
||||
const logging::Logger& logger);
|
||||
|
||||
} // namespace utils
|
||||
|
|
|
|||
|
|
@ -771,7 +771,7 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
|
|||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
#endif
|
||||
bool can_use_flatbuffer_for_initializers,
|
||||
const OrtFormatLoadOptions& load_options,
|
||||
const logging::Logger& logger,
|
||||
std::unique_ptr<Model>& model) {
|
||||
model = std::make_unique<Model>();
|
||||
|
|
@ -838,10 +838,10 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
|
|||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version, schema_registry,
|
||||
can_use_flatbuffer_for_initializers, logger, model->graph_));
|
||||
load_options, logger, model->graph_));
|
||||
#else
|
||||
ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version,
|
||||
can_use_flatbuffer_for_initializers, logger, model->graph_));
|
||||
load_options, logger, model->graph_));
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#include <string>
|
||||
#include "core/common/path.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/graph/ort_format_load_options.h"
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
#include "core/graph/function_template.h"
|
||||
|
|
@ -297,7 +298,7 @@ class Model {
|
|||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
|
||||
#endif
|
||||
bool can_use_flatbuffer_for_initializers,
|
||||
const OrtFormatLoadOptions& load_options,
|
||||
const logging::Logger& logger,
|
||||
std::unique_ptr<Model>& model);
|
||||
|
||||
|
|
|
|||
18
onnxruntime/core/graph/ort_format_load_options.h
Normal file
18
onnxruntime/core/graph/ort_format_load_options.h
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/// Options to configure how an ORT format model is loaded.
|
||||
struct OrtFormatLoadOptions {
|
||||
/// If true, set initializer TensorProtos to point to memory in the flatbuffer instead of copying data.
|
||||
/// This requires the flatbuffer to remain valid for the entire duration of the InferenceSession.
|
||||
bool can_use_flatbuffer_for_initializers{true};
|
||||
|
||||
/// If true, do not load any saved runtime optimizations.
|
||||
bool ignore_saved_runtime_optimizations{false};
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1024,8 +1024,10 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function<Status()> load_ort
|
|||
const auto* fbs_ort_model_version = fbs_session->ort_version();
|
||||
ORT_RETURN_IF(fbs_ort_model_version == nullptr, "Serialized version info is null. Invalid ORT format model.");
|
||||
|
||||
auto model_version = std::stoi(fbs_ort_model_version->str());
|
||||
bool is_supported = IsOrtModelVersionSupported(model_version);
|
||||
const auto model_version = std::stoi(fbs_ort_model_version->str());
|
||||
const bool is_supported = IsOrtModelVersionSupported(model_version);
|
||||
|
||||
OrtFormatLoadOptions load_options{};
|
||||
|
||||
#if defined(ORT_MINIMAL_BUILD)
|
||||
// Note about the ORT format version 5 breaking change.
|
||||
|
|
@ -1038,16 +1040,34 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function<Status()> load_ort
|
|||
"The ORT format model version [", fbs_ort_model_version->string_view(),
|
||||
"] is not supported in this build ", ORT_VERSION, ". ",
|
||||
kOrtFormatVersion5BreakingChangeNote);
|
||||
#else
|
||||
// models prior to v5 can be handled by inserting the kernel constraints in a full build
|
||||
bool is_supported_with_update = !is_supported && model_version < 5;
|
||||
#else // ^^ defined(ORT_MINIMAL_BUILD) ^^ / vv !defined(ORT_MINIMAL_BUILD) vv
|
||||
const auto has_saved_runtime_optimizations = [](const fbs::InferenceSession& fbs_session) -> bool {
|
||||
if (const auto* fbs_model = fbs_session.model()) {
|
||||
if (const auto* fbs_graph = fbs_model->graph()) {
|
||||
if (const auto* fbs_runtime_opts = fbs_graph->runtime_optimizations()) {
|
||||
if (const auto* fbs_runtime_opt_records = fbs_runtime_opts->records()) {
|
||||
return fbs_runtime_opt_records->size() > 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// models prior to v5 can be handled by inserting the kernel constraints in a full build
|
||||
const bool is_supported_with_update = model_version < 5;
|
||||
|
||||
if (is_supported_with_update && has_saved_runtime_optimizations(*fbs_session)) {
|
||||
LOGS(*session_logger_, WARNING)
|
||||
<< "The old ORT format model (version " << fbs_ort_model_version->string_view()
|
||||
<< ") has saved runtime optimizations. They will be ignored.";
|
||||
load_options.ignore_saved_runtime_optimizations = true;
|
||||
}
|
||||
|
||||
// currently this means the model is using a future version
|
||||
// i.e. attempted load of model created with future version of ORT
|
||||
ORT_RETURN_IF_NOT(is_supported || is_supported_with_update,
|
||||
"The ORT format model version [", fbs_ort_model_version->string_view(),
|
||||
"] is not supported in this build ", ORT_VERSION, ". ");
|
||||
#endif
|
||||
"] is not supported in this build ", ORT_VERSION, ".");
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
const auto* fbs_model = fbs_session->model();
|
||||
ORT_RETURN_IF(nullptr == fbs_model, "Missing Model. Invalid ORT format model.");
|
||||
|
|
@ -1058,19 +1078,18 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function<Status()> load_ort
|
|||
// if that is the case we also allow creating initializers that directly use those bytes.
|
||||
const auto& config_options = session_options_.config_options;
|
||||
using_ort_model_bytes_for_initializers_ =
|
||||
ort_format_model_bytes_data_holder_.empty() &&
|
||||
config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseORTModelBytesForInitializers, "0") == "1";
|
||||
load_options.can_use_flatbuffer_for_initializers =
|
||||
ort_format_model_bytes_data_holder_.empty() &&
|
||||
config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseORTModelBytesForInitializers, "0") == "1";
|
||||
|
||||
// need to go from unique_ptr to shared_ptr when moving into model_
|
||||
std::unique_ptr<Model> tmp_model;
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model,
|
||||
HasLocalSchema() ? &custom_schema_registries_ : nullptr,
|
||||
using_ort_model_bytes_for_initializers_,
|
||||
*session_logger_, tmp_model));
|
||||
load_options, *session_logger_, tmp_model));
|
||||
#else
|
||||
ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model, using_ort_model_bytes_for_initializers_, *session_logger_,
|
||||
tmp_model));
|
||||
ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model, load_options, *session_logger_, tmp_model));
|
||||
#endif
|
||||
|
||||
ORT_RETURN_IF_ERROR(SaveModelMetadata(*tmp_model));
|
||||
|
|
|
|||
|
|
@ -18,11 +18,10 @@
|
|||
#include "flatbuffers/idl.h"
|
||||
#include "flatbuffers/util.h"
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::logging;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
|
@ -97,7 +96,6 @@ static void CompareTensors(const OrtValue& left_value, const OrtValue& right_val
|
|||
EXPECT_EQ(left_strings[i], right_strings[i]) << "Mismatch index:" << i;
|
||||
}
|
||||
} else {
|
||||
|
||||
ASSERT_EQ(memcmp(left.DataRaw(), right.DataRaw(), left.SizeInBytes()), 0);
|
||||
}
|
||||
}
|
||||
|
|
@ -231,8 +229,8 @@ static void CompareSessionMetadata(const InferenceSessionWrapper& session_object
|
|||
ASSERT_EQ(model_1.ProducerVersion(), model_2.ProducerVersion());
|
||||
}
|
||||
|
||||
static void SaveAndCompareModels(const std::basic_string<ORTCHAR_T>& orig_file,
|
||||
const std::basic_string<ORTCHAR_T>& ort_file,
|
||||
static void SaveAndCompareModels(const PathString& orig_file,
|
||||
const PathString& ort_file,
|
||||
TransformerLevel optimization_level = TransformerLevel::Level3) {
|
||||
SessionOptions so;
|
||||
so.session_logid = "SerializeToOrtFormat";
|
||||
|
|
@ -301,7 +299,7 @@ TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) {
|
|||
test_info.configs.push_back(std::make_pair(kOrtSessionOptionsConfigLoadModelFormat, "ORT"));
|
||||
|
||||
OrtValue ml_value;
|
||||
vector<float> data(28 * 28, 0.0);
|
||||
std::vector<float> data(28 * 28, 0.0);
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1, 1, 28, 28}, data,
|
||||
&ml_value);
|
||||
test_info.inputs.insert(std::make_pair("Input3", ml_value));
|
||||
|
|
@ -372,59 +370,103 @@ TEST(OrtModelOnlyTests, MetadataSerialization) {
|
|||
}
|
||||
|
||||
// test we can load an old ORT format model and run it in a full build.
|
||||
// we changed from using kernel hashes to kernel type constraints in v5, so any old model should be able to be loaded
|
||||
// we changed from using kernel hashes to kernel type constraints in v5, so an old model should be able to be loaded
|
||||
// in a full build if we add the kernel type constraints during loading. this also means we can save the updated
|
||||
// ORT format model to effectively upgrade it to v5.
|
||||
TEST(OrtModelOnlyTests, UpdateOrtModelVersion) {
|
||||
// input is ORT format model using v4 where we used kernel hashes instead of constraints
|
||||
const auto onnx_file = ORT_TSTR("testdata/mnist.onnx");
|
||||
const auto ort_file_v4 = ORT_TSTR("testdata/mnist.basic.v4.ort");
|
||||
const auto ort_file_v5 = ORT_TSTR("testdata/mnist.basic.v5.test_output.ort");
|
||||
void TestOrtModelUpdate(const PathString& onnx_file,
|
||||
const PathString& ort_file_v4,
|
||||
const PathString& generated_ort_file_v5,
|
||||
const std::function<void(NameMLValMap& inputs, std::vector<std::string>& output_names)>&
|
||||
set_up_test_inputs_and_outputs_fn) {
|
||||
// ort_file_v4 is ORT format model using v4 where we used kernel hashes instead of constraints
|
||||
|
||||
// update v4 model and save as v5. do not run optimizations in order to preserve the model as-is.
|
||||
SaveAndCompareModels(ort_file_v4, ort_file_v5, TransformerLevel::Default);
|
||||
SaveAndCompareModels(ort_file_v4, generated_ort_file_v5, TransformerLevel::Default);
|
||||
|
||||
// run the original, v4 and v5 models and check the output is the same
|
||||
RandomValueGenerator random{};
|
||||
std::vector<int64_t> input_dims{1, 1, 28, 28};
|
||||
std::vector<float> input_data = random.Gaussian<float>(input_dims, 0.0f, 0.9f);
|
||||
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault),
|
||||
input_dims, input_data, &ml_value);
|
||||
|
||||
OrtModelTestInfo test_info;
|
||||
set_up_test_inputs_and_outputs_fn(test_info.inputs, test_info.output_names);
|
||||
|
||||
// keep the onnx and ort models to the same optimization level
|
||||
test_info.optimization_level = TransformerLevel::Level1;
|
||||
|
||||
test_info.inputs.insert(std::make_pair("Input3", ml_value));
|
||||
test_info.output_names = {"Plus214_Output_0"};
|
||||
|
||||
OrtValue orig_out, v4_out, v5_out;
|
||||
std::vector<OrtValue> orig_out, v4_out, v5_out;
|
||||
|
||||
test_info.model_filename = onnx_file;
|
||||
test_info.output_verifier = [&orig_out](const std::vector<OrtValue>& fetches) {
|
||||
orig_out = fetches[0];
|
||||
orig_out = fetches;
|
||||
};
|
||||
RunOrtModel(test_info);
|
||||
|
||||
// run with v4 as input. this should also update to v5 prior to execution.
|
||||
test_info.model_filename = ort_file_v4;
|
||||
test_info.output_verifier = [&v4_out](const std::vector<OrtValue>& fetches) {
|
||||
v4_out = fetches[0];
|
||||
v4_out = fetches;
|
||||
};
|
||||
RunOrtModel(test_info);
|
||||
|
||||
// validate the model saved as v5 also works
|
||||
test_info.model_filename = ort_file_v5;
|
||||
test_info.model_filename = generated_ort_file_v5;
|
||||
test_info.output_verifier = [&v5_out](const std::vector<OrtValue>& fetches) {
|
||||
v5_out = fetches[0];
|
||||
v5_out = fetches;
|
||||
};
|
||||
RunOrtModel(test_info);
|
||||
|
||||
CompareTensors(orig_out, v4_out);
|
||||
CompareTensors(v4_out, v5_out);
|
||||
auto compare_outputs = [](gsl::span<OrtValue> expected, gsl::span<OrtValue> actual) {
|
||||
ASSERT_EQ(expected.size(), actual.size());
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
CompareTensors(expected[i], actual[i]);
|
||||
}
|
||||
};
|
||||
|
||||
compare_outputs(orig_out, v4_out);
|
||||
compare_outputs(v4_out, v5_out);
|
||||
};
|
||||
|
||||
TEST(OrtModelOnlyTests, UpdateOrtModelVersion) {
|
||||
const auto onnx_file = ORT_TSTR("testdata/mnist.onnx");
|
||||
const auto ort_file_v4 = ORT_TSTR("testdata/mnist.basic.v4.ort");
|
||||
const auto ort_file_v5 = ORT_TSTR("testdata/mnist.basic.v5.test_output.ort");
|
||||
|
||||
RandomValueGenerator random{}; // keep in scope so we get random seed trace message on failure
|
||||
|
||||
TestOrtModelUpdate(onnx_file, ort_file_v4, ort_file_v5,
|
||||
[&](NameMLValMap& inputs, std::vector<std::string>& output_names) {
|
||||
std::vector<int64_t> input_dims{1, 1, 28, 28};
|
||||
std::vector<float> input_data = random.Gaussian<float>(input_dims, 0.0f, 0.9f);
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault),
|
||||
input_dims, input_data, &ml_value);
|
||||
|
||||
inputs = {{"Input3", ml_value}};
|
||||
output_names = {"Plus214_Output_0"};
|
||||
});
|
||||
}
|
||||
|
||||
// test that a model with saved runtime optimizations can also be updated
|
||||
// note: the saved runtime optimizations will be ignored
|
||||
TEST(OrtModelOnlyTests, UpdateOrtModelVersionWithSavedRuntimeOptimizations) {
|
||||
const auto onnx_file = ORT_TSTR("testdata/transform/runtime_optimization/qdq_convs.onnx");
|
||||
const auto ort_file_v4 = ORT_TSTR("testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v4.ort");
|
||||
const auto ort_file_v5 =
|
||||
ORT_TSTR("testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v5.test_output.ort");
|
||||
|
||||
RandomValueGenerator random{}; // keep in scope so we get random seed trace message on failure
|
||||
|
||||
TestOrtModelUpdate(onnx_file, ort_file_v4, ort_file_v5,
|
||||
[&](NameMLValMap& inputs, std::vector<std::string>& output_names) {
|
||||
constexpr int n = 3; // number of QDQ convs
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
std::vector<int64_t> input_dims{1, 1, 5, 5};
|
||||
std::vector<uint8_t> input_data = random.Uniform<uint8_t>(input_dims, 0, 255);
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<uint8_t>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault),
|
||||
input_dims, input_data, &ml_value);
|
||||
|
||||
inputs.emplace(MakeString("X_", i), std::move(ml_value));
|
||||
output_names.push_back(MakeString("Y_", i));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#if !defined(DISABLE_ML_OPS)
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v4.ort
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v4.ort
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue