From 002916acb05f0d7b367b5358ebe6af155595ecef Mon Sep 17 00:00:00 2001 From: Hector Li Date: Fri, 7 Feb 2025 21:31:11 -0800 Subject: [PATCH] Validate the context_file_path before EP compile graphs (#23611) Validate the context_file_path before EP compile graphs to make it fail fast. To avoid the possibility that EP generate new file (context binary file or blob file) over write the existing file. Return error if the path points to folder. --- .../onnxruntime_session_options_config_keys.h | 1 + .../core/framework/graph_partitioner.cc | 36 +++++- .../test/providers/qnn/qnn_ep_context_test.cc | 107 ++++++++++++++++++ 3 files changed, 138 insertions(+), 6 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index efe628b354..018e1ddc81 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -275,6 +275,7 @@ static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable"; // Specify the file path for the Onnx model which has EP context. // Default to original_file_name_ctx.onnx if not specified +// Folder is not a valid option static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path"; // Flag to specify whether to dump the EP context into the Onnx model. diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 7c980a1aeb..c02e6cf3af 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -643,6 +643,29 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide return Status::OK(); } +// Validate the ep_context_path to make sure it is file path and check whether the file exist already +static Status EpContextFilePathCheck(const std::string& ep_context_path, + const std::filesystem::path& model_path) { + std::filesystem::path context_cache_path; + if (!ep_context_path.empty()) { + context_cache_path = ep_context_path; + if (!context_cache_path.has_filename()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "context_file_path should not point to a folder."); + } + } else if (!model_path.empty()) { + context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx"); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty."); + } + + if (std::filesystem::exists(context_cache_path)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '", + context_cache_path, "' exist already."); + } + + return Status::OK(); +} + static Status CreateEpContextModel(const ExecutionProviders& execution_providers, const Graph& graph, const std::filesystem::path& ep_context_path, @@ -678,11 +701,6 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty"); } - if (std::filesystem::exists(context_cache_path)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '", - context_cache_path, "' exist already."); - } - Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(), graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path IOnnxRuntimeOpSchemaRegistryList{graph.GetSchemaRegistry()}, @@ -1007,9 +1025,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { #if !defined(ORT_MINIMAL_BUILD) + bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + if (ep_context_enabled) { + std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + // Check before EP compile graphs + ORT_RETURN_IF_ERROR(EpContextFilePathCheck(ep_context_path, graph.ModelPath())); + } + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger)); - bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; if (ep_context_enabled) { std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); std::string external_ini_file_name = config_options.GetConfigOrDefault(kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 8670fbbe2b..29a916997b 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -252,6 +252,113 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCpuNodeWithoutExternalWeights) { EpCtxCpuNodeWithExternalIniFileTestBody(false); } +// Set ep.context_file_path to folder path which is not a valid option, check the error message +TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationFolderPathNotExpected) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["offload_graph_io_quantization"] = "0"; + + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + bool single_ep_node = true; + BuildGraphWithQAndNonQ(single_ep_node)(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + + const std::string ep_context_onnx_file = "./ep_context_folder_not_expected/"; + std::remove(ep_context_onnx_file.c_str()); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str()); + so.AppendExecutionProvider("QNN", provider_options); + + try { + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); + FAIL(); // Should not get here! + } catch (const Ort::Exception& excpt) { + ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_INVALID_ARGUMENT); + ASSERT_THAT(excpt.what(), testing::HasSubstr("context_file_path should not point to a folder.")); + } +} + +// Create session 1 to generate context binary file +// Create session 2 to do same thing, make sure session 2 failed because file exist already +// Make sure no new file over write from session 2 +TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationNoOverWrite) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["offload_graph_io_quantization"] = "0"; + + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + bool single_ep_node = true; + BuildGraphWithQAndNonQ(single_ep_node)(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + + const std::string ep_context_onnx_file = "./ep_context_no_over_write.onnx"; + const std::string ep_context_binary_file = "./ep_context_no_over_write.onnx_QNNExecutionProvider_QNN_10880527342279992768_1_0.bin"; + + std::remove(ep_context_onnx_file.c_str()); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str()); + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session1(*ort_env, model_data_span.data(), model_data_span.size(), so); + + auto modify_time_1 = std::filesystem::last_write_time(ep_context_binary_file); + + try { + Ort::Session session2(*ort_env, model_data_span.data(), model_data_span.size(), so); + FAIL(); // Should not get here! + } catch (const Ort::Exception& excpt) { + ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_FAIL); + ASSERT_THAT(excpt.what(), testing::HasSubstr("exist already.")); + auto modify_time_2 = std::filesystem::last_write_time(ep_context_binary_file); + ASSERT_EQ(modify_time_1, modify_time_2); + } + + ASSERT_EQ(std::remove(ep_context_onnx_file.c_str()), 0); + ASSERT_EQ(std::remove(ep_context_binary_file.c_str()), 0); +} + // Create a model with Case + Add (quantized) // cast_input -> Cast -> Q -> DQ \ // Add -> Q -> DQ -> output