From 35ecce45496ee752bc3f85618eba713bb50e6069 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Wed, 18 Oct 2023 15:30:35 -0700 Subject: [PATCH] [QNN EP] Reduce overhead of QNN context binary loading (#17965) ### Description Reduce overhead of QNN context binary loading by avoiding memory copy ### Motivation and Context Reduce the session initialization time and memory usage while load from QNN context binary --- .../qnn/builder/onnx_ctx_model_helper.cc | 40 +++++++++++------- .../qnn/builder/onnx_ctx_model_helper.h | 40 ++++++++++-------- .../qnn/builder/qnn_backend_manager.cc | 23 +++-------- .../qnn/builder/qnn_backend_manager.h | 6 +-- .../providers/qnn/qnn_execution_provider.cc | 41 ++++++++----------- .../core/providers/shared/utils/utils.cc | 2 +- .../core/providers/shared/utils/utils.h | 2 +- 7 files changed, 74 insertions(+), 80 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 32b6a38793..7ccd765e65 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -60,26 +60,32 @@ Status CreateNodeArgs(const std::vector& names, return Status::OK(); } -Status GetEpContextFromModel(const std::string& ctx_onnx_model_path, - std::string& ep_cache_context, - const logging::Logger& logger) { +Status QnnCacheModelHandler::GetEpContextFromModel(const std::string& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + QnnModel& qnn_model, + const logging::Logger& logger) { using namespace onnxruntime; std::shared_ptr model; ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger)); const auto& graph = model->MainGraph(); - ORT_RETURN_IF_ERROR(GetEpContextFromGraph(GraphViewer(graph), ctx_onnx_model_path, ep_cache_context)); - - return Status::OK(); + return GetEpContextFromGraph(GraphViewer(graph), + ctx_onnx_model_path, + qnn_backend_manager, + qnn_model); } -Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, - const std::string& ctx_onnx_model_path, - std::string& ep_cache_context) { +Status QnnCacheModelHandler::GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, + const std::string& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + QnnModel& qnn_model) { const auto& node = graph_viewer.Nodes().begin(); NodeAttrHelper node_helper(*node); bool is_embed_mode = node_helper.Get(EMBED_MODE, true); if (is_embed_mode) { - ep_cache_context = node_helper.Get(EP_CACHE_CONTEXT, ""); + const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, ""); + return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()), + static_cast(context_binary.length()), + qnn_model); } else { std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, ""); @@ -88,17 +94,23 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, size_t buffer_size{0}; std::ifstream cache_file(context_binary_path.c_str(), std::ifstream::binary); ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file."); + cache_file.seekg(0, cache_file.end); buffer_size = static_cast(cache_file.tellg()); ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered."); + cache_file.seekg(0, cache_file.beg); - ep_cache_context.reserve(buffer_size); + std::unique_ptr buffer = std::make_unique(buffer_size); + ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for cache file."); // Load file into buffer - ep_cache_context.assign(std::istreambuf_iterator(cache_file), std::istreambuf_iterator()); + const auto& read_result = cache_file.read(buffer.get(), buffer_size); + ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file."); cache_file.close(); - ORT_RETURN_IF(ep_cache_context.length() != buffer_size, "Failed to read contents from cached context file."); + return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(), + static_cast(buffer_size), + qnn_model); } - ORT_RETURN_IF(ep_cache_context.empty(), "Cached context empty."); + return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 1ff8d00a0e..e9ca87a679 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -18,6 +18,7 @@ namespace onnxruntime { namespace qnn { class QnnModel; +class QnnBackendManager; static const std::string EPCONTEXT_OP = "EPContext"; static const std::string MAIN_CONTEXT = "main_context"; @@ -37,32 +38,24 @@ Status CreateNodeArgs(const std::vector& names, std::vector& node_args, onnxruntime::Graph& graph); -Status GetEpContextFromModel(const std::string& ctx_onnx_model_path, - std::string& ep_engine_cache, - const logging::Logger& logger); - -Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, - const std::string& ctx_onnx_model_path, - std::string& ep_cache_context); - class QnnCacheModelHandler { public: QnnCacheModelHandler(bool qnn_context_embed_mode) : qnn_context_embed_mode_(qnn_context_embed_mode) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnCacheModelHandler); - Status GetEpContext(const onnxruntime::GraphViewer& graph_viewer, - const std::string& ctx_onnx_model_path, - bool is_qnn_ctx_model, - bool is_ctx_cache_file_exist, - std::string& ep_engine_cache, - const logging::Logger& logger) const { + Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, + const std::string& ctx_onnx_model_path, + bool is_qnn_ctx_model, + bool is_ctx_cache_file_exist, + QnnBackendManager* qnn_backend_manager, + QnnModel& qnn_model, + const logging::Logger& logger) { if (is_qnn_ctx_model) { - ORT_RETURN_IF_ERROR(GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, ep_engine_cache)); + return GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model); } else if (is_ctx_cache_file_exist) { - ORT_RETURN_IF_ERROR(GetEpContextFromModel(ctx_onnx_model_path, ep_engine_cache, logger)); + return GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger); } - return Status::OK(); } @@ -92,12 +85,25 @@ class QnnCacheModelHandler { const std::unordered_map>& qnn_models, const logging::Logger& logger); + private: + Status GetEpContextFromModel(const std::string& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + QnnModel& qnn_model, + const logging::Logger& logger); + + Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, + const std::string& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + QnnModel& qnn_model); + private: bool is_metadata_ready_ = false; + // model_name_ to cache_source_ -- metadata get from generated Qnn context binary Onnx model std::string model_name_ = ""; std::string model_description_ = ""; std::string graph_partition_name_ = ""; std::string cache_source_ = ""; + std::string context_cache_path_ = ""; bool ctx_file_exists_ = false; bool get_capability_round_2_ = false; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 043b8a7900..f8ee0f225f 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -448,20 +448,7 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6 return context_buffer; } -Status QnnBackendManager::LoadCachedQnnCtxFromOnnxModel(const std::string& ep_engine_cache, - QnnModel& qnn_model, - bool& loaded_from_cache) { - loaded_from_cache = false; - - if (!ep_engine_cache.empty()) { - ORT_RETURN_IF_ERROR(LoadCachedQnnContextFromBuffer(ep_engine_cache, qnn_model)); - loaded_from_cache = true; - } - - return Status::OK(); -} - -Status QnnBackendManager::LoadCachedQnnContextFromBuffer(const std::string& buffer, QnnModel& qnn_model) { +Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model) { bool result = nullptr == qnn_sys_interface_.systemContextCreate || nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || nullptr == qnn_sys_interface_.systemContextFree; @@ -474,8 +461,8 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(const std::string& buff const QnnSystemContext_BinaryInfo_t* binary_info = nullptr; Qnn_ContextBinarySize_t binary_info_size{0}; rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle, - static_cast(const_cast(buffer.c_str())), - static_cast(buffer.length()), + static_cast(buffer), + buffer_length, &binary_info, &binary_info_size); ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info."); @@ -502,8 +489,8 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(const std::string& buff rt = qnn_interface_.contextCreateFromBinary(backend_handle_, device_handle_, (const QnnContext_Config_t**)&context_config_, - static_cast(const_cast(buffer.c_str())), - static_cast(buffer.length()), + static_cast(buffer), + buffer_length, &context_, profile_backend_handle_); ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary."); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 8f4a0002dd..9cb6a32214 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -74,9 +74,7 @@ class QnnBackendManager { std::unique_ptr GetContextBinaryBuffer(uint64_t& written_buffer_size); - Status LoadCachedQnnCtxFromOnnxModel(const std::string& ep_engine_cache, - QnnModel& qnn_model, - bool& loaded_from_cache); + Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model); Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); @@ -174,8 +172,6 @@ class QnnBackendManager { return (backend_build_id == nullptr ? std::string("") : std::string(backend_build_id)); } - Status LoadCachedQnnContextFromBuffer(const std::string& buffer, QnnModel& qnn_model); - private: const std::string backend_path_; const logging::Logger* logger_ = nullptr; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index b456be9241..943dc8128a 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -483,33 +483,28 @@ Status QNNExecutionProvider::Compile(const std::vector& fused bool is_qnn_ctx_model = false; ORT_RETURN_IF_ERROR(qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs, is_qnn_ctx_model)); - if (context_cache_enabled_ || is_qnn_ctx_model) { + bool is_ctx_file_exist = qnn_cache_model_handler_->GetIsContextCacheFileExists(); + if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) { ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support single partition for context cache feature."); std::unique_ptr qnn_model = std::make_unique(logger, qnn_backend_manager_.get()); - bool loaded_from_cache = false; - std::string ep_engine_cache; - ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->GetEpContext(graph_viewer, - context_cache_path_, - is_qnn_ctx_model, - qnn_cache_model_handler_->GetIsContextCacheFileExists(), - ep_engine_cache, - logger)); - ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnCtxFromOnnxModel(ep_engine_cache, - *(qnn_model.get()), - loaded_from_cache)); // Load and execute from cached context if exist - if (loaded_from_cache) { - ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); - ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); + ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->LoadQnnCtxFromOnnxModel(graph_viewer, + context_cache_path_, + is_qnn_ctx_model, + is_ctx_file_exist, + qnn_backend_manager_.get(), + *(qnn_model.get()), + logger)); + ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); + ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); - // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] - // the name here should be same with context->node_name in compute_info - LOGS(logger, VERBOSE) << "fused node name: " << fused_node.Name(); - qnn_models_.emplace(fused_node.Name(), std::move(qnn_model)); + // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] + // the name here should be same with context->node_name in compute_info + LOGS(logger, VERBOSE) << "fused node name: " << fused_node.Name(); + qnn_models_.emplace(fused_node.Name(), std::move(qnn_model)); - ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); - return Status::OK(); - } + ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); + return Status::OK(); } ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger)); @@ -524,8 +519,6 @@ Status QNNExecutionProvider::Compile(const std::vector& fused qnn_models_, logger)); } - qnn_cache_model_handler_.reset(); - return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index 7919f08166..6b1207d3d1 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -119,7 +119,7 @@ int64_t NodeAttrHelper::Get(const std::string& key, int64_t def_val) const { return node_attributes_.at(key).i(); } -std::string NodeAttrHelper::Get(const std::string& key, const std::string& def_val) const { +const std::string& NodeAttrHelper::Get(const std::string& key, const std::string& def_val) const { if (!HasAttr(key)) return def_val; diff --git a/onnxruntime/core/providers/shared/utils/utils.h b/onnxruntime/core/providers/shared/utils/utils.h index 744c8779c4..db07938c18 100644 --- a/onnxruntime/core/providers/shared/utils/utils.h +++ b/onnxruntime/core/providers/shared/utils/utils.h @@ -44,7 +44,7 @@ class NodeAttrHelper { int64_t Get(const std::string& key, int64_t def_val) const; - std::string Get(const std::string& key, const std::string& def_val) const; + const std::string& Get(const std::string& key, const std::string& def_val) const; std::vector Get(const std::string& key, const std::vector& def_val) const; std::vector Get(const std::string& key, const std::vector& def_val) const;