From 215732f74b341baa6afa2391d756ef44806cd1f2 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Tue, 8 Nov 2022 13:36:46 -0800 Subject: [PATCH] Ignore saved runtime optimizations when updating ORT format model 10536 bytes 10 files changed, 175 insertions(+), 94 deletions(-) create mode 100644 onnxruntime/core/graph/ort_format_load_options.h create mode 100644 onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v4.ort diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index ef5cb20913..0fc2762ca0 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -46,6 +46,7 @@ #endif #include "core/graph/graph_nodes.h" #include "core/graph/node_arg.h" +#include "core/graph/ort_format_load_options.h" namespace flatbuffers { class FlatBufferBuilder; @@ -487,11 +488,11 @@ class Node { #endif static Status LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, Graph& graph, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& node); Status LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger); Status LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge& fbs_node_edgs, const Graph& graph); @@ -1312,18 +1313,18 @@ class Graph { virtual ~Graph(); - static common::Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, const Model& owning_model, - const std::unordered_map& domain_to_version, + static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, const Model& owning_model, + const std::unordered_map& domain_to_version, #if !defined(ORT_MINIMAL_BUILD) - IOnnxRuntimeOpSchemaCollectionPtr schema_registry, + IOnnxRuntimeOpSchemaCollectionPtr schema_registry, #endif - bool can_use_flatbuffer_for_initializers, - const logging::Logger& logger, std::unique_ptr& graph); + const OrtFormatLoadOptions& load_options, + const logging::Logger& logger, std::unique_ptr& graph); // deserialize a subgraph static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, Graph& parent_graph, const Node& parent_node, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& graph); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1353,8 +1354,8 @@ class Graph { bool strict_shape_type_inference); // Populate Graph instance from ORT format serialized data. - common::Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, - bool can_use_flatbuffer_for_initializers); + Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, + const OrtFormatLoadOptions& load_options); #if !defined(ORT_MINIMAL_BUILD) // Constructor: Given a loaded from model file, construct diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index cf8302872a..08f00078e3 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -723,14 +723,14 @@ flatbuffers::Offset Node::SaveEdgesToOrtFormat(flatbuffers::FlatB #endif // !defined(ORT_MINIMAL_BUILD) Status Node::LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, Graph& graph, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& node) { node = std::make_unique(fbs_node.index(), graph); - return node->LoadFromOrtFormat(fbs_node, can_use_flatbuffer_for_initializers, logger); + return node->LoadFromOrtFormat(fbs_node, load_options, logger); } Status Node::LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger) { auto LoadNodeArgsFromOrtFormat = [&](const flatbuffers::Vector>* fbs_node_arg_names, @@ -769,8 +769,7 @@ Status Node::LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, AttributeProto attr_proto; std::unique_ptr subgraph; ORT_RETURN_IF_ERROR( - fbs::utils::LoadAttributeOrtFormat(*fbs_attr, attr_proto, subgraph, *graph_, *this, - can_use_flatbuffer_for_initializers, logger)); + fbs::utils::LoadAttributeOrtFormat(*fbs_attr, attr_proto, subgraph, *graph_, *this, load_options, logger)); // If we have a sub graph in this attributes, it will be loaded into subgraph ptr // while the attribute proto contains the sub graph will have the empty g() field @@ -4200,7 +4199,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, #if !defined(ORT_MINIMAL_BUILD) IOnnxRuntimeOpSchemaCollectionPtr schema_registry, #endif - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& graph) { graph = std::make_unique(owning_model, domain_to_version, #if !defined(ORT_MINIMAL_BUILD) @@ -4210,7 +4209,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, // Assume anything in ORT format has already been validated. false); - ORT_RETURN_IF_ERROR(graph->LoadFromOrtFormat(fbs_graph, can_use_flatbuffer_for_initializers)); + ORT_RETURN_IF_ERROR(graph->LoadFromOrtFormat(fbs_graph, load_options)); #if !defined(ORT_MINIMAL_BUILD) // in a full build we need to run Resolve to fully populate ResolveContext and Node::op_, @@ -4226,7 +4225,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, Graph& parent_graph, const Node& parent_node, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& graph) { graph = std::make_unique(parent_graph.owning_model_, parent_graph.domain_to_version_, @@ -4238,7 +4237,7 @@ Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, // Assume anything in ORT format has already been validated. false); - return graph->LoadFromOrtFormat(fbs_graph, can_use_flatbuffer_for_initializers); + return graph->LoadFromOrtFormat(fbs_graph, load_options); } Graph::Graph(const Model& owning_model, @@ -4268,7 +4267,7 @@ Graph::Graph(const Model& owning_model, } common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, - bool can_use_flatbuffer_for_initializers) { + const OrtFormatLoadOptions& load_options) { // We deserialize the graph from ORT format in the following order: // 1. Deserialize the initializers and sparse initializers. Convert sparse to dense. // 2. Deserialize the NodeArgs @@ -4298,8 +4297,7 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph for (const auto* fbs_tensor : *fbs_initializers) { ORT_RETURN_IF(nullptr == fbs_tensor, "Initializer tensor is missing. Invalid ORT format model."); TensorProto* initializer = deserialized_proto_data_.add_initializer(); - ORT_RETURN_IF_ERROR(fbs::utils::LoadInitializerOrtFormat(*fbs_tensor, *initializer, - can_use_flatbuffer_for_initializers)); + ORT_RETURN_IF_ERROR(fbs::utils::LoadInitializerOrtFormat(*fbs_tensor, *initializer, load_options)); auto p = name_to_initial_tensor_.emplace(initializer->name(), initializer); if (!p.second) { LOGS(logger_, WARNING) << "Duplicate initializer (dense or ConstantNode): '" << initializer->name() @@ -4318,7 +4316,8 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph for (const auto* fbs_sparse_tensor : *fbs_sparse_initializers) { ORT_RETURN_IF(nullptr == fbs_sparse_tensor, "Sparse Initializer tensor is missing. Invalid ORT format model."); SparseTensorProto sparse_initializer; - ORT_RETURN_IF_ERROR(fbs::utils::LoadSparseInitializerOrtFormat(*fbs_sparse_tensor, sparse_initializer)); + ORT_RETURN_IF_ERROR(fbs::utils::LoadSparseInitializerOrtFormat(*fbs_sparse_tensor, sparse_initializer, + load_options)); TensorProto& initializer = *deserialized_proto_data_.add_initializer(); ORT_RETURN_IF_ERROR(utils::SparseTensorProtoToDenseTensorProto(sparse_initializer, model_path, initializer)); auto p = name_to_initial_tensor_.emplace(initializer.name(), &initializer); @@ -4359,8 +4358,7 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph for (const auto* fbs_node : *fbs_nodes) { ORT_RETURN_IF(nullptr == fbs_node, "Node is missing. Invalid ORT format model."); std::unique_ptr node; - ORT_RETURN_IF_ERROR(Node::LoadFromOrtFormat(*fbs_node, *this, can_use_flatbuffer_for_initializers, logger_, - node)); + ORT_RETURN_IF_ERROR(Node::LoadFromOrtFormat(*fbs_node, *this, load_options, logger_, node)); ORT_RETURN_IF(node->Index() >= fbs_graph.max_node_index(), "Node index is out of range"); nodes_[node->Index()] = std::move(node); ++num_of_nodes_; @@ -4407,9 +4405,11 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph ORT_RETURN_IF_ERROR(PopulateNodeArgToProducerConsumerLookupsFromNodes()); // runtime optimizations - if (const auto* fbs_runtime_optimizations = fbs_graph.runtime_optimizations()) { - if (const auto* fbs_runtime_optimization_records = fbs_runtime_optimizations->records()) { - ORT_RETURN_IF_ERROR(MutableRuntimeOptimizations().LoadFromOrtFormat(*fbs_runtime_optimization_records)); + if (!load_options.ignore_saved_runtime_optimizations) { + if (const auto* fbs_runtime_optimizations = fbs_graph.runtime_optimizations()) { + if (const auto* fbs_runtime_optimization_records = fbs_runtime_optimizations->records()) { + ORT_RETURN_IF_ERROR(MutableRuntimeOptimizations().LoadFromOrtFormat(*fbs_runtime_optimization_records)); + } } } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index 15ada5fe8c..1d1fb85bf5 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -177,7 +177,7 @@ Status SaveAttributeOrtFormat(flatbuffers::FlatBufferBuilder& builder, #endif Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& initializer, - bool can_use_flatbuffer_for_initializers) { + const OrtFormatLoadOptions& load_options) { initializer.Clear(); LOAD_STR_FROM_ORT_FORMAT(initializer, name, fbs_tensor.name()); @@ -201,7 +201,7 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init const auto* fbs_raw_data = fbs_tensor.raw_data(); ORT_RETURN_IF(nullptr == fbs_raw_data, "Missing raw data for initializer. Invalid ORT format model."); - if (can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) { + if (load_options.can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) { initializer.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE)); @@ -231,19 +231,20 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init #if !defined(DISABLE_SPARSE_TENSORS) Status LoadSparseInitializerOrtFormat(const fbs::SparseTensor& fbs_sparse_tensor, - SparseTensorProto& initializer) { + SparseTensorProto& initializer, + const OrtFormatLoadOptions& load_options) { SparseTensorProto loaded_initializer; auto fbs_values_tensor = fbs_sparse_tensor.values(); ORT_RETURN_IF(nullptr == fbs_values_tensor, "Missing values for sparse initializer. Invalid ORT format model."); auto* values_tensor = loaded_initializer.mutable_values(); - ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_values_tensor, *values_tensor)); + ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_values_tensor, *values_tensor, load_options)); ORT_RETURN_IF(values_tensor->name().empty(), "Missing name for SparseTensor initializer. Invalid ORT format model."); auto fbs_indicies_tensor = fbs_sparse_tensor.indices(); ORT_RETURN_IF(nullptr == fbs_indicies_tensor, "Missing indicies for sparse initializer: ", "'", values_tensor->name(), "'", "Invalid ORT format model."); auto* indicies_tensor = loaded_initializer.mutable_indices(); - ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_indicies_tensor, *indicies_tensor)); + ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_indicies_tensor, *indicies_tensor, load_options)); auto fbs_dims = fbs_sparse_tensor.dims(); ORT_RETURN_IF(nullptr == fbs_dims, "Missing dims for sparse initializer: ", "'", values_tensor->name(), "'", @@ -259,7 +260,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr, ONNX_NAMESPACE::AttributeProto& attr_proto, std::unique_ptr& sub_graph, onnxruntime::Graph& graph, onnxruntime::Node& node, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger) { attr_proto.Clear(); LOAD_STR_FROM_ORT_FORMAT(attr_proto, name, fbs_attr.name()); @@ -283,7 +284,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr, auto fbs_tensor = fbs_attr.t(); ORT_RETURN_IF(nullptr == fbs_tensor, "Null tensor attribute. Invalid ORT format model."); ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_tensor, *attr_proto.mutable_t(), - can_use_flatbuffer_for_initializers)); + load_options)); } break; case AttributeProto_AttributeType_GRAPH: { // If the attribute type is a graph, we will create an empty graph in attr_proto so that the ONNX checker @@ -292,7 +293,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr, ORT_RETURN_IF(nullptr == fbs_graph, "Null graph attribute. Invalid ORT format model."); attr_proto.mutable_g()->set_name("Empty graph proto from deserialization of ORT format model"); ORT_RETURN_IF_ERROR(onnxruntime::Graph::LoadFromOrtFormat(*fbs_graph, graph, node, - can_use_flatbuffer_for_initializers, + load_options, logger, sub_graph)); } break; case AttributeProto_AttributeType_FLOATS: { @@ -327,7 +328,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr, for (const auto* fbs_tensor : *fbs_tensors) { ORT_RETURN_IF(nullptr == fbs_tensor, "Null tensor in tensors attribute. Invalid ORT format model."); ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_tensor, *tensors->Add(), - can_use_flatbuffer_for_initializers)); + load_options)); } } break; diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.h b/onnxruntime/core/graph/graph_flatbuffers_utils.h index f1d7ab989c..0088babad2 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.h +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.h @@ -6,6 +6,7 @@ #include #include "core/common/status.h" +#include "core/graph/ort_format_load_options.h" namespace ONNX_NAMESPACE { class AttributeProto; @@ -66,18 +67,16 @@ Status SaveAttributeOrtFormat( /// /// Flatbuffer Tensor /// TensorProto to load data into -/// -/// If true, set the TensorProto to point to the memory in the flatbuffer instead of copying data. -/// This requires the buffer to remain valid for the entire duration of the InferenceSession. -/// +/// ORT format load options /// Status Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, ONNX_NAMESPACE::TensorProto& initializer, - bool can_use_flatbuffer_for_initializers = false); + const OrtFormatLoadOptions& load_options); #if !defined(DISABLE_SPARSE_TENSORS) Status LoadSparseInitializerOrtFormat(const fbs::SparseTensor& fbs_sparse_tensor, - ONNX_NAMESPACE::SparseTensorProto& initializer); + ONNX_NAMESPACE::SparseTensorProto& initializer, + const OrtFormatLoadOptions& load_options); #endif // !defined(DISABLE_SPARSE_TENSORS) // Load a give fbs::Attribute into AttributeProto @@ -87,7 +86,7 @@ Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr, ONNX_NAMESPACE::AttributeProto& attr_proto, std::unique_ptr& sub_graph, onnxruntime::Graph& graph, onnxruntime::Node& node, - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger); } // namespace utils diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 4d7bce7f6d..8af9f99ed1 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -771,7 +771,7 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model, #if !defined(ORT_MINIMAL_BUILD) const IOnnxRuntimeOpSchemaRegistryList* local_registries, #endif - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& model) { model = std::make_unique(); @@ -838,10 +838,10 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model, } ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version, schema_registry, - can_use_flatbuffer_for_initializers, logger, model->graph_)); + load_options, logger, model->graph_)); #else ORT_RETURN_IF_ERROR(Graph::LoadFromOrtFormat(*fbs_graph, *model, domain_to_version, - can_use_flatbuffer_for_initializers, logger, model->graph_)); + load_options, logger, model->graph_)); #endif return Status::OK(); } diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 49280ca9fa..bd1f53b43c 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -9,6 +9,7 @@ #include #include "core/common/path.h" #include "core/graph/graph_viewer.h" +#include "core/graph/ort_format_load_options.h" #include "core/session/onnxruntime_c_api.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/graph/function_template.h" @@ -297,7 +298,7 @@ class Model { #if !defined(ORT_MINIMAL_BUILD) const IOnnxRuntimeOpSchemaRegistryList* local_registries, #endif - bool can_use_flatbuffer_for_initializers, + const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& model); diff --git a/onnxruntime/core/graph/ort_format_load_options.h b/onnxruntime/core/graph/ort_format_load_options.h new file mode 100644 index 0000000000..c870a7ada0 --- /dev/null +++ b/onnxruntime/core/graph/ort_format_load_options.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { + +/// Options to configure how an ORT format model is loaded. +struct OrtFormatLoadOptions { + /// If true, set initializer TensorProtos to point to memory in the flatbuffer instead of copying data. + /// This requires the flatbuffer to remain valid for the entire duration of the InferenceSession. + bool can_use_flatbuffer_for_initializers{true}; + + /// If true, do not load any saved runtime optimizations. + bool ignore_saved_runtime_optimizations{false}; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 206c053a34..81d668bdce 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1024,8 +1024,10 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort const auto* fbs_ort_model_version = fbs_session->ort_version(); ORT_RETURN_IF(fbs_ort_model_version == nullptr, "Serialized version info is null. Invalid ORT format model."); - auto model_version = std::stoi(fbs_ort_model_version->str()); - bool is_supported = IsOrtModelVersionSupported(model_version); + const auto model_version = std::stoi(fbs_ort_model_version->str()); + const bool is_supported = IsOrtModelVersionSupported(model_version); + + OrtFormatLoadOptions load_options{}; #if defined(ORT_MINIMAL_BUILD) // Note about the ORT format version 5 breaking change. @@ -1038,16 +1040,34 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort "The ORT format model version [", fbs_ort_model_version->string_view(), "] is not supported in this build ", ORT_VERSION, ". ", kOrtFormatVersion5BreakingChangeNote); -#else - // models prior to v5 can be handled by inserting the kernel constraints in a full build - bool is_supported_with_update = !is_supported && model_version < 5; +#else // ^^ defined(ORT_MINIMAL_BUILD) ^^ / vv !defined(ORT_MINIMAL_BUILD) vv + const auto has_saved_runtime_optimizations = [](const fbs::InferenceSession& fbs_session) -> bool { + if (const auto* fbs_model = fbs_session.model()) { + if (const auto* fbs_graph = fbs_model->graph()) { + if (const auto* fbs_runtime_opts = fbs_graph->runtime_optimizations()) { + if (const auto* fbs_runtime_opt_records = fbs_runtime_opts->records()) { + return fbs_runtime_opt_records->size() > 0; + } + } + } + } + return false; + }; + + // models prior to v5 can be handled by inserting the kernel constraints in a full build + const bool is_supported_with_update = model_version < 5; + + if (is_supported_with_update && has_saved_runtime_optimizations(*fbs_session)) { + LOGS(*session_logger_, WARNING) + << "The old ORT format model (version " << fbs_ort_model_version->string_view() + << ") has saved runtime optimizations. They will be ignored."; + load_options.ignore_saved_runtime_optimizations = true; + } - // currently this means the model is using a future version - // i.e. attempted load of model created with future version of ORT ORT_RETURN_IF_NOT(is_supported || is_supported_with_update, "The ORT format model version [", fbs_ort_model_version->string_view(), - "] is not supported in this build ", ORT_VERSION, ". "); -#endif + "] is not supported in this build ", ORT_VERSION, "."); +#endif // !defined(ORT_MINIMAL_BUILD) const auto* fbs_model = fbs_session->model(); ORT_RETURN_IF(nullptr == fbs_model, "Missing Model. Invalid ORT format model."); @@ -1058,19 +1078,18 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort // if that is the case we also allow creating initializers that directly use those bytes. const auto& config_options = session_options_.config_options; using_ort_model_bytes_for_initializers_ = - ort_format_model_bytes_data_holder_.empty() && - config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseORTModelBytesForInitializers, "0") == "1"; + load_options.can_use_flatbuffer_for_initializers = + ort_format_model_bytes_data_holder_.empty() && + config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseORTModelBytesForInitializers, "0") == "1"; // need to go from unique_ptr to shared_ptr when moving into model_ std::unique_ptr tmp_model; #if !defined(ORT_MINIMAL_BUILD) ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, - using_ort_model_bytes_for_initializers_, - *session_logger_, tmp_model)); + load_options, *session_logger_, tmp_model)); #else - ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model, using_ort_model_bytes_for_initializers_, *session_logger_, - tmp_model)); + ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model, load_options, *session_logger_, tmp_model)); #endif ORT_RETURN_IF_ERROR(SaveModelMetadata(*tmp_model)); diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc index 595f9fb821..5d4901eee2 100644 --- a/onnxruntime/test/framework/ort_model_only_test.cc +++ b/onnxruntime/test/framework/ort_model_only_test.cc @@ -18,11 +18,10 @@ #include "flatbuffers/idl.h" #include "flatbuffers/util.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" -using namespace std; using namespace ONNX_NAMESPACE; -using namespace onnxruntime::logging; namespace onnxruntime { namespace test { @@ -97,7 +96,6 @@ static void CompareTensors(const OrtValue& left_value, const OrtValue& right_val EXPECT_EQ(left_strings[i], right_strings[i]) << "Mismatch index:" << i; } } else { - ASSERT_EQ(memcmp(left.DataRaw(), right.DataRaw(), left.SizeInBytes()), 0); } } @@ -231,8 +229,8 @@ static void CompareSessionMetadata(const InferenceSessionWrapper& session_object ASSERT_EQ(model_1.ProducerVersion(), model_2.ProducerVersion()); } -static void SaveAndCompareModels(const std::basic_string& orig_file, - const std::basic_string& ort_file, +static void SaveAndCompareModels(const PathString& orig_file, + const PathString& ort_file, TransformerLevel optimization_level = TransformerLevel::Level3) { SessionOptions so; so.session_logid = "SerializeToOrtFormat"; @@ -301,7 +299,7 @@ TEST(OrtModelOnlyTests, ValidateOrtFormatModelDoesNotRunOptimizersInFullBuild) { test_info.configs.push_back(std::make_pair(kOrtSessionOptionsConfigLoadModelFormat, "ORT")); OrtValue ml_value; - vector data(28 * 28, 0.0); + std::vector data(28 * 28, 0.0); CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1, 1, 28, 28}, data, &ml_value); test_info.inputs.insert(std::make_pair("Input3", ml_value)); @@ -372,59 +370,103 @@ TEST(OrtModelOnlyTests, MetadataSerialization) { } // test we can load an old ORT format model and run it in a full build. -// we changed from using kernel hashes to kernel type constraints in v5, so any old model should be able to be loaded +// we changed from using kernel hashes to kernel type constraints in v5, so an old model should be able to be loaded // in a full build if we add the kernel type constraints during loading. this also means we can save the updated // ORT format model to effectively upgrade it to v5. -TEST(OrtModelOnlyTests, UpdateOrtModelVersion) { - // input is ORT format model using v4 where we used kernel hashes instead of constraints - const auto onnx_file = ORT_TSTR("testdata/mnist.onnx"); - const auto ort_file_v4 = ORT_TSTR("testdata/mnist.basic.v4.ort"); - const auto ort_file_v5 = ORT_TSTR("testdata/mnist.basic.v5.test_output.ort"); +void TestOrtModelUpdate(const PathString& onnx_file, + const PathString& ort_file_v4, + const PathString& generated_ort_file_v5, + const std::function& output_names)>& + set_up_test_inputs_and_outputs_fn) { + // ort_file_v4 is ORT format model using v4 where we used kernel hashes instead of constraints // update v4 model and save as v5. do not run optimizations in order to preserve the model as-is. - SaveAndCompareModels(ort_file_v4, ort_file_v5, TransformerLevel::Default); + SaveAndCompareModels(ort_file_v4, generated_ort_file_v5, TransformerLevel::Default); // run the original, v4 and v5 models and check the output is the same - RandomValueGenerator random{}; - std::vector input_dims{1, 1, 28, 28}; - std::vector input_data = random.Gaussian(input_dims, 0.0f, 0.9f); - - OrtValue ml_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), - input_dims, input_data, &ml_value); - OrtModelTestInfo test_info; + set_up_test_inputs_and_outputs_fn(test_info.inputs, test_info.output_names); // keep the onnx and ort models to the same optimization level test_info.optimization_level = TransformerLevel::Level1; - test_info.inputs.insert(std::make_pair("Input3", ml_value)); - test_info.output_names = {"Plus214_Output_0"}; - - OrtValue orig_out, v4_out, v5_out; + std::vector orig_out, v4_out, v5_out; test_info.model_filename = onnx_file; test_info.output_verifier = [&orig_out](const std::vector& fetches) { - orig_out = fetches[0]; + orig_out = fetches; }; RunOrtModel(test_info); // run with v4 as input. this should also update to v5 prior to execution. test_info.model_filename = ort_file_v4; test_info.output_verifier = [&v4_out](const std::vector& fetches) { - v4_out = fetches[0]; + v4_out = fetches; }; RunOrtModel(test_info); // validate the model saved as v5 also works - test_info.model_filename = ort_file_v5; + test_info.model_filename = generated_ort_file_v5; test_info.output_verifier = [&v5_out](const std::vector& fetches) { - v5_out = fetches[0]; + v5_out = fetches; }; RunOrtModel(test_info); - CompareTensors(orig_out, v4_out); - CompareTensors(v4_out, v5_out); + auto compare_outputs = [](gsl::span expected, gsl::span actual) { + ASSERT_EQ(expected.size(), actual.size()); + for (size_t i = 0; i < expected.size(); ++i) { + CompareTensors(expected[i], actual[i]); + } + }; + + compare_outputs(orig_out, v4_out); + compare_outputs(v4_out, v5_out); +}; + +TEST(OrtModelOnlyTests, UpdateOrtModelVersion) { + const auto onnx_file = ORT_TSTR("testdata/mnist.onnx"); + const auto ort_file_v4 = ORT_TSTR("testdata/mnist.basic.v4.ort"); + const auto ort_file_v5 = ORT_TSTR("testdata/mnist.basic.v5.test_output.ort"); + + RandomValueGenerator random{}; // keep in scope so we get random seed trace message on failure + + TestOrtModelUpdate(onnx_file, ort_file_v4, ort_file_v5, + [&](NameMLValMap& inputs, std::vector& output_names) { + std::vector input_dims{1, 1, 28, 28}; + std::vector input_data = random.Gaussian(input_dims, 0.0f, 0.9f); + OrtValue ml_value; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), + input_dims, input_data, &ml_value); + + inputs = {{"Input3", ml_value}}; + output_names = {"Plus214_Output_0"}; + }); +} + +// test that a model with saved runtime optimizations can also be updated +// note: the saved runtime optimizations will be ignored +TEST(OrtModelOnlyTests, UpdateOrtModelVersionWithSavedRuntimeOptimizations) { + const auto onnx_file = ORT_TSTR("testdata/transform/runtime_optimization/qdq_convs.onnx"); + const auto ort_file_v4 = ORT_TSTR("testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v4.ort"); + const auto ort_file_v5 = + ORT_TSTR("testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v5.test_output.ort"); + + RandomValueGenerator random{}; // keep in scope so we get random seed trace message on failure + + TestOrtModelUpdate(onnx_file, ort_file_v4, ort_file_v5, + [&](NameMLValMap& inputs, std::vector& output_names) { + constexpr int n = 3; // number of QDQ convs + for (size_t i = 0; i < n; ++i) { + std::vector input_dims{1, 1, 5, 5}; + std::vector input_data = random.Uniform(input_dims, 0, 255); + OrtValue ml_value; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), + input_dims, input_data, &ml_value); + + inputs.emplace(MakeString("X_", i), std::move(ml_value)); + output_names.push_back(MakeString("Y_", i)); + } + }); } #if !defined(DISABLE_ML_OPS) diff --git a/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v4.ort b/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.v4.ort new file mode 100644 index 0000000000000000000000000000000000000000..2fc70dedb3c529ebff9afbc1badabeedee6e313a GIT binary patch literal 10536 zcmb`Ne`uB08OPtWS(>%ZWiHDy=d#pgE`OA8%_KMxacwGLxKwRsT@g)|+;}gGH(9*e zAd-b4PKF!ewi1en6e*#I6W0aVN-3ojDI(I95{i>h#7PmCwOy8VE_a{reV^y%_08+l zc1Ir1bDs13KF>Mt`<|QHAPAn=)bV{f+Jiw=upp=m%E1D#AeaGeHiG}9P|&y>P+P+I zAic()2PRkD-2aF6nYG?u+_CoOKmO~HU!UD|@`2o2zuk5EJAeA~>R*1C^?C0oxizI1oKfAG8A`D}l#W3Vf?tG9n}XL^w7pzgy$VT!=&!J~Rl zQAb_FbRkD#=he|0|zWfa^V$pf@T1 zZ@|}v-bd!?tuErfT1;37?^}u?HV1Q!Sf>VbH=yeh> zMZgWf0)lTv*5$7tdkq6`BYzRuav+6#2lAcB=aC;n{%vF{fokL%k$?W9LP7PUJSk^t z%B$+k+q3X&VIYu9C(r}jL}&^-k{AETH>n}|PCz=O@5a9h1>tI#Ez(5q8p`exO%1n- z{Ni&s`;Cyk0Odnpr;fiQ&af$u{k!u6yYt&OZf_2P-hUJdvhTGh9rDSW^h?GYsVX+{ zDQ-3SHjJKw^a90s82AlP4O1I=l?};`HUz;Tdd*8tdrmQB{yA5tsxwj_v_H}D2@k}q<$e2x01_og)~`v0SB*RJiY z#kGBuoGRx@acKS2THuj?0lQ-$np7f^urKWn(^I1*4 zjicuvy`T}$xuA375=>PnkPX#fXMOUyah$!Ry(W8iIyctvZ1j5dyxh&Xafk)1`=sR!h9HK3EZ zx)`M*E-#+b@p$fLo$D>0n>+x|^Rlo(iL*BE93ie>kYmM}6qnl{h9_y?RU`NaAbG_# zTE|+__W+WWzd3T^dKli*IkKO`Xl*X){2-QVdZ=4dSg~`2XoBGBn}vebJt-c|tF^O_kFb3>H!D)Kw(fOj6 z?vr$)Nv;*po~z=Y@>>ZgTZB;@OHR1z`gZRr>xjYqf~`~ss+~K%t4wkb$cEk#CrI21 zdd2@3?gx{w6@>o!OX3K6;Tyy@^@h+L-JHx-y#IMsb{J!Pb zDZb+<_C`Q)Wq>>i4gu1mxM#jHKkkK*=LRgLb-bInm!MQ@uz(_(1kcNeHPuh-d?NA( zi}j_*wI`Bdc6q)6BN2z+jq#Wjqo4s80JO(PfKfoMR1)tA`Y}K`P|a(nhTq5X@Yd&y zW>q7KU2vy$Pf?(Xf9sctpB&@?*?4Tj4;~*%ckkAx?S{erJ=vb{yH*xk(?7Y*DG|QE zmENVu*2C~O@Sc)UJpn6GaELSW8bvh%yh@POzyNlufD>%?DO#KIs9$_M%b+93A6_|Ngy)>e;Ak`Sq;F;h$|0E{}#RvDX-gZgg9tf2e|Soo>k!L z1zUjCKs})GW}DkU&#+86Q_C4W6Qdt6`iRkGL+zN+M`LuYb+Oj$kZorCj``dMzQ}OU zvPy9G`G)E9@vM!HPlbNEUF3@Ad(lSZ8)OxHq_tRYdm=XaA~rUG-!3mXo6FG4H=Uho z2aG-vqaQQ+sL|cOaigD((Y41)^+ZzMkNCUHa9GNgfO`(^1HYZWbv9Q^FF9Yw7Ne(P z^me1S8QuNuHF{@^9$!aZ1A=3J4|A+_d>4M@n^txa+8Xn4v&rasW)XV*K<` zdXv!`jqd*GUQ+yPWAt0r5r6j*m-29hbA23M&&f$}e-0fnT}l3)vbiyO`Q_`Vdr9%X z8lzt}`h?Nl-)W;y#psK<_mWTR@o%&FDm&}J-M?mV*-G-S-R9cpCFlP28oe_{A2#}s z(cM4YGfMyA82yjt-%C6brzV?!}^d!HakLPp1dDhwRVh^4o!LH_J{PywWW@Pn%89vJK##;}&rF z?RM7L_}ZQA?8eYa**Rc#l&9AscFux(JjY^oPT6>jUgkaTr95#~UXJ*`*l_VyfxBMa zjYa*hU{2#P?x~`B!ISfZoptakPxAj~5j#WRo~Lc5*YlLKvF?WvEBU{OL1Zn$9AV={ zvvHL9tH}78Ujla-O?7j`{1)#)@KXGy$H&7quIJ7~yivWpwgyjjclD*WFAvU`Zsp-dME8}5 z?l(P}R z!vm(*^YW~XS913v-e_JDYlUaaJNV>w78x$tN^sY!-(VD%k7sSHd&jt)x%{VsX6B{4 zg9n=Kw}jKmdcZw~?}1BqQvQ2wu9IGJ{+t~)`cRDitkDk}-Qz!H^syNIE%Q(B0SgV6 zopNyZ?|qX^@=w3Ph#t^O&g-Mj=(RC=i_ud?cmLXr-WH=v-&54_UhbfjJJscRJGJbnxOmQ0zONlZf9|;FZ6v zz@WteT9yI#b5-|@^21}<3n6I+a2k!oDGT9{mayFNtmz>wx z0i%z^=*Ns6U#D@SpEiE?JFK(j;63y61V-u%m%VDc^OA2Ln~o$ubvGt@6}{x#&nBZc z#^~#ezSijO&o;(fQ*irJ&2ItW@9xs)^WaJK#@hkEEpSie_2%(FE}KB7*V`**SNi1d z1sk8E7q8nXy&q6x;hmH;hkqv`>=1>r&g?NPTMh2(xfEQwlk&92=2p^6&g*Tn(Kp2C zIiqKc?*1KM9OkL!wt14@3(2ATt$U{o;CE~fSv-0U(aLn!=zr_r`Qf^wv_EA}H@^CN z8RaLmUzKQo?pe@W{5yk-6ht4808_fd|Jr+eRQxUuJ zc_r}4>t2uZx!+#3|Mia6gg2I)imZO9^SuqHto!I5BBfs@69)*hX&Gv*`4WJ zzN_z6ewPXFts*i^9E|BzzPqwb{khyw)4*VQPd5EbQ+}{3o6F{QNYB|Azv6fPDBmP} zsv9-=tGglR+lhbr+qCe+o;yRtr=NmvjVsrkd8QkF{T`$f-hT6>SCfo4