Ignore saved runtime optimizations when updating ORT format model <v5. (#13393)

The old runtime optimization format is not readily convertible to the new one without extra information for translating kernel def hashes.
Ignore such saved runtime optimizations and output a warning for now.
This commit is contained in:
Edward Chen 2022-11-08 13:36:46 -08:00 committed by GitHub
parent b383312f4c
commit 215732f74b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 175 additions and 94 deletions

View file

@ -46,6 +46,7 @@
#endif
#include "core/graph/graph_nodes.h"
#include "core/graph/node_arg.h"
#include "core/graph/ort_format_load_options.h"
namespace flatbuffers {
class FlatBufferBuilder;
@ -487,11 +488,11 @@ class Node {
#endif
static Status LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, Graph& graph,
bool can_use_flatbuffer_for_initializers,
const OrtFormatLoadOptions& load_options,
const logging::Logger& logger, std::unique_ptr<Node>& node);
Status LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node,
bool can_use_flatbuffer_for_initializers,
const OrtFormatLoadOptions& load_options,
const logging::Logger& logger);
Status LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge& fbs_node_edgs, const Graph& graph);
@ -1312,18 +1313,18 @@ class Graph {
virtual ~Graph();
static common::Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, const Model& owning_model,
const std::unordered_map<std::string, int>& domain_to_version,
static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, const Model& owning_model,
const std::unordered_map<std::string, int>& domain_to_version,
#if !defined(ORT_MINIMAL_BUILD)
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
#endif
bool can_use_flatbuffer_for_initializers,
const logging::Logger& logger, std::unique_ptr<Graph>& graph);
const OrtFormatLoadOptions& load_options,
const logging::Logger& logger, std::unique_ptr<Graph>& graph);
// deserialize a subgraph
static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
Graph& parent_graph, const Node& parent_node,
bool can_use_flatbuffer_for_initializers,
const OrtFormatLoadOptions& load_options,
const logging::Logger& logger, std::unique_ptr<Graph>& graph);
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
@ -1353,8 +1354,8 @@ class Graph {
bool strict_shape_type_inference);
// Populate Graph instance from ORT format serialized data.
common::Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
bool can_use_flatbuffer_for_initializers);
Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
const OrtFormatLoadOptions& load_options);
#if !defined(ORT_MINIMAL_BUILD)
// Constructor: Given a <GraphProto> loaded from model file, construct

View file

