diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index ef5cb20913..0fc2762ca0 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -46,6 +46,7 @@ #endif #include "core/graph/graph_nodes.h" #include "core/graph/node_arg.h" +#include "core/graph/ort_format_load_options.h" namespace flatbuffers { class FlatBufferBuilder; @@ -487,11 +488,11 @@ class Node { #endif static Status LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, Graph& graph, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& node); Status LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger); Status LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge& fbs_node_edgs, const Graph& graph); @@ -1312,18 +1313,18 @@ class Graph { virtual ~Graph(); - static common::Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, const Model& owning_model, - const std::unordered_map& domain_to_version, + static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, const Model& owning_model, + const std::unordered_map& domain_to_version, #if !defined(ORT_MINIMAL_BUILD) - IOnnxRuntimeOpSchemaCollectionPtr schema_registry, + IOnnxRuntimeOpSchemaCollectionPtr schema_registry, #endif - bool can_use_flatbuffer_for_initializers, - const logging::Logger& logger, std::unique_ptr& graph); + const OrtFormatLoadOptions& load_options, + const logging::Logger& logger, std::unique_ptr& graph); // deserialize a subgraph static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, Graph& parent_graph, const Node& parent_node, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& graph); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1353,8 +1354,8 @@ class Graph { bool strict_shape_type_inference); // Populate Graph instance from ORT format serialized data. - common::Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, - bool can_use_flatbuffer_for_initializers); + Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, + const OrtFormatLoadOptions& load_options); #if !defined(ORT_MINIMAL_BUILD) // Constructor: Given a loaded from model file, construct diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index cf8302872a..08f00078e3 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -723,14 +723,14 @@ flatbuffers::Offset Node::SaveEdgesToOrtFormat(flatbuffers::FlatB #endif // !defined(ORT_MINIMAL_BUILD) Status Node::LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, Graph& graph, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& node) { node = std::make_unique(fbs_node.index(), graph); - return node->LoadFromOrtFormat(fbs_node, can_use_flatbuffer_for_initializers, logger); + return node->LoadFromOrtFormat(fbs_node, load_options, logger); } Status Node::LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger) { auto LoadNodeArgsFromOrtFormat = [&](const flatbuffers::Vector>* fbs_node_arg_names, @@ -769,8 +769,7 @@ Status Node::LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, AttributeProto attr_proto; std::unique_ptr subgraph; ORT_RETURN_IF_ERROR( - fbs::utils::LoadAttributeOrtFormat(*fbs_attr, attr_proto, subgraph, *graph_, *this, - can_use_flatbuffer_for_initializers, logger)); + fbs::utils::LoadAttributeOrtFormat(*fbs_attr, attr_proto, subgraph, *graph_, *this, load_options, logger)); // If we have a sub graph in this attributes, it will be loaded into subgraph ptr // while the attribute proto contains the sub graph will have the empty g() field @@ -4200,7 +4199,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, #if !defined(ORT_MINIMAL_BUILD) IOnnxRuntimeOpSchemaCollectionPtr schema_registry, #endif - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& graph) { graph = std::make_unique(owning_model, domain_to_version, #if !defined(ORT_MINIMAL_BUILD) @@ -4210,7 +4209,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, // Assume anything in ORT format has already been validated. false); - ORT_RETURN_IF_ERROR(graph->LoadFromOrtFormat(fbs_graph, can_use_flatbuffer_for_initializers)); + ORT_RETURN_IF_ERROR(graph->LoadFromOrtFormat(fbs_graph, load_options)); #if !defined(ORT_MINIMAL_BUILD) // in a full build we need to run Resolve to fully populate ResolveContext and Node::op_, @@ -4226,7 +4225,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, Graph& parent_graph, const Node& parent_node, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& graph) { graph = std::make_unique(parent_graph.owning_model_, parent_graph.domain_to_version_, @@ -4238,7 +4237,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, // Assume anything in ORT format has already been validated. false); - return graph->LoadFromOrtFormat(fbs_graph, can_use_flatbuffer_for_initializers); + return graph->LoadFromOrtFormat(fbs_graph, load_options); } Graph::Graph(const Model& owning_model, @@ -4268,7 +4267,7 @@ Graph::Graph(const Model& owning_model, } common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, - bool can_use_flatbuffer_for_initializers) { + const OrtFormatLoadOptions& load_options) { // We deserialize the graph from ORT format in the following order: // 1. Deserialize the initializers and sparse initializers. Convert sparse to dense. // 2. Deserialize the NodeArgs @@ -4298,8 +4297,7 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph for (const auto* fbs_tensor : *fbs_initializers) { ORT_RETURN_IF(nullptr == fbs_tensor, "Initializer tensor is missing. Invalid ORT format model."); TensorProto* initializer = deserialized_proto_data_.add_initializer(); - ORT_RETURN_IF_ERROR(fbs::utils::LoadInitializerOrtFormat(*fbs_tensor, *initializer, - can_use_flatbuffer_for_initializers)); + ORT_RETURN_IF_ERROR(fbs::utils::LoadInitializerOrtFormat(*fbs_tensor, *initializer, load_options)); auto p = name_to_initial_tensor_.emplace(initializer->name(), initializer); if (!p.second) { LOGS(logger_, WARNING) << "Duplicate initializer (dense or ConstantNode): '" << initializer->name() @@ -4318,7 +4316,8 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph for (const auto* fbs_sparse_tensor : *fbs_sparse_initializers) { ORT_RETURN_IF(nullptr == fbs_sparse_tensor, "Sparse Initializer tensor is missing. Invalid ORT format model."); SparseTensorProto sparse_initializer; - ORT_RETURN_IF_ERROR(fbs::utils::LoadSparseInitializerOrtFormat(*fbs_sparse_tensor, sparse_initializer)); + ORT_RETURN_IF_ERROR(fbs::utils::LoadSparseInitializerOrtFormat(*fbs_sparse_tensor, sparse_initializer, + load_options)); TensorProto& initializer = *deserialized_proto_data_.add_initializer(); ORT_RETURN_IF_ERROR(utils::SparseTensorProtoToDenseTensorProto(sparse_initializer, model_path, initializer)); auto p = name_to_initial_tensor_.emplace(initializer.name(), &initializer); @@ -4359,8 +4358,7 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph for (const auto* fbs_node : *fbs_nodes) { ORT_RETURN_IF(nullptr == fbs_node, "Node is missing. Invalid ORT format model."); std::unique_ptr node; - ORT_RETURN_IF_ERROR(Node::LoadFromOrtFormat(*fbs_node, *this, can_use_flatbuffer_for_initializers, logger_, - node)); + ORT_RETURN_IF_ERROR(Node::LoadFromOrtFormat(*fbs_node, *this, load_options, logger_, node)); ORT_RETURN_IF(node->Index() >= fbs_graph.max_node_index(), "Node index is out of range"); nodes_[node->Index()] = std::move(node); ++num_of_nodes_; @@ -4407,9 +4405,11 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph ORT_RETURN_IF_ERROR(PopulateNodeArgToProducerConsumerLookupsFromNodes()); // runtime optimizations - if (const auto* fbs_runtime_optimizations = fbs_graph.runtime_optimizations()) { - if (const auto* fbs_runtime_optimization_records = fbs_runtime_optimizations->records()) { - ORT_RETURN_IF_ERROR(MutableRuntimeOptimizations().LoadFromOrtFormat(*fbs_runtime_optimization_records)); + if (!load_options.ignore_saved_runtime_optimizations) { + if (const auto* fbs_runtime_optimizations = fbs_graph.runtime_optimizations()) { + if (const auto* fbs_runtime_optimization_records = fbs_runtime_optimizations->records()) { + ORT_RETURN_IF_ERROR(MutableRuntimeOptimizations().LoadFromOrtFormat(*fbs_runtime_optimization_records)); + } } } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index 15ada5fe8c..1d1fb85bf5 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -177,7 +177,7 @@ Status SaveAttributeOrtFormat(flatbuffers::FlatBufferBuilder& builder, #endif Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& initializer, - bool can_use_flatbuffer_for_initializers) { + const OrtFormatLoadOptions& load_options) { initializer.Clear(); LOAD_STR_FROM_ORT_FORMAT(initializer, name, fbs_tensor.name()); @@ -201,7 +201,7 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init const auto* fbs_raw_data = fbs_tensor.raw_data(); ORT_RETURN_IF(nullptr == fbs_raw_data, "Missing raw data for initializer. Invalid ORT format model."); - if (can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) { + if (load_options.can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) { initializer.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE)); @@ -231,19 +231,20 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init #if !defined(DISABLE_SPARSE_TENSORS) Status LoadSparseInitializerOrtFormat(const fbs::SparseTensor& fbs_sparse_tensor, - SparseTensorProto& initializer) { + SparseTensorProto& initializer, + const OrtFormatLoadOptions& load_options) { SparseTensorProto loaded_initializer; auto fbs_values_tensor = fbs_sparse_tensor.values(); ORT_RETURN_IF(nullptr == fbs_values_tensor, "Missing values for sparse initializer. Invalid ORT format model."); auto* values_tensor = loaded_initializer.mutable_values(); - ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_values_tensor, *values_tensor)); + ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_values_tensor, *values_tensor, load_options)); ORT_RETURN_IF(values_tensor->name().empty(), "Missing name for SparseTensor initializer. Invalid ORT format model."); auto fbs_indicies_tensor = fbs_sparse_tensor.indices(); ORT_RETURN_IF(nullptr == fbs_indicies_tensor, "Missing indicies for sparse initializer: ", "'", values_tensor->name(), "'", "Invalid ORT format model."); auto* indicies_tensor = loaded_initializer.mutable_indices(); - ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_indicies_tensor, *indicies_tensor)); + ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_indicies_tensor, *indicies_tensor, load_options)); auto fbs_dims = fbs_sparse_tensor.dims(); ORT_RETURN_IF(nullptr == fbs_dims, "Missing dims for sparse initializer: ", "'", values_tensor->name(), "'", @@ -259,7 +260,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr, ONNX_NAMESPACE::AttributeProto& attr_proto, std::unique_ptr& sub_graph, onnxruntime::Graph& graph, onnxruntime::Node& node, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger) { attr_proto.Clear(); LOAD_STR_FROM_ORT_FORMAT(attr_proto, name, fbs_attr.name()); @@ -283,7 +284,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr, auto fbs_tensor = fbs_attr.t(); ORT_RETURN_IF(nullptr == fbs_tensor, "Null tensor attribute. Invalid ORT format model."); ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_tensor, *attr_proto.mutable_t(), - can_use_flatbuffer_for_initializers)); + load_options)); } break; case AttributeProto_AttributeType_GRAPH: { // If the attribute type is a graph, we will create an empty graph in attr_proto so that the ONNX checker @@ -292,7 +293,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr, ORT_RETURN_IF(nullptr == fbs_graph, "Null graph attribute. Invalid ORT format model."); attr_proto.mutable_g()->set_name("Empty graph proto from deserialization of ORT format model"); ORT_RETURN_IF_ERROR(onnxruntime::Graph::LoadFromOrtFormat(*fbs_graph, graph, node, - can_use_flatbuffer_for_initializers, + load_options, logger, sub_graph)); } break; case AttributeProto_AttributeType_FLOATS: { @@ -327,7 +328,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr, for (const auto* fbs_tensor : *fbs_tensors) { ORT_RETURN_IF(nullptr == fbs_tensor, "Null tensor in tensors attribute. Invalid ORT format model."); ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_tensor, *tensors->Add(), - can_use_flatbuffer_for_initializers)); + load_options)); } } break; diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.h b/onnxruntime/core/graph/graph_flatbuffers_utils.h index f1d7ab989c..0088babad2 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.h +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.h @@ -6,6 +6,7 @@ #include #include "core/common/status.h" +#include "core/graph/ort_format_load_options.h" namespace ONNX_NAMESPACE { class AttributeProto; @@ -66,18 +67,16 @@ Status SaveAttributeOrtFormat( /// /// Flatbuffer Tensor /// TensorProto to load data into -/// -/// If true, set the TensorProto to point to the memory in the flatbuffer instead of copying data. -/// This requires the buffer to remain valid for the entire duration of the InferenceSession. -/// +/// ORT format load options /// Status Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, ONNX_NAMESPACE::TensorProto& initializer, - bool can_use_flatbuffer_for_initializers = false); + const OrtFormatLoadOptions& load_options); #if !defined(DISABLE_SPARSE_TENSORS) Status LoadSparseInitializerOrtFormat(const fbs::SparseTensor& fbs_sparse_tensor, - ONNX_NAMESPACE::SparseTensorProto& initializer); + ONNX_NAMESPACE::SparseTensorProto& initializer, + const OrtFormatLoadOptions& load_options); #endif // !defined(DISABLE_SPARSE_TENSORS) // Load a give fbs::Attribute into AttributeProto @@ -87,7 +86,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr, ONNX_NAMESPACE::AttributeProto& attr_proto, std::unique_ptr& sub_graph, onnxruntime::Graph& graph, onnxruntime::Node& node, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger); } // namespace utils diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 4d7bce7f6d..8af9f99ed1 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -771,7 +771,7 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model, #if !defined(ORT_MINIMAL_BUILD) const IOnnxRuntimeOpSchemaRegistryList* local_registries, #endif - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& model) { model = std::make_unique(); @@ -838,10 +838,10 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model, } ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version, schema_registry, - can_use_flatbuffer_for_initializers, logger, model->graph_)); + load_options, logger, model->graph_)); #else ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version, - can_use_flatbuffer_for_initializers, logger, model->graph_)); + load_options, logger, model->graph_)); #endif return Status::OK(); } diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 49280ca9fa..bd1f53b43c 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -9,6 +9,7 @@ #include #include "core/common/path.h" #include "core/graph/graph_viewer.h" +#include "core/graph/ort_format_load_options.h" #include "core/session/onnxruntime_c_api.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/graph/function_template.h" @@ -297,7 +298,7 @@ class Model { #if !defined(ORT_MINIMAL_BUILD) const IOnnxRuntimeOpSchemaRegistryList* local_registries, #endif - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& model); diff --git a/onnxruntime/core/graph/ort_format_load_options.h b/onnxruntime/core/graph/ort_format_load_options.h new file mode 100644 index 0000000000..c870a7ada0 --- /dev/null +++ b/onnxruntime/core/graph/ort_format_load_options.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { + +/// Options to configure how an ORT format model is loaded. +struct OrtFormatLoadOptions { + /// If true, set initializer TensorProtos to point to memory in the flatbuffer instead of copying data. + /// This requires the flatbuffer to remain valid for the entire duration of the InferenceSession. + bool can_use_flatbuffer_for_initializers{true}; + + /// If true, do not load any saved runtime optimizations. + bool ignore_saved_runtime_optimizations{false}; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 206c053a34..81d668bdce 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1024,8 +1024,10 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort const auto* fbs_ort_model_version = fbs_session->ort_version(); ORT_RETURN_IF(fbs_ort_model_version == nullptr, "Serialized version info is null. Invalid ORT format model."); - auto model_version = std::stoi(fbs_ort_model_version->str()); - bool is_supported = IsOrtModelVersionSupported(model_version); + const auto model_version = std::stoi(fbs_ort_model_version->str()); + const bool is_supported = IsOrtModelVersionSupported(model_version); + + OrtFormatLoadOptions load_options{}; #if defined(ORT_MINIMAL_BUILD) // Note about the ORT format version 5 breaking change. @@ -1038,16 +1040,34 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort "The ORT format model version [", fbs_ort_model_version->string_view(), "] is not supported in this build ", ORT_VERSION, ". ", kOrtFormatVersion5BreakingChangeNote); -#else - // models prior to v5 can be handled by inserting the kernel constraints in a full build - bool is_supported_with_update = !is_supported && model_version < 5; +#else // ^^ defined(ORT_MINIMAL_BUILD) ^^ / vv !defined(ORT_MINIMAL_BUILD) vv + const auto has_saved_runtime_optimizations = [](const fbs::InferenceSession& fbs_session) -> bool { + if (const auto* fbs_model = fbs_session.model()) { + if (const auto* fbs_graph = fbs_model->graph()) { + if (const auto* fbs_runtime_opts = fbs_graph->runtime_optimizations()) { + if (const auto* fbs_runtime_opt_records = fbs_runtime_opts->records()) { + return fbs_runtime_opt_records->size() > 0; + } + } + } + } + return false; + }; + + // models prior to v5 can be handled by inserting the kernel constraints in a full build + const bool is_supported_with_update = model_version < 5; + + if (is_supported_with_update && has_saved_runtime_optimizations(*fbs_session)) { + LOGS(*session_logger_, WARNING) + << "The old ORT format model (version " << fbs_ort_model_version->string_view() + << ") has saved runtime optimizations. They will be ignored."; + load_options.ignore_saved_runtime_optimizations = true; + } - // currently this means the model is using a future version - // i.e. attempted load of model created with future version of ORT ORT_RETURN_IF_NOT(is_supported || is_supported_with_update, "The ORT format model version [", fbs_ort_model_version->string_view(), - "] is not supported in this build ", ORT_VERSION, ". "); -#endif + "] is not supported in this build ", ORT_VERSION, "."); +#endif // !defined(ORT_MINIMAL_BUILD) const auto* fbs_model = fbs_session->model(); ORT_RETURN_IF(nullptr == fbs_model, "Missing Model. Invalid ORT format model."); @@ -1058,19 +1078,18 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort // if that is the case we also allow creating initializers that directly use those bytes. const auto& config_options = session_options_.config_options; using_ort_model_bytes_for_initializers_ = - ort_format_model_bytes_data_holder_.empty() && - config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseORTModelBytesForInitializers, "0") == "1"; + load_options.can_use_flatbuffer_for_initializers = + ort_format_model_bytes_data_holder_.empty() && + config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseORTModelBytesForInitializers, "0") == "1"; // need to go from unique_ptr to shared_ptr when moving into model_ std::unique_ptr tmp_model; #if !defined(ORT_MINIMAL_BUILD) ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, - using_ort_model_bytes_for_initializers_, - *session_logger_, tmp_model)); + load_options, *session_logger_, tmp_model)); #else - ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model, using_ort_model_bytes_for_initializers_, *session_logger_, - tmp_model)); + ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model, load_options, *session_logger_, tmp_model)); #endif ORT_RETURN_IF_ERROR(SaveModelMetadata(*tmp_model)); diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc index 595f9fb821..5d4901eee2 100644 --- a/onnxruntime/test/framework/ort_model_only_test.cc +++ b/onnxruntime/test/framework/ort_model_only_test.cc @@ -18,11 +18,10 @@ #include "flatbuffers/idl.h" #include "flatbuffers/util.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" -using namespace std; using namespace ONNX_NAMESPACE; -using namespace onnxruntime::logging; namespace onnxruntime { namespace test { @@ -97,7 +96,6 @@ static void CompareTensors(const OrtValue& left_value, const OrtValue& right_val EXPECT_EQ(left_strings[i], right_strings[i]) << "Mismatch index:" << i; } } else { - ASSERT_EQ(memcmp(left.DataRaw(), right.DataRaw(), left.SizeInBytes()), 0); } } @@ -231,8 +229,8 @@ static void CompareSessionMetadata(const InferenceSessionWrapper& session_object ASSERT_EQ(model_1.ProducerVersion(), model_2.ProducerVersion()); } -static void SaveAndCompareModels(const std::basic_string& orig_file, - const std::basic_string& ort_file, +static void SaveAndCompareModels(const PathString& orig_file, + const PathString& ort_file, TransformerLevel optimization_level = TransformerLevel::Level3) { SessionOptions so; so.session_logid = "SerializeToOrtFormat"; @@ -301,7 +299,7 @@ TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) { test_info.configs.push_back(std::make_pair(kOrtSessionOptionsConfigLoadModelFormat, "ORT")); OrtValue ml_value; - vector data(28 * 28, 0.0); + std::vector data(28 * 28, 0.0); CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1, 1, 28, 28}, data, &ml_value); test_info.inputs.insert(std::make_pair("Input3", ml_value)); @@ -372,59 +370,103 @@ TEST(OrtModelOnlyTests, MetadataSerialization) { } // test we can load an old ORT format model and run it in a full build. -// we changed from using kernel hashes to kernel type constraints in v5, so any old model should be able to be loaded +// we changed from using kernel hashes to kernel type constraints in v5, so an old model should be able to be loaded // in a full build if we add the kernel type constraints during loading. this also means we can save the updated // ORT format model to effectively upgrade it to v5. -TEST(OrtModelOnlyTests, UpdateOrtModelVersion) { - // input is ORT format model using v4 where we used kernel hashes instead of constraints - const auto onnx_file = ORT_TSTR("testdata/mnist.onnx"); - const auto ort_file_v4 = ORT_TSTR("testdata/mnist.basic.v4.ort"); - const auto ort_file_v5 = ORT_TSTR("testdata/mnist.basic.v5.test_output.ort"); +void TestOrtModelUpdate(const PathString& onnx_file, + const PathString& ort_file_v4, + const PathString& generated_ort_file_v5, + const std::function& output_names)>& + set_up_test_inputs_and_outputs_fn) { + // ort_file_v4 is ORT format model using v4 where we used kernel hashes instead of constraints // update v4 model and save as v5. do not run optimizations in order to preserve the model as-is. - SaveAndCompareModels(ort_file_v4, ort_file_v5, TransformerLevel::Default); + SaveAndCompareModels(ort_file_v4, generated_ort_file_v5, TransformerLevel::Default); // run the original, v4 and v5 models and check the output is the same - RandomValueGenerator random{}; - std::vector input_dims{1, 1, 28, 28}; - std::vector input_data = random.Gaussian(input_dims, 0.0f, 0.9f); - - OrtValue ml_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), - input_dims, input_data, &ml_value); - OrtModelTestInfo test_info; + set_up_test_inputs_and_outputs_fn(test_info.inputs, test_info.output_names); // keep the onnx and ort models to the same optimization level test_info.optimization_level = TransformerLevel::Level1; - test_info.inputs.insert(std::make_pair("Input3", ml_value)); - test_info.output_names = {"Plus214_Output_0"}; - - OrtValue orig_out, v4_out, v5_out; + std::vector orig_out, v4_out, v5_out; test_info.model_filename = onnx_file; test_info.output_verifier = [&orig_out](const std::vector& fetches) { - orig_out = fetches[0]; + orig_out = fetches; }; RunOrtModel(test_info); // run with v4 as input. this should also update to v5 prior to execution. test_info.model_filename = ort_file_v4; test_info.output_verifier = [&v4_out](const std::vector& fetches) { - v4_out = fetches[0]; + v4_out = fetches; }; RunOrtModel(test_info); // validate the model saved as v5 also works - test_info.model_filename = ort_file_v5; + test_info.model_filename = generated_ort_file_v5; test_info.output_verifier = [&v5_out](const std::vector& fetches) { - v5_out = fetches[0]; + v5_out = fetches; }; RunOrtModel(test_info); - CompareTensors(orig_out, v4_out); - CompareTensors(v4_out, v5_out); + auto compare_outputs = [](gsl::span expected, gsl::span actual) { + ASSERT_EQ(expected.size(), actual.size()); + for (size_t i = 0; i < expected.size(); ++i) { + CompareTensors(expected[i], actual[i]); + } + }; + + compare_outputs(orig_out, v4_out); + compare_outputs(v4_out, v5_out); +}; + +TEST(OrtModelOnlyTests, UpdateOrtModelVersion) { + const auto onnx_file = ORT_TSTR("testdata/mnist.onnx"); + const auto ort_file_v4 = ORT_TSTR("testdata/mnist.basic.v4.ort"); + const auto ort_file_v5 = ORT_TSTR("testdata/mnist.basic.v5.test_output.ort"); + + RandomValueGenerator random{}; // keep in scope so we get random seed trace message on failure + + TestOrtModelUpdate(onnx_file, ort_file_v4, ort_file_v5, + [&](NameMLValMap& inputs, std::vector& output_names) { + std::vector input_dims{1, 1, 28, 28}; + std::vector input_data = random.Gaussian(input_dims, 0.0f, 0.9f); + OrtValue ml_value; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), + input_dims, input_data, &ml_value); + + inputs = {{"Input3", ml_value}}; + output_names = {"Plus214_Output_0"}; + }); +} + +// test that a model with saved runtime optimizations can also be updated +// note: the saved runtime optimizations will be ignored +TEST(OrtModelOnlyTests, UpdateOrtModelVersionWithSavedRuntimeOptimizations) { + const auto onnx_file = ORT_TSTR("testdata/transform/runtime_optimization/qdq_convs.onnx"); + const auto ort_file_v4 = ORT_TSTR("testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v4.ort"); + const auto ort_file_v5 = + ORT_TSTR("testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v5.test_output.ort"); + + RandomValueGenerator random{}; // keep in scope so we get random seed trace message on failure + + TestOrtModelUpdate(onnx_file, ort_file_v4, ort_file_v5, + [&](NameMLValMap& inputs, std::vector& output_names) { + constexpr int n = 3; // number of QDQ convs + for (size_t i = 0; i < n; ++i) { + std::vector input_dims{1, 1, 5, 5}; + std::vector input_data = random.Uniform(input_dims, 0, 255); + OrtValue ml_value; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), + input_dims, input_data, &ml_value); + + inputs.emplace(MakeString("X_", i), std::move(ml_value)); + output_names.push_back(MakeString("Y_", i)); + } + }); } #if !defined(DISABLE_ML_OPS) diff --git a/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v4.ort b/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v4.ort new file mode 100644 index 0000000000..2fc70dedb3 Binary files /dev/null and b/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v4.ort differ