mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Check the ep_cache_context and don't allow access outside the directory (#19174)
### Description Check the ep_cache_context node property for EPContext node, and don't allow relative path like "../file_path"
This commit is contained in:
parent
9da3e36138
commit
dadd3ea704
2 changed files with 155 additions and 2 deletions
|
|
@ -88,9 +88,33 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
|
|||
qnn_model);
|
||||
}
|
||||
|
||||
std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, "");
|
||||
std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path();
|
||||
std::filesystem::path context_binary_path = folder_path.append(external_qnn_context_binary_file_name);
|
||||
std::string external_qnn_ctx_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, "");
|
||||
ORT_RETURN_IF(external_qnn_ctx_binary_file_name.empty(), "The file path in ep_cache_context should not be empty.");
|
||||
#ifdef _WIN32
|
||||
onnxruntime::PathString external_qnn_context_binary_path = onnxruntime::ToPathString(external_qnn_ctx_binary_file_name);
|
||||
auto ctx_file_path = std::filesystem::path(external_qnn_context_binary_path.c_str());
|
||||
ORT_RETURN_IF(ctx_file_path.is_absolute(), "External mode should set ep_cache_context field with a relative path, but it is an absolute path: ",
|
||||
external_qnn_ctx_binary_file_name);
|
||||
auto relative_path = ctx_file_path.lexically_normal().make_preferred().wstring();
|
||||
if (relative_path.find(L"..", 0) != std::string::npos) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context field has '..'. It's not allowed to point outside the directory.");
|
||||
}
|
||||
|
||||
std::filesystem::path context_binary_path = folder_path.append(relative_path);
|
||||
#else
|
||||
ORT_RETURN_IF(external_qnn_ctx_binary_file_name[0] == '/',
|
||||
"External mode should set ep_cache_context field with a relative path, but it is an absolute path: ",
|
||||
external_qnn_ctx_binary_file_name);
|
||||
if (external_qnn_ctx_binary_file_name.find("..", 0) != std::string::npos) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context field has '..'. It's not allowed to point outside the directory.");
|
||||
}
|
||||
std::filesystem::path context_binary_path = folder_path.append(external_qnn_ctx_binary_file_name);
|
||||
std::string file_full_path = context_binary_path.string();
|
||||
#endif
|
||||
if (!std::filesystem::is_regular_file(context_binary_path)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context does not exist or is not accessible.");
|
||||
}
|
||||
|
||||
size_t buffer_size{0};
|
||||
std::ifstream cache_file(context_binary_path.string().c_str(), std::ifstream::binary);
|
||||
|
|
|
|||
|
|
@ -908,6 +908,135 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) {
|
|||
ASSERT_EQ(std::remove(context_binary_file.c_str()), 0);
|
||||
}
|
||||
|
||||
std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) {
|
||||
const std::unordered_map<std::string, int> domain_to_version = {{"", 11}, {kMSDomain, 1}};
|
||||
auto& logging_manager = DefaultLoggingManager();
|
||||
onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(),
|
||||
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
|
||||
logging_manager.DefaultLogger());
|
||||
Graph& graph = model.MainGraph();
|
||||
ModelTestBuilder helper(graph);
|
||||
std::vector<int64_t> shape = {2, 3};
|
||||
NodeArg* graph_input = MakeTestInput(helper, TestInputDef<float>(shape, true, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f}));
|
||||
auto* graph_output = helper.MakeOutput<float>(shape);
|
||||
Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain);
|
||||
ep_context_node.AddAttribute("embed_mode", static_cast<int64_t>(0));
|
||||
// The .. in the path will cause INVALID_GRAPH
|
||||
ep_context_node.AddAttribute("ep_cache_context", external_bin_path);
|
||||
ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0");
|
||||
ep_context_node.AddAttribute("source", "QNN");
|
||||
helper.SetGraphOutputs();
|
||||
std::string model_data;
|
||||
model.ToProto().SerializeToString(&model_data);
|
||||
|
||||
return model_data;
|
||||
}
|
||||
|
||||
// Create a model with EPContext node. Set the node property ep_cache_context has ".."
|
||||
// Verify that it return INVALID_GRAPH status
|
||||
TEST_F(QnnHTPBackendTests, QnnContextBinaryRelativePathTest) {
|
||||
std::string model_data = CreateQnnCtxModelWithNonEmbedMode("../qnn_context.bin");
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "qnn_ctx_model_logger";
|
||||
RunOptions run_options;
|
||||
run_options.run_tag = so.session_logid;
|
||||
|
||||
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
||||
|
||||
ProviderOptions provider_options;
|
||||
#if defined(_WIN32)
|
||||
provider_options["backend_path"] = "QnnHtp.dll";
|
||||
#else
|
||||
provider_options["backend_path"] = "libQnnHtp.so";
|
||||
#endif
|
||||
|
||||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options)));
|
||||
ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast<int>(model_data.size())));
|
||||
// Verify the return status with code INVALID_GRAPH
|
||||
ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH);
|
||||
}
|
||||
|
||||
// Create a model with EPContext node. Set the node property ep_cache_context has absolute path
|
||||
// Verify that it return INVALID_GRAPH status
|
||||
TEST_F(QnnHTPBackendTests, QnnContextBinaryAbsolutePathTest) {
|
||||
#if defined(_WIN32)
|
||||
std::string external_ctx_bin_path = "D:/qnn_context.bin";
|
||||
#else
|
||||
std::string external_ctx_bin_path = "/data/qnn_context.bin";
|
||||
#endif
|
||||
std::string model_data = CreateQnnCtxModelWithNonEmbedMode(external_ctx_bin_path);
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "qnn_ctx_model_logger";
|
||||
RunOptions run_options;
|
||||
run_options.run_tag = so.session_logid;
|
||||
|
||||
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
||||
|
||||
ProviderOptions provider_options;
|
||||
#if defined(_WIN32)
|
||||
provider_options["backend_path"] = "QnnHtp.dll";
|
||||
#else
|
||||
provider_options["backend_path"] = "libQnnHtp.so";
|
||||
#endif
|
||||
|
||||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options)));
|
||||
ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast<int>(model_data.size())));
|
||||
// Verify the return status with code INVALID_GRAPH
|
||||
ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH);
|
||||
}
|
||||
|
||||
// Create a model with EPContext node. Set the node property ep_cache_context to a file not exist
|
||||
// Verify that it return INVALID_GRAPH status
|
||||
TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) {
|
||||
std::string model_data = CreateQnnCtxModelWithNonEmbedMode("qnn_context_not_exist.bin");
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "qnn_ctx_model_logger";
|
||||
RunOptions run_options;
|
||||
run_options.run_tag = so.session_logid;
|
||||
|
||||
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
||||
|
||||
ProviderOptions provider_options;
|
||||
#if defined(_WIN32)
|
||||
provider_options["backend_path"] = "QnnHtp.dll";
|
||||
#else
|
||||
provider_options["backend_path"] = "libQnnHtp.so";
|
||||
#endif
|
||||
|
||||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options)));
|
||||
ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast<int>(model_data.size())));
|
||||
// Verify the return status with code INVALID_GRAPH
|
||||
ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH);
|
||||
}
|
||||
|
||||
// Create a model with EPContext node. Set the node property ep_cache_context to empty string
|
||||
// Verify that it return INVALID_GRAPH status
|
||||
TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) {
|
||||
std::string model_data = CreateQnnCtxModelWithNonEmbedMode("");
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "qnn_ctx_model_logger";
|
||||
RunOptions run_options;
|
||||
run_options.run_tag = so.session_logid;
|
||||
|
||||
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
||||
|
||||
ProviderOptions provider_options;
|
||||
#if defined(_WIN32)
|
||||
provider_options["backend_path"] = "QnnHtp.dll";
|
||||
#else
|
||||
provider_options["backend_path"] = "libQnnHtp.so";
|
||||
#endif
|
||||
|
||||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options)));
|
||||
ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast<int>(model_data.size())));
|
||||
// Verify the return status with code INVALID_GRAPH
|
||||
ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH);
|
||||
}
|
||||
|
||||
// Run QDQ model on HTP with 2 inputs
|
||||
// 1st run will generate the Qnn context cache onnx file
|
||||
// 2nd run will load and run from QDQ model + Qnn context cache model
|
||||
|
|
|
|||
Loading…
Reference in a new issue