@ -723,14 +723,14 @@ flatbuffers::Offset<fbs::NodeEdge> Node::SaveEdgesToOrtFormat(flatbuffers::FlatB
#endif // !defined(ORT_MINIMAL_BUILD)
Status Node::LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, Graph& graph,
bool can_use_flatbuffer_for_initializers,
const OrtFormatLoadOptions& load_options,
const logging::Logger& logger, std::unique_ptr<Node>& node) {
node = std::make_unique<Node>(fbs_node.index(), graph);
return node->LoadFromOrtFormat(fbs_node, can_use_flatbuffer_for_initializers, logger);
return node->LoadFromOrtFormat(fbs_node, load_options, logger);
}
Status Node::LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node,
bool can_use_flatbuffer_for_initializers,
const OrtFormatLoadOptions& load_options,
const logging::Logger& logger) {
auto LoadNodeArgsFromOrtFormat =
[&](const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>* fbs_node_arg_names,
@ -769,8 +769,7 @@ Status Node::LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node,
AttributeProto attr_proto;
std::unique_ptr<Graph> subgraph;
ORT_RETURN_IF_ERROR(
fbs::utils::LoadAttributeOrtFormat(*fbs_attr, attr_proto, subgraph, *graph_, *this,
can_use_flatbuffer_for_initializers, logger));
fbs::utils::LoadAttributeOrtFormat(*fbs_attr, attr_proto, subgraph, *graph_, *this, load_options, logger));
// If we have a sub graph in this attributes, it will be loaded into subgraph ptr
// while the attribute proto contains the sub graph will have the empty g() field
@ -4200,7 +4199,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
#if !defined(ORT_MINIMAL_BUILD)
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
#endif
bool can_use_flatbuffer_for_initializers,
const OrtFormatLoadOptions& load_options,
const logging::Logger& logger, std::unique_ptr<Graph>& graph) {
graph = std::make_unique<Graph>(owning_model, domain_to_version,
#if !defined(ORT_MINIMAL_BUILD)
@ -4210,7 +4209,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
// Assume anything in ORT format has already been validated.
false);
ORT_RETURN_IF_ERROR(graph->LoadFromOrtFormat(fbs_graph, can_use_flatbuffer_for_initializers));
ORT_RETURN_IF_ERROR(graph->LoadFromOrtFormat(fbs_graph, load_options));
#if !defined(ORT_MINIMAL_BUILD)
// in a full build we need to run Resolve to fully populate ResolveContext and Node::op_,
@ -4226,7 +4225,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
Graph& parent_graph, const Node& parent_node,
bool can_use_flatbuffer_for_initializers,
const OrtFormatLoadOptions& load_options,
const logging::Logger& logger, std::unique_ptr<Graph>& graph) {
graph = std::make_unique<Graph>(parent_graph.owning_model_,
parent_graph.domain_to_version_,
@ -4238,7 +4237,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
// Assume anything in ORT format has already been validated.
false);
return graph->LoadFromOrtFormat(fbs_graph, can_use_flatbuffer_for_initializers);
return graph->LoadFromOrtFormat(fbs_graph, load_options);
}
Graph::Graph(const Model& owning_model,
@ -4268,7 +4267,7 @@ Graph::Graph(const Model& owning_model,
}
common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
bool can_use_flatbuffer_for_initializers) {
const OrtFormatLoadOptions& load_options) {
// We deserialize the graph from ORT format in the following order:
// 1. Deserialize the initializers and sparse initializers. Convert sparse to dense.
// 2. Deserialize the NodeArgs
@ -4298,8 +4297,7 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph
for (const auto* fbs_tensor : *fbs_initializers) {
ORT_RETURN_IF(nullptr == fbs_tensor, "Initializer tensor is missing. Invalid ORT format model.");
TensorProto* initializer = deserialized_proto_data_.add_initializer();
ORT_RETURN_IF_ERROR(fbs::utils::LoadInitializerOrtFormat(*fbs_tensor, *initializer,
can_use_flatbuffer_for_initializers));
ORT_RETURN_IF_ERROR(fbs::utils::LoadInitializerOrtFormat(*fbs_tensor, *initializer, load_options));
auto p = name_to_initial_tensor_.emplace(initializer->name(), initializer);
if (!p.second) {
LOGS(logger_, WARNING) << "Duplicate initializer (dense or ConstantNode): '" << initializer->name()
@ -4318,7 +4316,8 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph
for (const auto* fbs_sparse_tensor : *fbs_sparse_initializers) {
ORT_RETURN_IF(nullptr == fbs_sparse_tensor, "Sparse Initializer tensor is missing. Invalid ORT format model.");
SparseTensorProto sparse_initializer;
ORT_RETURN_IF_ERROR(fbs::utils::LoadSparseInitializerOrtFormat(*fbs_sparse_tensor, sparse_initializer));
ORT_RETURN_IF_ERROR(fbs::utils::LoadSparseInitializerOrtFormat(*fbs_sparse_tensor, sparse_initializer,
load_options));
TensorProto& initializer = *deserialized_proto_data_.add_initializer();
ORT_RETURN_IF_ERROR(utils::SparseTensorProtoToDenseTensorProto(sparse_initializer, model_path, initializer));
auto p = name_to_initial_tensor_.emplace(initializer.name(), &initializer);
@ -4359,8 +4358,7 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph
for (const auto* fbs_node : *fbs_nodes) {
ORT_RETURN_IF(nullptr == fbs_node, "Node is missing. Invalid ORT format model.");
std::unique_ptr<Node> node;
ORT_RETURN_IF_ERROR(Node::LoadFromOrtFormat(*fbs_node, *this, can_use_flatbuffer_for_initializers, logger_,
node));
ORT_RETURN_IF_ERROR(Node::LoadFromOrtFormat(*fbs_node, *this, load_options, logger_, node));
ORT_RETURN_IF(node->Index() >= fbs_graph.max_node_index(), "Node index is out of range");
nodes_[node->Index()] = std::move(node);
++num_of_nodes_;
@ -4407,9 +4405,11 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph
ORT_RETURN_IF_ERROR(PopulateNodeArgToProducerConsumerLookupsFromNodes());
// runtime optimizations
if (const auto* fbs_runtime_optimizations = fbs_graph.runtime_optimizations()) {
if (const auto* fbs_runtime_optimization_records = fbs_runtime_optimizations->records()) {
ORT_RETURN_IF_ERROR(MutableRuntimeOptimizations().LoadFromOrtFormat(*fbs_runtime_optimization_records));
if (!load_options.ignore_saved_runtime_optimizations) {
if (const auto* fbs_runtime_optimizations = fbs_graph.runtime_optimizations()) {
if (const auto* fbs_runtime_optimization_records = fbs_runtime_optimizations->records()) {
ORT_RETURN_IF_ERROR(MutableRuntimeOptimizations().LoadFromOrtFormat(*fbs_runtime_optimization_records));
}
}
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

View file

@ -177,7 +177,7 @@ Status SaveAttributeOrtFormat(flatbuffers::FlatBufferBuilder& builder,
#endif
Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& initializer,
bool can_use_flatbuffer_for_initializers) {
const OrtFormatLoadOptions& load_options) {
initializer.Clear();
LOAD_STR_FROM_ORT_FORMAT(initializer, name, fbs_tensor.name());
@ -201,7 +201,7 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init
const auto* fbs_raw_data = fbs_tensor.raw_data();
ORT_RETURN_IF(nullptr == fbs_raw_data, "Missing raw data for initializer. Invalid ORT format model.");
if (can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) {
if (load_options.can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) {
initializer.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL);
static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE));
@ -231,19 +231,20 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init
#if !defined(DISABLE_SPARSE_TENSORS)
Status LoadSparseInitializerOrtFormat(const fbs::SparseTensor& fbs_sparse_tensor,
SparseTensorProto& initializer) {
SparseTensorProto& initializer,
const OrtFormatLoadOptions& load_options) {
SparseTensorProto loaded_initializer;
auto fbs_values_tensor = fbs_sparse_tensor.values();
ORT_RETURN_IF(nullptr == fbs_values_tensor, "Missing values for sparse initializer. Invalid ORT format model.");
auto* values_tensor = loaded_initializer.mutable_values();
ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_values_tensor, *values_tensor));
ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_values_tensor, *values_tensor, load_options));
ORT_RETURN_IF(values_tensor->name().empty(), "Missing name for SparseTensor initializer. Invalid ORT format model.");
auto fbs_indicies_tensor = fbs_sparse_tensor.indices();
ORT_RETURN_IF(nullptr == fbs_indicies_tensor, "Missing indicies for sparse initializer: ", "'", values_tensor->name(), "'",
"Invalid ORT format model.");
auto* indicies_tensor = loaded_initializer.mutable_indices();
ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_indicies_tensor, *indicies_tensor));
ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_indicies_tensor, *indicies_tensor, load_options));
auto fbs_dims = fbs_sparse_tensor.dims();
ORT_RETURN_IF(nullptr == fbs_dims, "Missing dims for sparse initializer: ", "'", values_tensor->name(), "'",
@ -259,7 +260,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr,
ONNX_NAMESPACE::AttributeProto& attr_proto,
std::unique_ptr<onnxruntime::Graph>& sub_graph,
onnxruntime::Graph& graph, onnxruntime::Node& node,
bool can_use_flatbuffer_for_initializers,
const OrtFormatLoadOptions& load_options,
const logging::Logger& logger) {
attr_proto.Clear();
LOAD_STR_FROM_ORT_FORMAT(attr_proto, name, fbs_attr.name());
@ -283,7 +284,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr,
auto fbs_tensor = fbs_attr.t();
ORT_RETURN_IF(nullptr == fbs_tensor, "Null tensor attribute. Invalid ORT format model.");
ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_tensor, *attr_proto.mutable_t(),
can_use_flatbuffer_for_initializers));
load_options));
} break;
case AttributeProto_AttributeType_GRAPH: {
// If the attribute type is a graph, we will create an empty graph in attr_proto so that the ONNX checker
@ -292,7 +293,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr,
ORT_RETURN_IF(nullptr == fbs_graph, "Null graph attribute. Invalid ORT format model.");
attr_proto.mutable_g()->set_name("Empty graph proto from deserialization of ORT format model");
ORT_RETURN_IF_ERROR(onnxruntime::Graph::LoadFromOrtFormat(*fbs_graph, graph, node,
can_use_flatbuffer_for_initializers,
load_options,
logger, sub_graph));
} break;
case AttributeProto_AttributeType_FLOATS: {
@ -327,7 +328,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr,
for (const auto* fbs_tensor : *fbs_tensors) {
ORT_RETURN_IF(nullptr == fbs_tensor, "Null tensor in tensors attribute. Invalid ORT format model.");
ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_tensor, *tensors->Add(),
can_use_flatbuffer_for_initializers));
load_options));
}
} break;

View file

@ -6,6 +6,7 @@
#include <memory>
#include "core/common/status.h"
#include "core/graph/ort_format_load_options.h"
namespace ONNX_NAMESPACE {
class AttributeProto;
@ -66,18 +67,16 @@ Status SaveAttributeOrtFormat(
/// </summary>
/// <param name="fbs_tensor">Flatbuffer Tensor</param>
/// <param name="initializer">TensorProto to load data into</param>
/// <param name="can_use_flatbuffer_for_initializers">
/// If true, set the TensorProto to point to the memory in the flatbuffer instead of copying data.
/// This requires the buffer to remain valid for the entire duration of the InferenceSession.
/// </param>
/// <param name="load_options">ORT format load options</param>
/// <returns>Status</returns>
Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor,
ONNX_NAMESPACE::TensorProto& initializer,
bool can_use_flatbuffer_for_initializers = false);
const OrtFormatLoadOptions& load_options);
#if !defined(DISABLE_SPARSE_TENSORS)
Status LoadSparseInitializerOrtFormat(const fbs::SparseTensor& fbs_sparse_tensor,
ONNX_NAMESPACE::SparseTensorProto& initializer);
ONNX_NAMESPACE::SparseTensorProto& initializer,
const OrtFormatLoadOptions& load_options);
#endif // !defined(DISABLE_SPARSE_TENSORS)
// Load a give fbs::Attribute into AttributeProto
@ -87,7 +86,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr,
ONNX_NAMESPACE::AttributeProto& attr_proto,
std::unique_ptr<onnxruntime::Graph>& sub_graph,
onnxruntime::Graph& graph, onnxruntime::Node& node,
bool can_use_flatbuffer_for_initializers,
const OrtFormatLoadOptions& load_options,
const logging::Logger& logger);
} // namespace utils

