diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index b157396306..fd9bf200c4 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -88,9 +88,33 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, qnn_model); } - std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, ""); std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path(); - std::filesystem::path context_binary_path = folder_path.append(external_qnn_context_binary_file_name); + std::string external_qnn_ctx_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, ""); + ORT_RETURN_IF(external_qnn_ctx_binary_file_name.empty(), "The file path in ep_cache_context should not be empty."); +#ifdef _WIN32 + onnxruntime::PathString external_qnn_context_binary_path = onnxruntime::ToPathString(external_qnn_ctx_binary_file_name); + auto ctx_file_path = std::filesystem::path(external_qnn_context_binary_path.c_str()); + ORT_RETURN_IF(ctx_file_path.is_absolute(), "External mode should set ep_cache_context field with a relative path, but it is an absolute path: ", + external_qnn_ctx_binary_file_name); + auto relative_path = ctx_file_path.lexically_normal().make_preferred().wstring(); + if (relative_path.find(L"..", 0) != std::string::npos) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context field has '..'. It's not allowed to point outside the directory."); + } + + std::filesystem::path context_binary_path = folder_path.append(relative_path); +#else + ORT_RETURN_IF(external_qnn_ctx_binary_file_name[0] == '/', + "External mode should set ep_cache_context field with a relative path, but it is an absolute path: ", + external_qnn_ctx_binary_file_name); + if (external_qnn_ctx_binary_file_name.find("..", 0) != std::string::npos) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context field has '..'. It's not allowed to point outside the directory."); + } + std::filesystem::path context_binary_path = folder_path.append(external_qnn_ctx_binary_file_name); + std::string file_full_path = context_binary_path.string(); +#endif + if (!std::filesystem::is_regular_file(context_binary_path)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context does not exist or is not accessible."); + } size_t buffer_size{0}; std::ifstream cache_file(context_binary_path.string().c_str(), std::ifstream::binary); diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index c4244fe532..4ac1f5ddca 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -908,6 +908,135 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) { ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } +std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) { + const std::unordered_map domain_to_version = {{"", 11}, {kMSDomain, 1}}; + auto& logging_manager = DefaultLoggingManager(); + onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + std::vector shape = {2, 3}; + NodeArg* graph_input = MakeTestInput(helper, TestInputDef(shape, true, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f})); + auto* graph_output = helper.MakeOutput(shape); + Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); + ep_context_node.AddAttribute("embed_mode", static_cast(0)); + // The .. in the path will cause INVALID_GRAPH + ep_context_node.AddAttribute("ep_cache_context", external_bin_path); + ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); + ep_context_node.AddAttribute("source", "QNN"); + helper.SetGraphOutputs(); + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + return model_data; +} + +// Create a model with EPContext node. Set the node property ep_cache_context has ".." +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryRelativePathTest) { + std::string model_data = CreateQnnCtxModelWithNonEmbedMode("../qnn_context.bin"); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Create a model with EPContext node. Set the node property ep_cache_context has absolute path +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryAbsolutePathTest) { +#if defined(_WIN32) + std::string external_ctx_bin_path = "D:/qnn_context.bin"; +#else + std::string external_ctx_bin_path = "/data/qnn_context.bin"; +#endif + std::string model_data = CreateQnnCtxModelWithNonEmbedMode(external_ctx_bin_path); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Create a model with EPContext node. Set the node property ep_cache_context to a file not exist +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { + std::string model_data = CreateQnnCtxModelWithNonEmbedMode("qnn_context_not_exist.bin"); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Create a model with EPContext node. Set the node property ep_cache_context to empty string +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { + std::string model_data = CreateQnnCtxModelWithNonEmbedMode(""); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + // Run QDQ model on HTP with 2 inputs // 1st run will generate the Qnn context cache onnx file // 2nd run will load and run from QDQ model + Qnn context cache model