From fb61e14153b6a1263c15ea3b62d6bbbc5bde9848 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Fri, 26 Jul 2024 16:56:44 -0700 Subject: [PATCH] Add QNN EP option context_node_name_prefix to set EPContext node name prefix (#21236) ### Description Add QNN EP option context_node_name_prefix to set EPContext node name prefix ### Motivation and Context For the case to workaround QNN context PD memory limit, user need split the model into pieces and generate the QNN context model separately. It could happen that the generated EPContext node in separate graph has same node name. This will cause issue if glue those EPContext nodes together into a single model. To avoid this user can set this context_node_name_prefix for each split pieces to make the node name unique. --- .../onnxruntime_session_options_config_keys.h | 4 ++ .../providers/qnn/qnn_execution_provider.cc | 9 ++++- .../providers/qnn/qnn_execution_provider.h | 1 + .../test/providers/qnn/qnn_ep_context_test.cc | 39 +++++++++++++++++++ 4 files changed, 52 insertions(+), 1 deletion(-) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 17ae649e6f..209fd4279c 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -265,6 +265,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p // "1": dump the EP context into the Onnx model. (default). static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; +// Specify the EPContext node name prefix to make it unique +// in case user need to merge/connect multiple EPContext nodes in one model +static const char* const kOrtSessionOptionEpContextNodeNamePrefix = "ep.context_node_name_prefix"; + // Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul. // Option values: // - "0": Gemm FastMath mode is not enabled. [DEFAULT] diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 539b456cb6..c56a47e674 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -199,6 +199,13 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio context_cache_path_cfg_ = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_cfg_; + + // For the case that workaround QNN context PD memory limit, user need split the model into pieces and + // generate the QNN context model separately. + // It could happen that the generated EPContext node in separate graph has same node name. + // User can set this context_node_name_prefix for each split pieces to avoid that happens. + context_node_name_prefix_ = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextNodeNamePrefix, ""); + LOGS_DEFAULT(VERBOSE) << "User specified QNN context node name prefix: " << context_node_name_prefix_; } static const std::string BACKEND_PATH = "backend_path"; @@ -613,7 +620,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer const auto gen_metadef_name = [&]() { uint64_t model_hash; int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); - return MakeString(QNN, "_", model_hash, "_", metadef_id); + return MakeString(QNN, context_node_name_prefix_, "_", model_hash, "_", metadef_id); }; // For model with EPContext, make sure each partition only has one single EPContext node diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index e7419dabb1..f00ffb6cfd 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -80,6 +80,7 @@ class QNNExecutionProvider : public IExecutionProvider { std::unordered_map> qnn_models_; bool context_cache_enabled_ = false; std::string context_cache_path_cfg_ = ""; + std::string context_node_name_prefix_ = ""; bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. bool qnn_context_embed_mode_ = true; int32_t vtcm_size_in_mb_ = 0; diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index a3768cb98f..be3bd2cc5d 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -279,6 +279,45 @@ TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } +TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + std::string node_name_prefix = "node_name_prefix_test"; + + // Add kMSDomain to cover contrib op like Gelu + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextNodeNamePrefix, node_name_prefix.c_str()); + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + for (auto& node : model->MainGraph().Nodes()) { + if (node.OpType() == "EPContext") { + EXPECT_TRUE(node.Name().find(node_name_prefix) != std::string::npos); + } + } + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + // Run QDQ model on HTP 3 times // 1st run will generate the Qnn context cache onnx file // 2nd run directly loads and run from Qnn context cache model