View file

@ -771,7 +771,7 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
#if !defined(ORT_MINIMAL_BUILD)
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
#endif
bool can_use_flatbuffer_for_initializers,
const OrtFormatLoadOptions& load_options,
const logging::Logger& logger,
std::unique_ptr<Model>& model) {
model = std::make_unique<Model>();
@ -838,10 +838,10 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
}
ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version, schema_registry,
can_use_flatbuffer_for_initializers, logger, model->graph_));
load_options, logger, model->graph_));
#else
ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version,
can_use_flatbuffer_for_initializers, logger, model->graph_));
load_options, logger, model->graph_));
#endif
return Status::OK();
}

View file

@ -9,6 +9,7 @@
#include <string>
#include "core/common/path.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/ort_format_load_options.h"
#include "core/session/onnxruntime_c_api.h"
#if !defined(ORT_MINIMAL_BUILD)
#include "core/graph/function_template.h"
@ -297,7 +298,7 @@ class Model {
#if !defined(ORT_MINIMAL_BUILD)
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
#endif
bool can_use_flatbuffer_for_initializers,
const OrtFormatLoadOptions& load_options,
const logging::Logger& logger,
std::unique_ptr<Model>& model);

View file

@ -0,0 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
/// Options to configure how an ORT format model is loaded.
struct OrtFormatLoadOptions {
/// If true, set initializer TensorProtos to point to memory in the flatbuffer instead of copying data.
/// This requires the flatbuffer to remain valid for the entire duration of the InferenceSession.
bool can_use_flatbuffer_for_initializers{true};
/// If true, do not load any saved runtime optimizations.
bool ignore_saved_runtime_optimizations{false};
};
} // namespace onnxruntime

