diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index ebee7aa4b5..bc150c3774 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -316,6 +316,26 @@ Status QnnBackendManager::ReleaseContext() { return Status::OK(); } +bool QnnBackendManager::IsContextCacheFileExists(const std::string& customer_context_cache_path, + const std::string& model_description, + const onnxruntime::PathString& model_pathstring) { + // Avoid duplicate work + if (!context_cache_path_.empty()) { + return ctx_file_exists_; + } + model_description_ = model_description; + // Use user provided context cache file path if exist, otherwise try model_file.onnx.bin by default + if (customer_context_cache_path.empty()) { + context_cache_path_ = PathToUTF8String(model_pathstring) + ".bin"; + } else { + context_cache_path_ = customer_context_cache_path; + } + + ctx_file_exists_ = std::filesystem::exists(context_cache_path_); + + return ctx_file_exists_; +} + Status WriteInt16ToBinaryFile(std::ofstream& of_stream, uint16_t value) { const std::vector data{value}; std::vector data_bytes(sizeof(uint16_t) / sizeof(unsigned char)); @@ -324,9 +344,7 @@ Status WriteInt16ToBinaryFile(std::ofstream& of_stream, uint16_t value) { return Status::OK(); } -Status QnnBackendManager::DumpQnnContext(const onnxruntime::PathString& context_cache_pathstring, - const std::string& model_name, - const std::string& graph_name) { +Status QnnBackendManager::DumpQnnContext(const std::string& model_name, const std::string& graph_name) { if (nullptr == qnn_interface_.contextGetBinarySize || nullptr == qnn_interface_.contextGetBinary) { LOGS(*logger_, ERROR) << "Failed to get valid function pointer."; @@ -362,7 +380,7 @@ Status QnnBackendManager::DumpQnnContext(const onnxruntime::PathString& context_ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Context written buffer exceeds allocated buffer size."); } - std::ofstream of_stream(context_cache_pathstring.c_str(), std::ofstream::binary); + std::ofstream of_stream(context_cache_path_.c_str(), std::ofstream::binary); if (!of_stream) { LOGS(*logger_, ERROR) << "Failed to open cached context file."; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to open context cache file."); @@ -371,7 +389,10 @@ Status QnnBackendManager::DumpQnnContext(const onnxruntime::PathString& context_ // Write Ort metadata into context binary file uint16_t model_name_length = static_cast(model_name.length()); uint16_t graph_name_length = static_cast(graph_name.length()); - uint16_t header_length = 3 * sizeof(uint16_t) + model_name_length + graph_name_length; + uint16_t model_description_length = static_cast(model_description_.length()); + + // Header: uint16_t(totale_length)|uint16_t(model_name_length)|model_name|uint16_t(graph_name_length)|graph_name|uint16_t(model_description_length)|model_description + uint16_t header_length = 4 * sizeof(uint16_t) + model_name_length + graph_name_length + model_description_length; uint16_t totale_length = header_length + static_cast(strlen(QNN_PROVIDER)); of_stream.write(QNN_PROVIDER, strlen(QNN_PROVIDER)); @@ -382,6 +403,11 @@ Status QnnBackendManager::DumpQnnContext(const onnxruntime::PathString& context_ ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, graph_name_length)); of_stream.write(graph_name.c_str(), graph_name_length); + + ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, model_description_length)); + of_stream.write(model_description_.c_str(), model_description_length); + model_description_.clear(); + LOGS(*logger_, VERBOSE) << "Dump metadata with length: " << totale_length; of_stream.write(reinterpret_cast(context_buffer.get()), written_buffer_size); @@ -390,14 +416,16 @@ Status QnnBackendManager::DumpQnnContext(const onnxruntime::PathString& context_ return Status::OK(); } -Status QnnBackendManager::LoadCachedQnnContext(const onnxruntime::PathString& context_cache_pathstring, QnnModel& qnn_model) { +Status QnnBackendManager::LoadCachedQnnContext(QnnModel& qnn_model) { bool result = nullptr == qnn_sys_interface_.systemContextCreate || nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || nullptr == qnn_sys_interface_.systemContextFree; ORT_RETURN_IF(result, "Failed to get valid function pointer."); + ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!"); + uint64_t buffer_size{0}; - std::ifstream cache_file(context_cache_pathstring.c_str(), std::ifstream::binary); + std::ifstream cache_file(context_cache_path_.c_str(), std::ifstream::binary); ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file."); cache_file.seekg(0, cache_file.end); buffer_size = cache_file.tellg(); @@ -466,6 +494,8 @@ Status QnnBackendManager::LoadCachedQnnContext(const onnxruntime::PathString& co ORT_RETURN_IF_ERROR(ExtractBackendProfilingInfo()); context_created_ = true; + model_description_.clear(); + model_description_from_ctx_cache_.clear(); LOGS(*logger_, VERBOSE) << "Load from cached QNN Context completed."; return Status::OK(); } @@ -502,12 +532,11 @@ Status ReadInt16FromBinaryFile(std::ifstream& binary_file, uint16_t& value) { } /* \brief: Try to get metadata from Ort generated context cache binary file. - * \param[in] context_cache_pathstring - context cache binary file path string * Cached context binary file generated by Ort has some metadata which can be used for validation with the model * to avoid user choose a wrong context binary file which is not for this model * It is treated as Qnn generated context binary file if no metadata found from the file */ -Status QnnBackendManager::GetMetadataFromOrtContextFile(const onnxruntime::PathString& context_cache_pathstring) { +Status QnnBackendManager::GetMetadataFromOrtContextFile() { // Only try parse meta data once if (ctx_metadata_tried_) { return Status::OK(); @@ -515,7 +544,7 @@ Status QnnBackendManager::GetMetadataFromOrtContextFile(const onnxruntime::PathS ctx_metadata_tried_ = true; uint64_t buffer_size = 0; - std::ifstream cache_file(context_cache_pathstring.c_str(), std::ifstream::binary); + std::ifstream cache_file(context_cache_path_.c_str(), std::ifstream::binary); ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open context cache file."); cache_file.seekg(0, cache_file.end); buffer_size = cache_file.tellg(); @@ -533,17 +562,18 @@ Status QnnBackendManager::GetMetadataFromOrtContextFile(const onnxruntime::PathS } ort_generated_ctx_cache_ = true; - uint16_t header_length = 0; - ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, header_length)); - ort_ctx_metadata_length_ = header_length + static_cast(ort_flag_length); + uint16_t str_length = 0; + ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length)); + ort_ctx_metadata_length_ = str_length + static_cast(ort_flag_length); - uint16_t model_name_length = 0; - ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, model_name_length)); - ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, model_name_from_ctx_cache_, static_cast(model_name_length))); + ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length)); + ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, model_name_from_ctx_cache_, static_cast(str_length))); - uint16_t graph_name_length = 0; - ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, graph_name_length)); - ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, graph_name_from_ctx_cache_, static_cast(graph_name_length))); + ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length)); + ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, graph_name_from_ctx_cache_, static_cast(str_length))); + + ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length)); + ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, model_description_from_ctx_cache_, static_cast(str_length))); return Status::OK(); } @@ -555,15 +585,30 @@ Status QnnBackendManager::GetMetadataFromOrtContextFile(const onnxruntime::PathS * so only validate the graph name for 2nd call */ Status QnnBackendManager::ValidateWithContextFile(const std::string& model_name, const std::string& graph_name) { + ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!"); + + // Get metadata from cached context binary file + ORT_RETURN_IF_ERROR(GetMetadataFromOrtContextFile()); + + // The context binary file doesn't have ORT metadata, so it is generated from QNN toolchain not from ORT if (!ort_generated_ctx_cache_) { return Status::OK(); } ORT_RETURN_IF(model_name != model_name_from_ctx_cache_, - "Model file name from context cache metadata: " + model_name_from_ctx_cache_ + " is different with target: " + model_name); + "Model file name from context cache metadata: " + model_name_from_ctx_cache_ + + " is different with target: " + model_name + + ". Please make sure the context binary file matches the model."); + + ORT_RETURN_IF(model_description_ != model_description_from_ctx_cache_, + "Model description from context cache metadata: " + model_description_from_ctx_cache_ + + " is different with target: " + model_description_ + + ". Please make sure the context binary file matches the model."); ORT_RETURN_IF(graph_name != graph_name_from_ctx_cache_ && get_capability_round_2_, - "Graph name from context cache metadata: " + graph_name_from_ctx_cache_ + " is different with target: " + graph_name); + "Graph name from context cache metadata: " + graph_name_from_ctx_cache_ + + " is different with target: " + graph_name + + ". You may need to re-generate the context binary file."); get_capability_round_2_ = true; return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 7f15d3ed49..81f9816829 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -69,13 +69,11 @@ class QnnBackendManager { return CreateContext(); } - Status DumpQnnContext(const onnxruntime::PathString& context_cache_pathstring, - const std::string& model_name, - const std::string& graph_name); + Status DumpQnnContext(const std::string& model_name, const std::string& graph_name); - Status LoadCachedQnnContext(const onnxruntime::PathString& context_cache_pathstring, QnnModel& qnn_model); + Status LoadCachedQnnContext(QnnModel& qnn_model); - Status GetMetadataFromOrtContextFile(const onnxruntime::PathString& model_path); + Status GetMetadataFromOrtContextFile(); Status ValidateWithContextFile(const std::string& model_name, const std::string& graph_name); @@ -133,6 +131,10 @@ class QnnBackendManager { // NPU backend requires quantized model bool IsNpuBackend() { return is_npu_backend_; } + bool IsContextCacheFileExists(const std::string& customer_context_cache_path, + const std::string& model_description, + const onnxruntime::PathString& model_pathstring); + private: void* LoadLib(const char* file_name, int flags, std::string& error_msg); @@ -197,6 +199,10 @@ class QnnBackendManager { HtpPerformanceMode htp_performance_mode_; std::string model_name_from_ctx_cache_ = ""; std::string graph_name_from_ctx_cache_ = ""; + std::string model_description_from_ctx_cache_ = ""; + std::string model_description_ = ""; + std::string context_cache_path_ = ""; + bool ctx_file_exists_ = false; bool ctx_metadata_tried_ = false; bool ort_generated_ctx_cache_ = false; bool get_capability_round_2_ = false; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 999ac0ff41..f43b7029ca 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -276,18 +276,9 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer const auto& logger = *GetLogger(); bool load_from_cached_context = false; if (context_cache_enabled_) { - onnxruntime::PathString context_cache_pathstring; - load_from_cached_context = IsContextCacheFileExists(graph_viewer.ModelPath().ToPathString(), - context_cache_pathstring); - - // Get metadata from cached context binary file - if (load_from_cached_context) { - auto rt = qnn_backend_manager_->GetMetadataFromOrtContextFile(context_cache_pathstring); - if (Status::OK() != rt) { - LOGS(logger, ERROR) << "Failed to get metadata from cached context binary file. " << rt.ErrorMessage(); - return result; - } - } + load_from_cached_context = qnn_backend_manager_->IsContextCacheFileExists(context_cache_path_, + graph_viewer.Description(), + graph_viewer.ModelPath().ToPathString()); } // Load from cached context will load the QnnSystem lib and skip the Qnn context creation @@ -444,19 +435,6 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { const auto& logger = *GetLogger(); @@ -466,15 +444,18 @@ Status QNNExecutionProvider::Compile(const std::vector& fused ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support singel partition for context cache feature."); Node& fused_node = fused_nodes_and_graphs[0].fused_node; const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[0].filtered_graph); - onnxruntime::PathString context_cache_pathstring; - bool load_from_cached_context = IsContextCacheFileExists(graph_viewer.ModelPath().ToPathString(), - context_cache_pathstring); + // The dumy_model_description won't be used since IsContextCacheFileExists call cached the result + // The graph_viewer.Description here is not same with original model + std::string dumy_model_description = ""; + bool load_from_cached_context = qnn_backend_manager_->IsContextCacheFileExists(context_cache_path_, + dumy_model_description, + graph_viewer.ModelPath().ToPathString()); // Load and execute from cached context if exist if (load_from_cached_context) { std::unique_ptr qnn_model = std::make_unique(logger, qnn_backend_manager_.get(), is_npu_backend); - ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnContext(context_cache_pathstring, *(qnn_model.get()))); + ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnContext(*(qnn_model.get()))); ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); @@ -490,8 +471,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger)); // graph_viewer.Name() is generated in GetCapability, e.g QNN_[hash_id]_[id] // dump graph_viewer.Name() as metadata in context cache binary file, so that we can validate it in GetCapability - ORT_RETURN_IF_ERROR(qnn_backend_manager_->DumpQnnContext(context_cache_pathstring, - GetFileNameFromModelPath(graph_viewer.ModelPath()), + ORT_RETURN_IF_ERROR(qnn_backend_manager_->DumpQnnContext(GetFileNameFromModelPath(graph_viewer.ModelPath()), graph_viewer.Name())); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 2f6166d46e..2fe507b70a 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -54,9 +54,6 @@ class QNNExecutionProvider : public IExecutionProvider { std::vector& node_compute_funcs, const logging::Logger& logger); - bool IsContextCacheFileExists(const onnxruntime::PathString& model_pathstring, - onnxruntime::PathString& context_cache_pathstring) const; - void ParseHtpPerformanceMode(std::string htp_performance_mode_string); private: diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 8ef29efdf0..74b9e24630 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -156,7 +156,6 @@ TEST_F(QnnHTPBackendTests, TestQDQAtanTest) { // 1st run will generate the Qnn context cache binary file // 2nd run will load and run from Qnn context cache binary file TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) { - RunQDQSingleInputOpTest({1, 2, 3}, "Atan", "TestQDQGeluTest", 11, ExpectedEPNodeAssignment::All, 1); ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll";