From 385fab5baec0ba0e936a7232cc5657bf85866af8 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Fri, 6 Oct 2023 15:56:33 -0700 Subject: [PATCH] [QNN EP] Qnn cache improvement (#17757) ### Description Improve the QNN context binary cache feature to reduce the memory overhead and initialization time overhead. Instead of dumping a Qnn context binary file with metadata as header, we dump a Onnx format file with metadata inside Onnx node. ### Motivation and Context reduce the memory overhead and initialization time overhead --- docs/ContribOperators.md | 50 ++++ .../core/session/onnxruntime_c_api.h | 3 + .../core/graph/contrib_ops/contrib_defs.cc | 77 +++++ .../qnn/builder/onnx_ctx_model_helper.cc | 264 ++++++++++++++++++ .../qnn/builder/onnx_ctx_model_helper.h | 108 +++++++ .../qnn/builder/qnn_backend_manager.cc | 243 ++-------------- .../qnn/builder/qnn_backend_manager.h | 44 ++- .../core/providers/qnn/builder/qnn_model.cc | 6 +- .../core/providers/qnn/builder/qnn_model.h | 23 +- .../providers/qnn/qnn_execution_provider.cc | 84 ++++-- .../providers/qnn/qnn_execution_provider.h | 2 + onnxruntime/test/onnx/main.cc | 6 + .../test/providers/qnn/qnn_test_utils.h | 20 +- .../test/providers/qnn/simple_op_htp_test.cc | 121 +++++++- .../azure-pipelines/linux-qnn-ci-pipeline.yml | 4 +- 15 files changed, 770 insertions(+), 285 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc create mode 100644 onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 95dc8c3cde..888bcdbb9e 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -27,6 +27,7 @@ Do not modify directly.* * com.microsoft.DequantizeWithOrder * com.microsoft.DynamicQuantizeLSTM * com.microsoft.DynamicQuantizeMatMul + * com.microsoft.EPContext * com.microsoft.EmbedLayerNormalization * com.microsoft.ExpandDims * com.microsoft.FastGelu @@ -1520,6 +1521,55 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.EPContext** + + Onnx node container for EP context. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
embed_mode : int
+
1: indicate ep_cache_context is the context content. 0: indicate ep_cache_context is the file path to the context content.The path is relative to this Onnx file. Default is 1.
+
ep_cache_context : string
+
payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0.
+
ep_sdk_version : string
+
(Optional) SDK version used to convert the model.
+
main_context : int
+
Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.
+
notes : string
+
(Optional) Some notes for the model
+
partition_name : string
+
(Optional) partitioned graph name.
+
source : string
+
(Optional) the source used to generate the engine/context cache file. Ort EP or native SDK tool chain
+
+ +#### Inputs (1 - ∞) + +
+
inputs (variadic) : T
+
List of tensors for inputs
+
+ +#### Outputs (1 - ∞) + +
+
outputs (variadic) : T
+
One or more outputs, list of tensors for outputs
+
+ +#### Type Constraints + +
+
T : tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(float16), tensor(float), tensor(double)
+
Constrain input and output types.
+
+ + ### **com.microsoft.EmbedLayerNormalization** EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing. diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 8393978120..4be49bdaea 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3597,6 +3597,9 @@ struct OrtApi { * "rpc_control_latency": QNN RPC control latency. * "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance", * "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default". + * "qnn_context_embed_mode", 1 means dump the QNN context binary into node attribute EPContext->ep_cache_context in the Onnx skeleton model. + * 0 means dump the QNN context binary into separate bin file and set the path to EPContext->ep_cache_context. + * The path is relative path to the Onnx skeleton model file. * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and * may alter model/EP partitioning. Use only for debugging. diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 64f8c8d86f..21f9db7e48 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2844,6 +2844,83 @@ void RegisterContribSchemas() { propagateElemTypeFromInputToOutput(ctx, 0, 0); }); + ONNX_CONTRIB_OPERATOR_SCHEMA(EPContext) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc("Onnx node container for EP context.") + .Attr( + "main_context", + "Usually each single EPContext associate with a graph partition." + "But for some case like QNN, it has single EPContext contains all partitions." + "In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context." + "The path is relative to this Onnx file. Default is 1.", + AttributeProto::INT, + static_cast(1)) + .Attr( + "ep_cache_context", + "payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0.", + AttributeProto::STRING, + OPTIONAL_VALUE) + .Attr( + "embed_mode", + "1: indicate ep_cache_context is the context content. 0: indicate ep_cache_context is the file path to the context content." + "The path is relative to this Onnx file. Default is 1.", + AttributeProto::INT, + static_cast(1)) + .Attr( + "ep_sdk_version", + "(Optional) SDK version used to convert the model.", + AttributeProto::STRING, + OPTIONAL_VALUE) + .Attr( + "partition_name", + "(Optional) partitioned graph name.", + AttributeProto::STRING, + OPTIONAL_VALUE) + .Attr( + "source", + "(Optional) the source used to generate the engine/context cache file. Ort EP or native SDK tool chain", + AttributeProto::STRING, + OPTIONAL_VALUE) + .Attr("notes", "(Optional) Some notes for the model", AttributeProto::STRING, OPTIONAL_VALUE) + .AllowUncheckedAttributes() + .Input( + 0, + "inputs", + "List of tensors for inputs", + "T", + OpSchema::Variadic, + true, + 1, + OpSchema::NonDifferentiable) + .Output( + 0, + "outputs", + "One or more outputs, list of tensors for outputs", + "T", + OpSchema::Variadic, + true, + 1, + OpSchema::NonDifferentiable) + .TypeConstraint( + "T", + {"tensor(int8)", + "tensor(int16)", + "tensor(int32)", + "tensor(int64)", + "tensor(uint8)", + "tensor(uint16)", + "tensor(uint32)", + "tensor(uint64)", + "tensor(float16)", + "tensor(float)", + "tensor(double)"}, + "Constrain input and output types.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + }); + static const char* BitmaskDropout_ver1_doc = R"DOC( BitmaskDropout takes an input floating-point tensor, an optional input ratio (floating-point scalar) and an optional input training_mode (boolean scalar). It produces two tensor outputs: output (floating-point tensor) and mask (optional `Tensor`). If `training_mode` is true then the output Y will be a random dropout. diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc new file mode 100644 index 0000000000..afa5d1ce77 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -0,0 +1,264 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/onnx_ctx_model_helper.h" +#include "core/graph/constants.h" +#include "core/providers/qnn/builder/qnn_model.h" + +#include +#include +#include + +namespace onnxruntime { +namespace qnn { + +Status IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs, + bool& is_qnn_ctx_model) { + is_qnn_ctx_model = false; + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph); + // It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type + int count = 0; + for (const auto& node : graph_viewer.Nodes()) { + if (EPCONTEXT_OP == node.OpType()) { + is_qnn_ctx_model = true; + } + ++count; + } + ORT_RETURN_IF(is_qnn_ctx_model && count > 1, "Fused graph should only has 1 single EPContext node."); + } + return Status::OK(); +} + +bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer) { + // It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type + for (const auto& node : graph_viewer.Nodes()) { + if (EPCONTEXT_OP == node.OpType()) { + return true; + } + } + return false; +} + +Status CreateNodeArgs(const std::vector& names, + const std::unordered_map& tensor_info_table, + std::vector& node_args, + onnxruntime::Graph& graph) { + using namespace ONNX_NAMESPACE; + for (size_t i = 0; i < names.size(); ++i) { + std::string name = names[i]; + ORT_RETURN_IF(tensor_info_table.find(name) == tensor_info_table.end(), "Tensor name: ", name, " not found in tensor_info_table"); + const OnnxTensorInfo& tensor_info = tensor_info_table.at(name); + TypeProto tensor_type; + tensor_type.mutable_tensor_type()->set_elem_type(tensor_info.data_type_); + for (size_t j = 0; j < tensor_info.shape_.size(); ++j) { + tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(tensor_info.shape_[j]); + } + auto& input_arg = graph.GetOrCreateNodeArg(name, &tensor_type); + node_args.push_back(&input_arg); + } + return Status::OK(); +} + +Status GetEpContextFromModel(const std::string& ctx_onnx_model_path, + std::string& ep_cache_context, + const logging::Logger& logger) { + using namespace onnxruntime; + std::shared_ptr model; + ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger)); + const auto& graph = model->MainGraph(); + ORT_RETURN_IF_ERROR(GetEpContextFromGraph(GraphViewer(graph), ctx_onnx_model_path, ep_cache_context)); + + return Status::OK(); +} + +Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, + const std::string& ctx_onnx_model_path, + std::string& ep_cache_context) { + const auto& node = graph_viewer.Nodes().begin(); + NodeAttrHelper node_helper(*node); + bool is_embed_mode = node_helper.Get(EMBED_MODE, true); + if (is_embed_mode) { + ep_cache_context = node_helper.Get(EP_CACHE_CONTEXT, ""); + } else { + std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, ""); + + std::string context_binary_path(std::filesystem::path(ctx_onnx_model_path).parent_path().string() + + "/" + external_qnn_context_binary_file_name); + size_t buffer_size{0}; + std::ifstream cache_file(context_binary_path.c_str(), std::ifstream::binary); + ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file."); + cache_file.seekg(0, cache_file.end); + buffer_size = static_cast(cache_file.tellg()); + ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered."); + cache_file.seekg(0, cache_file.beg); + ep_cache_context.reserve(buffer_size); + // Load file into buffer + ep_cache_context.assign(std::istreambuf_iterator(cache_file), std::istreambuf_iterator()); + cache_file.close(); + ORT_RETURN_IF(ep_cache_context.length() != buffer_size, "Failed to read contents from cached context file."); + } + ORT_RETURN_IF(ep_cache_context.empty(), "Cached context empty."); + return Status::OK(); +} + +Status QnnCacheModelHandler::GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path, + std::string& model_name, + std::string& model_description, + std::string& graph_partition_name, + std::string& cache_source, + const logging::Logger& logger) { + if (!is_metadata_ready_) { + using namespace onnxruntime; + std::shared_ptr model; + ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger)); + const auto& graph = GraphViewer(model->MainGraph()); + const auto& node = graph.Nodes().begin(); + NodeAttrHelper node_helper(*node); + model_name_ = graph.Name(); + model_description_ = graph.Description(); + graph_partition_name_ = node_helper.Get(PARTITION_NAME, ""); + cache_source_ = node_helper.Get(SOURCE, ""); + is_metadata_ready_ = true; + } + model_name = model_name_; + model_description = model_description_; + graph_partition_name = graph_partition_name_; + cache_source = cache_source_; + + return Status::OK(); +} + +bool QnnCacheModelHandler::IsContextCacheFileExists(const std::string& customer_context_cache_path, + const std::string& model_description, + const onnxruntime::PathString& model_pathstring) { + // Avoid duplicate work + if (ctx_file_exists_) { + return ctx_file_exists_; + } + model_description_ = model_description; + // Use user provided context cache file path if exist, otherwise try model_file.onnx_ctx.onnx by default + if (customer_context_cache_path.empty()) { + context_cache_path_ = PathToUTF8String(model_pathstring) + "_qnn_ctx.onnx"; + } else { + context_cache_path_ = customer_context_cache_path; + } + + ctx_file_exists_ = std::filesystem::exists(context_cache_path_); + + return ctx_file_exists_; +} + +Status QnnCacheModelHandler::ValidateWithContextFile(const std::string& model_name, + const std::string& graph_partition_name, + const logging::Logger& logger) { + ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!"); + + std::string model_name_from_ctx_cache; + std::string model_description_from_ctx_cache; + std::string graph_partition_name_from_ctx_cache; + std::string cache_source; + ORT_RETURN_IF_ERROR(GetMetadataFromEpContextModel(context_cache_path_, + model_name_from_ctx_cache, + model_description_from_ctx_cache, + graph_partition_name_from_ctx_cache, + cache_source, + logger)); + + // The source attribute from the skeleton onnx file indicate whether it's generated from QNN toolchain or ORT + if (cache_source != kQnnExecutionProvider) { + return Status::OK(); + } + + ORT_RETURN_IF(model_name != model_name_from_ctx_cache, + "Model file name from context cache metadata: " + model_name_from_ctx_cache + + " is different with target: " + model_name + + ". Please make sure the context binary file matches the model."); + + ORT_RETURN_IF(model_description_ != model_description_from_ctx_cache, + "Model description from context cache metadata: " + model_description_from_ctx_cache + + " is different with target: " + model_description_ + + ". Please make sure the context binary file matches the model."); + + ORT_RETURN_IF(graph_partition_name != graph_partition_name_from_ctx_cache && get_capability_round_2_, + "Graph name from context cache metadata: " + graph_partition_name_from_ctx_cache + + " is different with target: " + graph_partition_name + + ". You may need to re-generate the context binary file."); + + get_capability_round_2_ = true; + return Status::OK(); +} + +Status QnnCacheModelHandler::GenerateCtxCacheOnnxModel(unsigned char* buffer, + uint64_t buffer_size, + const std::string& sdk_build_version, + const std::vector& fused_nodes_and_graphs, + const std::unordered_map>& qnn_models, + const logging::Logger& logger) { + std::unordered_map domain_to_version = {{kOnnxDomain, 11}, {kMSDomain, 1}}; + Model model(model_name_, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, {}, logger); + auto& graph = model.MainGraph(); + graph.SetDescription(model_description_); + + using namespace ONNX_NAMESPACE; + int index = 0; + // Still need more work to support multiple partition, it's out of EP's scope. + // Already have code to make sure it's single partition before this method get invoked. + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph); + Node& fused_node = fused_node_graph.fused_node; + // graph_viewer.Name() is generated in GetCapability, e.g QNN_[hash_id]_[id] + // dump graph_viewer.Name() as metadata in context cache binary file, so that we can validate it in GetCapability + auto qnn_model_kv = qnn_models.find(fused_node.Name()); + ORT_RETURN_IF(qnn_model_kv == qnn_models.end(), fused_node.Name(), " not exist in QnnModel table."); + + auto qnn_model = qnn_model_kv->second.get(); + std::vector inputs; + std::vector outputs; + ORT_RETURN_IF_ERROR(CreateNodeArgs(qnn_model->GetInputNames(), qnn_model->GetInputsInfo(), inputs, graph)); + ORT_RETURN_IF_ERROR(CreateNodeArgs(qnn_model->GetOutputNames(), qnn_model->GetOutputsInfo(), outputs, graph)); + + const std::string& graph_name = graph_viewer.Name(); + auto& ep_node = graph.AddNode(graph_name, + EPCONTEXT_OP, + "Onnx Qnn context binary cache for graph partition: " + graph_name, + inputs, + outputs, + nullptr, + kMSDomain); + + // Only dump the context buffer once since all QNN graph are in one single context + if (0 == index) { + if (qnn_context_embed_mode_) { + std::string cache_payload(buffer, buffer + buffer_size); + ep_node.AddAttribute(EP_CACHE_CONTEXT, cache_payload); + } else { + std::string context_cache_path(context_cache_path_ + "_" + graph_name + ".bin"); + std::string context_cache_name(std::filesystem::path(context_cache_path).filename().string()); + std::ofstream of_stream(context_cache_path.c_str(), std::ofstream::binary); + if (!of_stream) { + LOGS(logger, ERROR) << "Failed to open create context file."; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to open context cache file."); + } + of_stream.write(reinterpret_cast(buffer), buffer_size); + ep_node.AddAttribute(EP_CACHE_CONTEXT, context_cache_name); + } + } else { + ep_node.AddAttribute(MAIN_CONTEXT, static_cast(0)); + } + int64_t embed_mode = qnn_context_embed_mode_ ? static_cast(1) : static_cast(0); + ep_node.AddAttribute(EMBED_MODE, embed_mode); + ep_node.AddAttribute(EP_SDK_VER, sdk_build_version); + ep_node.AddAttribute(PARTITION_NAME, graph_name); + ep_node.AddAttribute(SOURCE, kQnnExecutionProvider); + ++index; + } + ORT_RETURN_IF_ERROR(graph.Resolve()); + ORT_RETURN_IF_ERROR(Model::Save(model, context_cache_path_)); + + return Status::OK(); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h new file mode 100644 index 0000000000..1ff8d00a0e --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "qnn_def.h" +#include "core/common/logging/logging.h" +#include "core/graph/graph_viewer.h" +#include "core/providers/shared/utils/utils.h" +#include "core/graph/model.h" +#include "core/framework/execution_provider.h" + +namespace onnxruntime { + +namespace qnn { + +class QnnModel; + +static const std::string EPCONTEXT_OP = "EPContext"; +static const std::string MAIN_CONTEXT = "main_context"; +static const std::string EMBED_MODE = "embed_mode"; +static const std::string EP_CACHE_CONTEXT = "ep_cache_context"; +static const std::string EP_SDK_VER = "ep_sdk_version"; +static const std::string PARTITION_NAME = "partition_name"; +static const std::string SOURCE = "source"; + +Status IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs, + bool& is_qnn_ctx_model); + +bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer); + +Status CreateNodeArgs(const std::vector& names, + const std::unordered_map& tensor_info_table, + std::vector& node_args, + onnxruntime::Graph& graph); + +Status GetEpContextFromModel(const std::string& ctx_onnx_model_path, + std::string& ep_engine_cache, + const logging::Logger& logger); + +Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, + const std::string& ctx_onnx_model_path, + std::string& ep_cache_context); + +class QnnCacheModelHandler { + public: + QnnCacheModelHandler(bool qnn_context_embed_mode) : qnn_context_embed_mode_(qnn_context_embed_mode) { + } + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnCacheModelHandler); + + Status GetEpContext(const onnxruntime::GraphViewer& graph_viewer, + const std::string& ctx_onnx_model_path, + bool is_qnn_ctx_model, + bool is_ctx_cache_file_exist, + std::string& ep_engine_cache, + const logging::Logger& logger) const { + if (is_qnn_ctx_model) { + ORT_RETURN_IF_ERROR(GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, ep_engine_cache)); + } else if (is_ctx_cache_file_exist) { + ORT_RETURN_IF_ERROR(GetEpContextFromModel(ctx_onnx_model_path, ep_engine_cache, logger)); + } + + return Status::OK(); + } + + bool IsContextCacheFileExists(const std::string& customer_context_cache_path, + const std::string& model_description, + const onnxruntime::PathString& model_pathstring); + + bool GetIsContextCacheFileExists() const { + return ctx_file_exists_; + } + + Status ValidateWithContextFile(const std::string& model_name, + const std::string& graph_name, + const logging::Logger& logger); + + Status GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path, + std::string& model_name, + std::string& model_description, + std::string& graph_partition_name, + std::string& cache_source, + const logging::Logger& logger); + + Status GenerateCtxCacheOnnxModel(unsigned char* buffer, + uint64_t buffer_size, + const std::string& sdk_build_version, + const std::vector& fused_nodes_and_graphs, + const std::unordered_map>& qnn_models, + const logging::Logger& logger); + + private: + bool is_metadata_ready_ = false; + std::string model_name_ = ""; + std::string model_description_ = ""; + std::string graph_partition_name_ = ""; + std::string cache_source_ = ""; + std::string context_cache_path_ = ""; + bool ctx_file_exists_ = false; + bool get_capability_round_2_ = false; + bool qnn_context_embed_mode_ = true; +}; // QnnCacheModelHandler + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index e2083371ac..043b8a7900 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -3,8 +3,6 @@ #include "qnn_backend_manager.h" #include "qnn_model.h" -#include -#include #include #include "QnnOpDef.h" #include "HTP/QnnHtpPerfInfrastructure.h" @@ -16,6 +14,7 @@ #include "core/common/gsl.h" #include "core/framework/endian_utils.h" #include "core/common/logging/capture.h" +#include "core/providers/qnn/builder/onnx_ctx_model_helper.h" // Flag to determine if Backend should do node validation for each opNode added #define DO_GRAPH_NODE_VALIDATIONS 1 @@ -28,8 +27,6 @@ typedef Qnn_ErrorHandle_t (*QnnInterfaceGetProvidersFn_t)(const QnnInterface_t** typedef Qnn_ErrorHandle_t (*QnnSystemInterfaceGetProvidersFn_t)(const QnnSystemInterface_t*** providerList, uint32_t* numProviders); -constexpr const char* QNN_PROVIDER = "ORTQNNEP"; - static Qnn_Version_t GetQnnInterfaceApiVersion(const QnnInterface_t* qnn_interface) { return qnn_interface->apiVersion.coreApiVersion; } @@ -412,135 +409,64 @@ Status QnnBackendManager::ReleaseContext() { return Status::OK(); } -bool QnnBackendManager::IsContextCacheFileExists(const std::string& customer_context_cache_path, - const std::string& model_description, - const onnxruntime::PathString& model_pathstring) { - // Avoid duplicate work - if (!context_cache_path_.empty()) { - return ctx_file_exists_; - } - model_description_ = model_description; - // Use user provided context cache file path if exist, otherwise try model_file.onnx.bin by default - if (customer_context_cache_path.empty()) { - context_cache_path_ = PathToUTF8String(model_pathstring) + ".bin"; - } else { - context_cache_path_ = customer_context_cache_path; - } - - ctx_file_exists_ = std::filesystem::exists(context_cache_path_); - - return ctx_file_exists_; -} - -Status WriteInt16ToBinaryFile(std::ofstream& of_stream, uint16_t value) { - const std::vector data{value}; - std::vector data_bytes(sizeof(uint16_t) / sizeof(unsigned char)); - ORT_RETURN_IF_ERROR(onnxruntime::utils::WriteLittleEndian(gsl::make_span(data), gsl::make_span(data_bytes))); - of_stream.write(reinterpret_cast(data_bytes.data()), data_bytes.size()); - return Status::OK(); -} - -Status QnnBackendManager::DumpQnnContext(const std::string& model_name, const std::string& graph_name) { +std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint64_t& written_buffer_size) { if (nullptr == qnn_interface_.contextGetBinarySize || nullptr == qnn_interface_.contextGetBinary) { LOGS(*logger_, ERROR) << "Failed to get valid function pointer."; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get valid function pointer."); + return nullptr; } uint64_t required_buffer_size(0); Qnn_ErrorHandle_t rt = qnn_interface_.contextGetBinarySize(context_, &required_buffer_size); if (QNN_CONTEXT_NO_ERROR != rt) { LOGS(*logger_, ERROR) << "Failed to get QNN context binary size. Error code: " << rt; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get QNN context binary size."); + return nullptr; } std::unique_ptr context_buffer = std::make_unique(required_buffer_size); if (nullptr == context_buffer) { LOGS(*logger_, ERROR) << "Failed to allocate buffer for context cache."; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to allocate buffer for context cache."); + return nullptr; } - uint64_t written_buffer_size(0); rt = qnn_interface_.contextGetBinary(context_, reinterpret_cast(context_buffer.get()), required_buffer_size, &written_buffer_size); if (QNN_CONTEXT_NO_ERROR != rt) { LOGS(*logger_, ERROR) << "Failed to get context binary."; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get context binary."); + return nullptr; } if (required_buffer_size < written_buffer_size) { LOGS(*logger_, ERROR) << "Context written buffer size: " << written_buffer_size << " exceeds allocated buffer size: " << required_buffer_size; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Context written buffer exceeds allocated buffer size."); + return nullptr; } - std::ofstream of_stream(context_cache_path_.c_str(), std::ofstream::binary); - if (!of_stream) { - LOGS(*logger_, ERROR) << "Failed to open cached context file."; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to open context cache file."); + LOGS(*logger_, VERBOSE) << "Get context binary buffer succeed."; + return context_buffer; +} + +Status QnnBackendManager::LoadCachedQnnCtxFromOnnxModel(const std::string& ep_engine_cache, + QnnModel& qnn_model, + bool& loaded_from_cache) { + loaded_from_cache = false; + + if (!ep_engine_cache.empty()) { + ORT_RETURN_IF_ERROR(LoadCachedQnnContextFromBuffer(ep_engine_cache, qnn_model)); + loaded_from_cache = true; } - // Write Ort metadata into context binary file - uint16_t model_name_length = static_cast(model_name.length()); - uint16_t graph_name_length = static_cast(graph_name.length()); - uint16_t model_description_length = static_cast(model_description_.length()); - - // Header: uint16_t(totale_length)|uint16_t(model_name_length)|model_name|uint16_t(graph_name_length)|graph_name|uint16_t(model_description_length)|model_description - uint16_t header_length = 4 * sizeof(uint16_t) + model_name_length + graph_name_length + model_description_length; - uint16_t totale_length = header_length + static_cast(strlen(QNN_PROVIDER)); - of_stream.write(QNN_PROVIDER, strlen(QNN_PROVIDER)); - - ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, header_length)); - - ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, model_name_length)); - of_stream.write(model_name.c_str(), model_name_length); - - ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, graph_name_length)); - of_stream.write(graph_name.c_str(), graph_name_length); - - ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, model_description_length)); - of_stream.write(model_description_.c_str(), model_description_length); - model_description_.clear(); - - LOGS(*logger_, VERBOSE) << "Dump metadata with length: " << totale_length; - - of_stream.write(reinterpret_cast(context_buffer.get()), written_buffer_size); - - LOGS(*logger_, VERBOSE) << "Dump QNN Context completed."; return Status::OK(); } -Status QnnBackendManager::LoadCachedQnnContext(QnnModel& qnn_model) { +Status QnnBackendManager::LoadCachedQnnContextFromBuffer(const std::string& buffer, QnnModel& qnn_model) { bool result = nullptr == qnn_sys_interface_.systemContextCreate || nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || nullptr == qnn_sys_interface_.systemContextFree; ORT_RETURN_IF(result, "Failed to get valid function pointer."); - ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!"); - - uint64_t buffer_size{0}; - std::ifstream cache_file(context_cache_path_.c_str(), std::ifstream::binary); - ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file."); - cache_file.seekg(0, cache_file.end); - buffer_size = cache_file.tellg(); - ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered."); - cache_file.seekg(0, cache_file.beg); - // Skip Ort generated metadata - if (ort_generated_ctx_cache_) { - cache_file.seekg(ort_ctx_metadata_length_); - buffer_size -= ort_ctx_metadata_length_; - } - - std::unique_ptr buffer = std::make_unique(buffer_size); - ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for cache file."); - - // Load file into buffer - const auto& read_result = cache_file.read(reinterpret_cast(buffer.get()), buffer_size); - cache_file.close(); - ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file."); - QnnSystemContext_Handle_t sys_ctx_handle = nullptr; auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle); ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle."); @@ -548,8 +474,8 @@ Status QnnBackendManager::LoadCachedQnnContext(QnnModel& qnn_model) { const QnnSystemContext_BinaryInfo_t* binary_info = nullptr; Qnn_ContextBinarySize_t binary_info_size{0}; rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle, - static_cast(buffer.get()), - buffer_size, + static_cast(const_cast(buffer.c_str())), + static_cast(buffer.length()), &binary_info, &binary_info_size); ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info."); @@ -576,12 +502,14 @@ Status QnnBackendManager::LoadCachedQnnContext(QnnModel& qnn_model) { rt = qnn_interface_.contextCreateFromBinary(backend_handle_, device_handle_, (const QnnContext_Config_t**)&context_config_, - static_cast(buffer.get()), - buffer_size, + static_cast(const_cast(buffer.c_str())), + static_cast(buffer.length()), &context_, profile_backend_handle_); ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary."); + // More work to support multiple partition, how to map the graph name in compile to qnn graph name + // Need the lower level framework to understand EPContext op and pass in the partition_name in fused_node during Compile ORT_RETURN_IF_ERROR(qnn_model.DeserializeGraphInfoFromBinaryInfo(graphs_info[0])); qnn_sys_interface_.systemContextFree(sys_ctx_handle); @@ -590,126 +518,10 @@ Status QnnBackendManager::LoadCachedQnnContext(QnnModel& qnn_model) { ORT_RETURN_IF_ERROR(ExtractBackendProfilingInfo()); context_created_ = true; - model_description_.clear(); - model_description_from_ctx_cache_.clear(); LOGS(*logger_, VERBOSE) << "Load from cached QNN Context completed."; return Status::OK(); } -/* \brief: Read string data from binary file with given length - * \param[in] binary_file - file stream of the binary file - * \param[out] result_str - string read from binary file - * \param[out] length - length to read - */ -Status ReadStringFromBinaryFile(std::ifstream& binary_file, std::string& result_str, size_t length) { - result_str.resize(length); - const auto& read_result = binary_file.read(result_str.data(), length); - ORT_RETURN_IF(!read_result, "Failed to read contents from cached context binary file."); - - return Status::OK(); -} - -/* \brief: Read a uint16_t from binary file - * \param[in] binary_file - file stream of the binary file - * \param[out] value - uint16_t value - */ -Status ReadInt16FromBinaryFile(std::ifstream& binary_file, uint16_t& value) { - std::unique_ptr buffer = std::make_unique(sizeof(uint16_t)); - ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for buffer."); - const auto& read_result = binary_file.read(buffer.get(), sizeof(uint16_t)); - ORT_RETURN_IF(!read_result, "Failed to read contents from cached context binary file."); - - auto src = gsl::make_span(reinterpret_cast(buffer.get()), sizeof(uint16_t)); - std::vector dst(1); - ORT_RETURN_IF_ERROR(onnxruntime::utils::ReadLittleEndian(src, gsl::make_span(dst))); - value = dst[0]; - - return Status::OK(); -} - -/* \brief: Try to get metadata from Ort generated context cache binary file. - * Cached context binary file generated by Ort has some metadata which can be used for validation with the model - * to avoid user choose a wrong context binary file which is not for this model - * It is treated as Qnn generated context binary file if no metadata found from the file - */ -Status QnnBackendManager::GetMetadataFromOrtContextFile() { - // Only try parse meta data once - if (ctx_metadata_tried_) { - return Status::OK(); - } - ctx_metadata_tried_ = true; - - uint64_t buffer_size = 0; - std::ifstream cache_file(context_cache_path_.c_str(), std::ifstream::binary); - ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open context cache file."); - cache_file.seekg(0, cache_file.end); - buffer_size = cache_file.tellg(); - ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered."); - cache_file.seekg(0, cache_file.beg); - - // Read ort flag - std::string ort_flag(""); - size_t ort_flag_length = strlen(QNN_PROVIDER); - ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, ort_flag, ort_flag_length)); - - // It's not Ort generated context binary file - if (strncmp(ort_flag.c_str(), QNN_PROVIDER, ort_flag_length) != 0) { - return Status::OK(); - } - ort_generated_ctx_cache_ = true; - - uint16_t str_length = 0; - ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length)); - ort_ctx_metadata_length_ = str_length + static_cast(ort_flag_length); - - ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length)); - ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, model_name_from_ctx_cache_, static_cast(str_length))); - - ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length)); - ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, graph_name_from_ctx_cache_, static_cast(str_length))); - - ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length)); - ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, model_description_from_ctx_cache_, static_cast(str_length))); - - return Status::OK(); -} - -/* \brief: Validate the model file name and graph name with Ort generated context cache metadata - * \param[in] model_name - model file name - * \param[in] graph_name - graph name, e.g Ort_QNN_[hash_id]_[id]. Since GetCapability is called twice, - * [hash_id]_[id] changes even for same graph, - * so only validate the graph name for 2nd call - */ -Status QnnBackendManager::ValidateWithContextFile(const std::string& model_name, const std::string& graph_name) { - ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!"); - - // Get metadata from cached context binary file - ORT_RETURN_IF_ERROR(GetMetadataFromOrtContextFile()); - - // The context binary file doesn't have ORT metadata, so it is generated from QNN toolchain not from ORT - if (!ort_generated_ctx_cache_) { - return Status::OK(); - } - - ORT_RETURN_IF(model_name != model_name_from_ctx_cache_, - "Model file name from context cache metadata: " + model_name_from_ctx_cache_ + - " is different with target: " + model_name + - ". Please make sure the context binary file matches the model."); - - ORT_RETURN_IF(model_description_ != model_description_from_ctx_cache_, - "Model description from context cache metadata: " + model_description_from_ctx_cache_ + - " is different with target: " + model_description_ + - ". Please make sure the context binary file matches the model."); - - ORT_RETURN_IF(graph_name != graph_name_from_ctx_cache_ && get_capability_round_2_, - "Graph name from context cache metadata: " + graph_name_from_ctx_cache_ + - " is different with target: " + graph_name + - ". You may need to re-generate the context binary file."); - - get_capability_round_2_ = true; - return Status::OK(); -} - Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context) { if (backend_setup_completed_) { LOGS(logger, VERBOSE) << "Backend setup already!"; @@ -728,8 +540,9 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_ ORT_RETURN_IF_ERROR(LoadQnnSystemLib()); } + sdk_build_version_ = GetBackendBuildId(); LOGS(logger, VERBOSE) << "Backend build version: " - << GetBackendBuildId(); + << sdk_build_version_; SetLogger(&logger); LOGS(logger, VERBOSE) << "SetLogger succeed."; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 402f842c7a..8f4a0002dd 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -22,6 +22,7 @@ namespace onnxruntime { namespace qnn { class QnnModel; +class QnnCacheModelHandler; class QnnBackendManager { public: @@ -71,13 +72,11 @@ class QnnBackendManager { return CreateContext(); } - Status DumpQnnContext(const std::string& model_name, const std::string& graph_name); + std::unique_ptr GetContextBinaryBuffer(uint64_t& written_buffer_size); - Status LoadCachedQnnContext(QnnModel& qnn_model); - - Status GetMetadataFromOrtContextFile(); - - Status ValidateWithContextFile(const std::string& model_name, const std::string& graph_name); + Status LoadCachedQnnCtxFromOnnxModel(const std::string& ep_engine_cache, + QnnModel& qnn_model, + bool& loaded_from_cache); Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); @@ -91,14 +90,6 @@ class QnnBackendManager { const Qnn_ProfileHandle_t& GetQnnProfileHandle() { return profile_backend_handle_; } - std::string GetBackendBuildId() { - char* backend_build_id{nullptr}; - if (QNN_SUCCESS != qnn_interface_.backendGetBuildId((const char**)&backend_build_id)) { - LOGS(*logger_, ERROR) << "Unable to get build Id from the backend."; - } - return (backend_build_id == nullptr ? std::string("") : std::string(backend_build_id)); - } - void SetLogger(const logging::Logger* logger) { if (logger_ == nullptr) { logger_ = logger; @@ -133,9 +124,7 @@ class QnnBackendManager { void SetQnnBackendType(uint32_t backend_id); QnnBackendType GetQnnBackendType() { return qnn_backend_type_; } - bool IsContextCacheFileExists(const std::string& customer_context_cache_path, - const std::string& model_description, - const onnxruntime::PathString& model_pathstring); + const std::string& GetSdkVersion() { return sdk_build_version_; } private: void* LoadLib(const char* file_name, int flags, std::string& error_msg); @@ -177,6 +166,16 @@ class QnnBackendManager { return ret; } + std::string GetBackendBuildId() { + char* backend_build_id{nullptr}; + if (QNN_SUCCESS != qnn_interface_.backendGetBuildId((const char**)&backend_build_id)) { + LOGS(*logger_, ERROR) << "Unable to get build Id from the backend."; + } + return (backend_build_id == nullptr ? std::string("") : std::string(backend_build_id)); + } + + Status LoadCachedQnnContextFromBuffer(const std::string& buffer, QnnModel& qnn_model); + private: const std::string backend_path_; const logging::Logger* logger_ = nullptr; @@ -201,16 +200,7 @@ class QnnBackendManager { std::vector op_package_paths_; uint32_t rpc_control_latency_ = 0; HtpPerformanceMode htp_performance_mode_; - std::string model_name_from_ctx_cache_ = ""; - std::string graph_name_from_ctx_cache_ = ""; - std::string model_description_from_ctx_cache_ = ""; - std::string model_description_ = ""; - std::string context_cache_path_ = ""; - bool ctx_file_exists_ = false; - bool ctx_metadata_tried_ = false; - bool ort_generated_ctx_cache_ = false; - bool get_capability_round_2_ = false; - uint16_t ort_ctx_metadata_length_ = 0; + std::string sdk_build_version_ = ""; #ifdef _WIN32 std::set mod_handles_; #endif diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index db7196b4c2..0a458f2602 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -36,15 +36,16 @@ Status QnnModel::SetGraphInputOutputInfo(const GraphViewer& graph_viewer, initializer_inputs_.emplace(graph_ini.first); } auto input_defs = fused_node.InputDefs(); - ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(input_defs, inputs_info_, model_input_index_map_, true)); + ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(input_defs, input_names_, inputs_info_, model_input_index_map_, true)); auto output_defs = fused_node.OutputDefs(); - ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(output_defs, outputs_info_, model_output_index_map_)); + ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(output_defs, output_names_, outputs_info_, model_output_index_map_)); return Status::OK(); } Status QnnModel::ParseGraphInputOrOutput(ConstPointerContainer>& input_output_defs, + std::vector& input_output_names, std::unordered_map& input_output_info_table, std::unordered_map& input_output_index_map, bool is_input) { @@ -72,6 +73,7 @@ Status QnnModel::ParseGraphInputOrOutput(ConstPointerContainertensor_type().elem_type(); // use index i so that for graph input, it has initializers included input_output_info_table.emplace(std::piecewise_construct, std::forward_as_tuple(name), std::forward_as_tuple(i, data_type, std::move(shape))); + input_output_names.push_back(name); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 934980f05f..373995106f 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -47,6 +47,7 @@ class QnnModel { Status SetGraphInputOutputInfo(const GraphViewer& graph_viewer, const onnxruntime::Node& fused_node); Status ParseGraphInputOrOutput(ConstPointerContainer>& input_output_defs, + std::vector& input_output_names, std::unordered_map& input_output_info_table, std::unordered_map& input_output_index, bool is_input = false); @@ -74,6 +75,24 @@ class QnnModel { Status DeserializeGraphInfoFromBinaryInfo(const QnnSystemContext_GraphInfo_t& qnn_sys_ctx_graph_info); + const std::vector& GetInputNames() const { + return input_names_; + } + + const std::vector& GetOutputNames() const { + return output_names_; + } + + const std::unordered_map& GetInputsInfo() const { + return inputs_info_; + } + + const std::unordered_map& GetOutputsInfo() const { + return outputs_info_; + } + + const std::string& Name() { return graph_info_->Name(); } + private: const NodeUnit& GetNodeUnit(const Node* node, const std::unordered_map& node_unit_map) const; @@ -87,13 +106,13 @@ class QnnModel { QnnBackendType GetQnnBackendType() { return qnn_backend_type_; } - private: size_t GetInputOutputIndex(const std::string& name, const std::unordered_map& io_info) const { auto it = io_info.find(name); ORT_ENFORCE(it != io_info.end(), "Input/Output name not found."); return it->second.index_; } + private: const logging::Logger& logger_; std::unique_ptr graph_info_; QnnBackendManager* qnn_backend_manager_ = nullptr; @@ -102,6 +121,8 @@ class QnnModel { std::unordered_map model_output_index_map_; // TODO: remove initializer_inputs_, use QnnModelWrapper std::unordered_set initializer_inputs_; + std::vector input_names_; + std::vector output_names_; std::unordered_map inputs_info_; std::unordered_map outputs_info_; std::vector qnn_inputs_; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index ec5316eb13..6cd9cbac72 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -101,6 +101,14 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_; } + bool qnn_context_embed_mode = true; + static const std::string CONTEXT_CACHE_EMBED_MODE = "qnn_context_embed_mode"; + auto context_cache_embed_mode_pos = runtime_options_.find(CONTEXT_CACHE_EMBED_MODE); + if (context_cache_embed_mode_pos != runtime_options_.end()) { + qnn_context_embed_mode = context_cache_embed_mode_pos->second == "1"; + LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode; + } + static const std::string BACKEND_PATH = "backend_path"; auto backend_path_pos = runtime_options_.find(BACKEND_PATH); @@ -147,6 +155,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio rpc_control_latency_, htp_performance_mode_, std::move(qnn_saver_path)); + qnn_cache_model_handler_ = std::make_unique(qnn_context_embed_mode); } bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, @@ -262,10 +271,16 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer const auto& logger = *GetLogger(); bool load_from_cached_context = false; - if (context_cache_enabled_) { - load_from_cached_context = qnn_backend_manager_->IsContextCacheFileExists(context_cache_path_, - graph_viewer.Description(), - graph_viewer.ModelPath().ToPathString()); + bool is_qnn_ctx_model = qnn::IsQnnCtxModel(graph_viewer); + if (is_qnn_ctx_model) { + load_from_cached_context = true; + } + + // This is for case: QDQ model + Onnx Qnn context cache model + if (context_cache_enabled_ && !is_qnn_ctx_model) { + load_from_cached_context = qnn_cache_model_handler_->IsContextCacheFileExists(context_cache_path_, + graph_viewer.Description(), + graph_viewer.ModelPath().ToPathString()); } // Load from cached context will load the QnnSystem lib and skip the Qnn context creation @@ -275,7 +290,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer return result; } - if (context_cache_enabled_ && !IsNpuBackend(qnn_backend_manager_->GetQnnBackendType())) { + if ((context_cache_enabled_ || is_qnn_ctx_model) && !IsNpuBackend(qnn_backend_manager_->GetQnnBackendType())) { LOGS(logger, ERROR) << "Qnn context cache only works for HTP or DSP backend."; return result; } @@ -365,9 +380,10 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer const size_t num_of_partitions = result.size(); - if (load_from_cached_context && 1 == num_of_partitions) { - rt = qnn_backend_manager_->ValidateWithContextFile(GetFileNameFromModelPath(graph_viewer.ModelPath()), - result[0]->sub_graph->GetMetaDef()->name); + if (!is_qnn_ctx_model && load_from_cached_context && 1 == num_of_partitions) { + rt = qnn_cache_model_handler_->ValidateWithContextFile(GetFileNameFromModelPath(graph_viewer.ModelPath()), + result[0]->sub_graph->GetMetaDef()->name, + logger); if (Status::OK() != rt) { LOGS(logger, ERROR) << "QNN failed to validate context cache metadata: " << rt.ErrorMessage(); return result; @@ -447,22 +463,28 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { const auto& logger = *GetLogger(); + Node& fused_node = fused_nodes_and_graphs[0].fused_node; + const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[0].filtered_graph); - if (context_cache_enabled_) { + bool is_qnn_ctx_model = false; + ORT_RETURN_IF_ERROR(qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs, is_qnn_ctx_model)); + + if (context_cache_enabled_ || is_qnn_ctx_model) { ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support single partition for context cache feature."); - Node& fused_node = fused_nodes_and_graphs[0].fused_node; - const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[0].filtered_graph); - // The dumy_model_description won't be used since IsContextCacheFileExists call cached the result - // The graph_viewer.Description here is not same with original model - std::string dumy_model_description = ""; - bool load_from_cached_context = qnn_backend_manager_->IsContextCacheFileExists(context_cache_path_, - dumy_model_description, - graph_viewer.ModelPath().ToPathString()); + std::unique_ptr qnn_model = std::make_unique(logger, qnn_backend_manager_.get()); + bool loaded_from_cache = false; + std::string ep_engine_cache; + ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->GetEpContext(graph_viewer, + context_cache_path_, + is_qnn_ctx_model, + qnn_cache_model_handler_->GetIsContextCacheFileExists(), + ep_engine_cache, + logger)); + ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnCtxFromOnnxModel(ep_engine_cache, + *(qnn_model.get()), + loaded_from_cache)); // Load and execute from cached context if exist - if (load_from_cached_context) { - std::unique_ptr qnn_model = std::make_unique(logger, - qnn_backend_manager_.get()); - ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnContext(*(qnn_model.get()))); + if (loaded_from_cache) { ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); @@ -473,18 +495,22 @@ Status QNNExecutionProvider::Compile(const std::vector& fused ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); return Status::OK(); - } else { - // Load and execute from Onnx model if not exit and dump the context - ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger)); - // graph_viewer.Name() is generated in GetCapability, e.g QNN_[hash_id]_[id] - // dump graph_viewer.Name() as metadata in context cache binary file, so that we can validate it in GetCapability - ORT_RETURN_IF_ERROR(qnn_backend_manager_->DumpQnnContext(GetFileNameFromModelPath(graph_viewer.ModelPath()), - graph_viewer.Name())); } - return Status::OK(); } ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger)); + if (context_cache_enabled_ && !is_qnn_ctx_model) { + ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support single partition for context cache feature."); + uint64_t buffer_size(0); + auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); + ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->GenerateCtxCacheOnnxModel(context_buffer.get(), + buffer_size, + qnn_backend_manager_->GetSdkVersion(), + fused_nodes_and_graphs, + qnn_models_, + logger)); + } + qnn_cache_model_handler_.reset(); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 3827e2044e..c63a60018a 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -8,6 +8,7 @@ #include #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/providers/qnn/builder/qnn_model.h" +#include "core/providers/qnn/builder/onnx_ctx_model_helper.h" namespace onnxruntime { @@ -66,6 +67,7 @@ class QNNExecutionProvider : public IExecutionProvider { bool context_cache_enabled_ = false; std::string context_cache_path_ = ""; bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. + std::unique_ptr qnn_cache_model_handler_; }; } // namespace onnxruntime diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 287d657a2c..da67987f52 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -56,6 +56,8 @@ void usage() { "\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n" "\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" "\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" + "\t [QNN only] [qnn_context_embed_mode]: 1 means dump the QNN context binary into the Onnx skeleton model.\n" + "\t 0 means dump the QNN context binary into separate bin file and set the path in the Onnx skeleton model.\n" "\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" "\t [Usage]: -e -i '| |' \n\n" "\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n" @@ -454,6 +456,10 @@ int real_main(int argc, char* argv[], Ort::Env& env) { if (value.empty()) { ORT_THROW("Please provide the QNN backend path."); } + } else if (key == "qnn_context_embed_mode") { + if (value != "0") { + ORT_THROW("Set to 0 to disable qnn_context_embed_mode."); + } } else if (key == "qnn_context_cache_enable") { if (value != "1") { ORT_THROW("Set to 1 to enable qnn_context_cache_enable."); diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index b4c84d893c..396fc193bf 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -261,7 +261,8 @@ template inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTestQDQModelFn& qdq_model_fn, ProviderOptions qnn_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err = 1e-4f, - logging::Severity log_severity = logging::Severity::kERROR) { + logging::Severity log_severity = logging::Severity::kERROR, + const std::string& qnn_ctx_model_path = "") { // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; @@ -321,8 +322,21 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Run QDQ model on QNN EP and collect outputs. TryEnableQNNSaver(qnn_options); std::vector qnn_qdq_outputs; - InferenceModel(qdq_model_data, "qdq_model_logger", QnnExecutionProviderWithOptions(qnn_options), - expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs); + if (!qnn_ctx_model_path.empty()) { + onnx::ModelProto model_proto; + onnxruntime::Model qnn_ctx_model; + // Load the QNN context cache model from path specified + ASSERT_STATUS_OK(qnn_ctx_model.Load(ToPathString(qnn_ctx_model_path), model_proto)); + std::string qnn_ctx_model_data; + model_proto.SerializeToString(&qnn_ctx_model_data); + // Run QNN context cache model on QNN EP and collect outputs. + InferenceModel(qnn_ctx_model_data, "qnn_ctx_model_logger", QnnExecutionProviderWithOptions(qnn_options), + expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs); + } else { + // Run QDQ model on QNN EP and collect outputs. + InferenceModel(qdq_model_data, "qdq_model_logger", QnnExecutionProviderWithOptions(qnn_options), + expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs); + } if (expected_ep_assignment != ExpectedEPNodeAssignment::None) { // Run QDQ model on CPU EP and collect outputs. diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index f77c098f72..be8afa7636 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -679,10 +679,11 @@ TEST_F(QnnHTPBackendTests, SpaceToDepthOp_U16) { true); // Use com.microsoft domain for Q/DQ ops } -// Run QDQ model on HTP twice -// 1st run will generate the Qnn context cache binary file -// 2nd run will load and run from Qnn context cache binary file -TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) { +// Run QDQ model on HTP 3 times +// 1st run will generate the Qnn context cache onnx file +// 2nd run will load and run from QDQ model + Qnn context cache model +// 3rd run directly loads and run from Qnn context cache model +TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -690,7 +691,7 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) { provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["qnn_context_cache_enable"] = "1"; - const std::string context_binary_file = "./qnn_context_binary_test.bin"; + const std::string context_binary_file = "./qnn_context_binary_test.onnx"; provider_options["qnn_context_cache_path"] = context_binary_file; const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); @@ -707,12 +708,120 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) { // Make sure the Qnn context cache binary file is generated EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - // 2nd run will load and run from Qnn context cache binary file + // 2nd run loads and run from QDQ model + Qnn context cache model TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, 14, ExpectedEPNodeAssignment::All); + + // 3rd run directly loads and run from Qnn context cache model + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + 1e-4f, + logging::Severity::kERROR, + context_binary_file); +} + +// Run QDQ model on HTP 3 times +// 1st run will generate the Onnx skeleton file + Qnn context cache binary file +// 2nd run will loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file +// 3rd run directly loads and run from Onnx skeleton file + Qnn context cache binary file +TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["qnn_context_cache_enable"] = "1"; + const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; + provider_options["qnn_context_cache_path"] = context_binary_file; + provider_options["qnn_context_embed_mode"] = "0"; + + const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Atan"; + + // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All); + + // Check the Onnx skeleton file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + // Check the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists("qnn_context_cache_non_embed.onnx_QNN_8283143575221199085_1.bin")); + + // 2nd run loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All); + + // 3rd run directly loads and run from Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + 1e-4f, + logging::Severity::kERROR, + context_binary_file); +} + +// Run QDQ model on HTP with 2 inputs +// 1st run will generate the Qnn context cache onnx file +// 2nd run will load and run from QDQ model + Qnn context cache model +// 3rd run directly loads and run from Qnn context cache model +TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["qnn_context_cache_enable"] = "1"; + const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; + provider_options["qnn_context_cache_path"] = context_binary_file; + + const TestInputDef input_def1({1, 2, 3}, false, -10.0f, 10.0f); + const TestInputDef input_def2({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Add"; + + // Runs model with DQ-> Add-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + // 2nd run loads and run from QDQ model + Qnn context cache model + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All); + + // 3rd run directly loads and run from Qnn context cache model + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + 1e-4f, + logging::Severity::kERROR, + context_binary_file); } TEST_F(QnnHTPBackendTests, QuantAccuracyTest) { diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index f678b18ba9..491c896de8 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -110,7 +110,7 @@ jobs: inputs: script: | ./build/Release/onnx_test_runner -e qnn \ - -v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so qnn_context_cache_enable|1 qnn_context_cache_path|./build/Release/mobilenet_qdq.bin" \ + -v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so qnn_context_cache_enable|1 qnn_context_cache_path|./build/Release/mobilenet_qdq.onnx_qnn_ctx.onnx" \ /data/qdq_models/mobilenetv2-1.0_add_transpose_quant - task: CmdLine@2 @@ -118,5 +118,5 @@ jobs: inputs: script: | ./build/Release/onnx_test_runner -e qnn \ - -v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so qnn_context_cache_enable|1 qnn_context_cache_path|./build/Release/mobilenet_qdq.bin" \ + -v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so qnn_context_cache_enable|1 qnn_context_cache_path|./build/Release/mobilenet_qdq.onnx_qnn_ctx.onnx" \ /data/qdq_models/mobilenetv2-1.0_add_transpose_quant