View file

@ -1024,8 +1024,10 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function<Status()> load_ort
const auto* fbs_ort_model_version = fbs_session->ort_version();
ORT_RETURN_IF(fbs_ort_model_version == nullptr, "Serialized version info is null. Invalid ORT format model.");
auto model_version = std::stoi(fbs_ort_model_version->str());
bool is_supported = IsOrtModelVersionSupported(model_version);
const auto model_version = std::stoi(fbs_ort_model_version->str());
const bool is_supported = IsOrtModelVersionSupported(model_version);
OrtFormatLoadOptions load_options{};
#if defined(ORT_MINIMAL_BUILD)
// Note about the ORT format version 5 breaking change.
@ -1038,16 +1040,34 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function<Status()> load_ort
"The ORT format model version [", fbs_ort_model_version->string_view(),
"] is not supported in this build ", ORT_VERSION, ". ",
kOrtFormatVersion5BreakingChangeNote);
#else
// models prior to v5 can be handled by inserting the kernel constraints in a full build
bool is_supported_with_update = !is_supported && model_version < 5;
#else // ^^ defined(ORT_MINIMAL_BUILD) ^^ / vv !defined(ORT_MINIMAL_BUILD) vv
const auto has_saved_runtime_optimizations = [](const fbs::InferenceSession& fbs_session) -> bool {
if (const auto* fbs_model = fbs_session.model()) {
if (const auto* fbs_graph = fbs_model->graph()) {
if (const auto* fbs_runtime_opts = fbs_graph->runtime_optimizations()) {
if (const auto* fbs_runtime_opt_records = fbs_runtime_opts->records()) {
return fbs_runtime_opt_records->size() > 0;
}
}
}
}
return false;
};
// models prior to v5 can be handled by inserting the kernel constraints in a full build
const bool is_supported_with_update = model_version < 5;
if (is_supported_with_update && has_saved_runtime_optimizations(*fbs_session)) {
LOGS(*session_logger_, WARNING)
<< "The old ORT format model (version " << fbs_ort_model_version->string_view()
<< ") has saved runtime optimizations. They will be ignored.";
load_options.ignore_saved_runtime_optimizations = true;
}
// currently this means the model is using a future version
// i.e. attempted load of model created with future version of ORT
ORT_RETURN_IF_NOT(is_supported || is_supported_with_update,
"The ORT format model version [", fbs_ort_model_version->string_view(),
"] is not supported in this build ", ORT_VERSION, ". ");
#endif
"] is not supported in this build ", ORT_VERSION, ".");
#endif // !defined(ORT_MINIMAL_BUILD)
const auto* fbs_model = fbs_session->model();
ORT_RETURN_IF(nullptr == fbs_model, "Missing Model. Invalid ORT format model.");
@ -1058,19 +1078,18 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function<Status()> load_ort
// if that is the case we also allow creating initializers that directly use those bytes.
const auto& config_options = session_options_.config_options;
using_ort_model_bytes_for_initializers_ =
ort_format_model_bytes_data_holder_.empty() &&
config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseORTModelBytesForInitializers, "0") == "1";
load_options.can_use_flatbuffer_for_initializers =
ort_format_model_bytes_data_holder_.empty() &&
config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseORTModelBytesForInitializers, "0") == "1";
// need to go from unique_ptr to shared_ptr when moving into model_
std::unique_ptr<Model> tmp_model;
#if !defined(ORT_MINIMAL_BUILD)
ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr,
using_ort_model_bytes_for_initializers_,
*session_logger_, tmp_model));
load_options, *session_logger_, tmp_model));
#else
ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model, using_ort_model_bytes_for_initializers_, *session_logger_,
tmp_model));
ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model, load_options, *session_logger_, tmp_model));
#endif
ORT_RETURN_IF_ERROR(SaveModelMetadata(*tmp_model));

