diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index efe628b354..018e1ddc81 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -275,6 +275,7 @@ static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable"; // Specify the file path for the Onnx model which has EP context. // Default to original_file_name_ctx.onnx if not specified +// Folder is not a valid option static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path"; // Flag to specify whether to dump the EP context into the Onnx model. diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 7c980a1aeb..c02e6cf3af 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -643,6 +643,29 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide return Status::OK(); } +// Validate the ep_context_path to make sure it is file path and check whether the file exist already +static Status EpContextFilePathCheck(const std::string& ep_context_path, + const std::filesystem::path& model_path) { + std::filesystem::path context_cache_path; + if (!ep_context_path.empty()) { + context_cache_path = ep_context_path; + if (!context_cache_path.has_filename()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "context_file_path should not point to a folder."); + } + } else if (!model_path.empty()) { + context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx"); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty."); + } + + if (std::filesystem::exists(context_cache_path)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '", + context_cache_path, "' exist already."); + } + + return Status::OK(); +} + static Status CreateEpContextModel(const ExecutionProviders& execution_providers, const Graph& graph, const std::filesystem::path& ep_context_path, @@ -678,11 +701,6 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty"); } - if (std::filesystem::exists(context_cache_path)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '", - context_cache_path, "' exist already."); - } - Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(), graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path IOnnxRuntimeOpSchemaRegistryList{graph.GetSchemaRegistry()}, @@ -1007,9 +1025,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { #if !defined(ORT_MINIMAL_BUILD) + bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + if (ep_context_enabled) { + std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + // Check before EP compile graphs + ORT_RETURN_IF_ERROR(EpContextFilePathCheck(ep_context_path, graph.ModelPath())); + } + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger)); - bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; if (ep_context_enabled) { std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); std::string external_ini_file_name = config_options.GetConfigOrDefault(kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 8670fbbe2b..29a916997b 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -252,6 +252,113 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCpuNodeWithoutExternalWeights) { EpCtxCpuNodeWithExternalIniFileTestBody(false); } +// Set ep.context_file_path to folder path which is not a valid option, check the error message +TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationFolderPathNotExpected) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["offload_graph_io_quantization"] = "0"; + + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + bool single_ep_node = true; + BuildGraphWithQAndNonQ(single_ep_node)(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + + const std::string ep_context_onnx_file = "./ep_context_folder_not_expected/"; + std::remove(ep_context_onnx_file.c_str()); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str()); + so.AppendExecutionProvider("QNN", provider_options); + + try { + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); + FAIL(); // Should not get here! + } catch (const Ort::Exception& excpt) { + ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_INVALID_ARGUMENT); + ASSERT_THAT(excpt.what(), testing::HasSubstr("context_file_path should not point to a folder.")); + } +} + +// Create session 1 to generate context binary file +// Create session 2 to do same thing, make sure session 2 failed because file exist already +// Make sure no new file over write from session 2 +TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationNoOverWrite) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["offload_graph_io_quantization"] = "0"; + + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + bool single_ep_node = true; + BuildGraphWithQAndNonQ(single_ep_node)(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + + const std::string ep_context_onnx_file = "./ep_context_no_over_write.onnx"; + const std::string ep_context_binary_file = "./ep_context_no_over_write.onnx_QNNExecutionProvider_QNN_10880527342279992768_1_0.bin"; + + std::remove(ep_context_onnx_file.c_str()); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str()); + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session1(*ort_env, model_data_span.data(), model_data_span.size(), so); + + auto modify_time_1 = std::filesystem::last_write_time(ep_context_binary_file); + + try { + Ort::Session session2(*ort_env, model_data_span.data(), model_data_span.size(), so); + FAIL(); // Should not get here! + } catch (const Ort::Exception& excpt) { + ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_FAIL); + ASSERT_THAT(excpt.what(), testing::HasSubstr("exist already.")); + auto modify_time_2 = std::filesystem::last_write_time(ep_context_binary_file); + ASSERT_EQ(modify_time_1, modify_time_2); + } + + ASSERT_EQ(std::remove(ep_context_onnx_file.c_str()), 0); + ASSERT_EQ(std::remove(ep_context_binary_file.c_str()), 0); +} + // Create a model with Case + Add (quantized) // cast_input -> Cast -> Q -> DQ \ // Add -> Q -> DQ -> output