View file

@ -18,11 +18,10 @@
#include "flatbuffers/idl.h"
#include "flatbuffers/util.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
using namespace std;
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::logging;
namespace onnxruntime {
namespace test {
@ -97,7 +96,6 @@ static void CompareTensors(const OrtValue& left_value, const OrtValue& right_val
EXPECT_EQ(left_strings[i], right_strings[i]) << "Mismatch index:" << i;
}
} else {
ASSERT_EQ(memcmp(left.DataRaw(), right.DataRaw(), left.SizeInBytes()), 0);
}
}
@ -231,8 +229,8 @@ static void CompareSessionMetadata(const InferenceSessionWrapper& session_object
ASSERT_EQ(model_1.ProducerVersion(), model_2.ProducerVersion());
}
static void SaveAndCompareModels(const std::basic_string<ORTCHAR_T>& orig_file,
const std::basic_string<ORTCHAR_T>& ort_file,
static void SaveAndCompareModels(const PathString& orig_file,
const PathString& ort_file,
TransformerLevel optimization_level = TransformerLevel::Level3) {
SessionOptions so;
so.session_logid = "SerializeToOrtFormat";
@ -301,7 +299,7 @@ TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) {
test_info.configs.push_back(std::make_pair(kOrtSessionOptionsConfigLoadModelFormat, "ORT"));
OrtValue ml_value;
vector<float> data(28 * 28, 0.0);
std::vector<float> data(28 * 28, 0.0);
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1, 1, 28, 28}, data,
&ml_value);
test_info.inputs.insert(std::make_pair("Input3", ml_value));
@ -372,59 +370,103 @@ TEST(OrtModelOnlyTests, MetadataSerialization) {
}
// test we can load an old ORT format model and run it in a full build.
// we changed from using kernel hashes to kernel type constraints in v5, so any old model should be able to be loaded
// we changed from using kernel hashes to kernel type constraints in v5, so an old model should be able to be loaded
// in a full build if we add the kernel type constraints during loading. this also means we can save the updated
// ORT format model to effectively upgrade it to v5.
TEST(OrtModelOnlyTests, UpdateOrtModelVersion) {
// input is ORT format model using v4 where we used kernel hashes instead of constraints
const auto onnx_file = ORT_TSTR("testdata/mnist.onnx");
const auto ort_file_v4 = ORT_TSTR("testdata/mnist.basic.v4.ort");
const auto ort_file_v5 = ORT_TSTR("testdata/mnist.basic.v5.test_output.ort");
void TestOrtModelUpdate(const PathString& onnx_file,
const PathString& ort_file_v4,
const PathString& generated_ort_file_v5,
const std::function<void(NameMLValMap& inputs, std::vector<std::string>& output_names)>&
set_up_test_inputs_and_outputs_fn) {
// ort_file_v4 is ORT format model using v4 where we used kernel hashes instead of constraints
// update v4 model and save as v5. do not run optimizations in order to preserve the model as-is.
SaveAndCompareModels(ort_file_v4, ort_file_v5, TransformerLevel::Default);
SaveAndCompareModels(ort_file_v4, generated_ort_file_v5, TransformerLevel::Default);
// run the original, v4 and v5 models and check the output is the same
RandomValueGenerator random{};
std::vector<int64_t> input_dims{1, 1, 28, 28};
std::vector<float> input_data = random.Gaussian<float>(input_dims, 0.0f, 0.9f);
OrtValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault),
input_dims, input_data, &ml_value);
OrtModelTestInfo test_info;
set_up_test_inputs_and_outputs_fn(test_info.inputs, test_info.output_names);
// keep the onnx and ort models to the same optimization level
test_info.optimization_level = TransformerLevel::Level1;
test_info.inputs.insert(std::make_pair("Input3", ml_value));
test_info.output_names = {"Plus214_Output_0"};
OrtValue orig_out, v4_out, v5_out;
std::vector<OrtValue> orig_out, v4_out, v5_out;
test_info.model_filename = onnx_file;
test_info.output_verifier = [&orig_out](const std::vector<OrtValue>& fetches) {
orig_out = fetches[0];
orig_out = fetches;
};
RunOrtModel(test_info);
// run with v4 as input. this should also update to v5 prior to execution.
test_info.model_filename = ort_file_v4;
test_info.output_verifier = [&v4_out](const std::vector<OrtValue>& fetches) {
v4_out = fetches[0];
v4_out = fetches;
};
RunOrtModel(test_info);
// validate the model saved as v5 also works
test_info.model_filename = ort_file_v5;
test_info.model_filename = generated_ort_file_v5;
test_info.output_verifier = [&v5_out](const std::vector<OrtValue>& fetches) {
v5_out = fetches[0];
v5_out = fetches;
};
RunOrtModel(test_info);
CompareTensors(orig_out, v4_out);
CompareTensors(v4_out, v5_out);
auto compare_outputs = [](gsl::span<OrtValue> expected, gsl::span<OrtValue> actual) {
ASSERT_EQ(expected.size(), actual.size());
for (size_t i = 0; i < expected.size(); ++i) {
CompareTensors(expected[i], actual[i]);
}
};
compare_outputs(orig_out, v4_out);
compare_outputs(v4_out, v5_out);
};
TEST(OrtModelOnlyTests, UpdateOrtModelVersion) {
const auto onnx_file = ORT_TSTR("testdata/mnist.onnx");
const auto ort_file_v4 = ORT_TSTR("testdata/mnist.basic.v4.ort");
const auto ort_file_v5 = ORT_TSTR("testdata/mnist.basic.v5.test_output.ort");
RandomValueGenerator random{}; // keep in scope so we get random seed trace message on failure
TestOrtModelUpdate(onnx_file, ort_file_v4, ort_file_v5,
[&](NameMLValMap& inputs, std::vector<std::string>& output_names) {
std::vector<int64_t> input_dims{1, 1, 28, 28};
std::vector<float> input_data = random.Gaussian<float>(input_dims, 0.0f, 0.9f);
OrtValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault),
input_dims, input_data, &ml_value);
inputs = {{"Input3", ml_value}};
output_names = {"Plus214_Output_0"};
});
}
// test that a model with saved runtime optimizations can also be updated
// note: the saved runtime optimizations will be ignored
TEST(OrtModelOnlyTests, UpdateOrtModelVersionWithSavedRuntimeOptimizations) {
const auto onnx_file = ORT_TSTR("testdata/transform/runtime_optimization/qdq_convs.onnx");
const auto ort_file_v4 = ORT_TSTR("testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v4.ort");
const auto ort_file_v5 =
ORT_TSTR("testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v5.test_output.ort");
RandomValueGenerator random{}; // keep in scope so we get random seed trace message on failure
TestOrtModelUpdate(onnx_file, ort_file_v4, ort_file_v5,
[&](NameMLValMap& inputs, std::vector<std::string>& output_names) {
constexpr int n = 3; // number of QDQ convs
for (size_t i = 0; i < n; ++i) {
std::vector<int64_t> input_dims{1, 1, 5, 5};
std::vector<uint8_t> input_data = random.Uniform<uint8_t>(input_dims, 0, 255);
OrtValue ml_value;
CreateMLValue<uint8_t>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault),
input_dims, input_data, &ml_value);
inputs.emplace(MakeString("X_", i), std::move(ml_value));
output_names.push_back(MakeString("Y_", i));
}
});
}
#if !defined(DISABLE_ML_OPS)