Validate the context_file_path before EP compile graphs (#23611)

Validate the context_file_path before EP compile graphs to make it fail fast. To avoid the possibility that EP generate new file (context binary file or blob file) over write the existing file. Return error if the path points to folder.
This commit is contained in:
Hector Li 2025-02-07 21:31:11 -08:00 committed by GitHub
parent 0887e3694a
commit 002916acb0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 138 additions and 6 deletions

View file

@ -275,6 +275,7 @@ static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable";
// Specify the file path for the Onnx model which has EP context.
// Default to original_file_name_ctx.onnx if not specified
// Folder is not a valid option
static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path";
// Flag to specify whether to dump the EP context into the Onnx model.

View file

@ -643,6 +643,29 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
return Status::OK();
}
// Validate the ep_context_path to make sure it is file path and check whether the file exist already
static Status EpContextFilePathCheck(const std::string& ep_context_path,
const std::filesystem::path& model_path) {
std::filesystem::path context_cache_path;
if (!ep_context_path.empty()) {
context_cache_path = ep_context_path;
if (!context_cache_path.has_filename()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "context_file_path should not point to a folder.");
}
} else if (!model_path.empty()) {
context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx");
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty.");
}
if (std::filesystem::exists(context_cache_path)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '",
context_cache_path, "' exist already.");
}
return Status::OK();
}
static Status CreateEpContextModel(const ExecutionProviders& execution_providers,
const Graph& graph,
const std::filesystem::path& ep_context_path,
@ -678,11 +701,6 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty");
}
if (std::filesystem::exists(context_cache_path)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '",
context_cache_path, "' exist already.");
}
Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(),
graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path
IOnnxRuntimeOpSchemaRegistryList{graph.GetSchemaRegistry()},
@ -1007,9 +1025,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
#if !defined(ORT_MINIMAL_BUILD)
bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
if (ep_context_enabled) {
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
// Check before EP compile graphs
ORT_RETURN_IF_ERROR(EpContextFilePathCheck(ep_context_path, graph.ModelPath()));
}
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger));
bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
if (ep_context_enabled) {
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
std::string external_ini_file_name = config_options.GetConfigOrDefault(kOrtSessionOptionsEpContextModelExternalInitializersFileName, "");

View file

@ -252,6 +252,113 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCpuNodeWithoutExternalWeights) {
EpCtxCpuNodeWithExternalIniFileTestBody(false);
}
// Set ep.context_file_path to folder path which is not a valid option, check the error message
TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationFolderPathNotExpected) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
provider_options["offload_graph_io_quantization"] = "0";
const std::unordered_map<std::string, int> domain_to_version = {{"", 13}, {kMSDomain, 1}};
auto& logging_manager = DefaultLoggingManager();
logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR);
onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
logging_manager.DefaultLogger());
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
bool single_ep_node = true;
BuildGraphWithQAndNonQ(single_ep_node)(helper);
helper.SetGraphOutputs();
ASSERT_STATUS_OK(model.MainGraph().Resolve());
// Serialize the model to a string.
std::string model_data;
model.ToProto().SerializeToString(&model_data);
const auto model_data_span = AsByteSpan(model_data.data(), model_data.size());
const std::string ep_context_onnx_file = "./ep_context_folder_not_expected/";
std::remove(ep_context_onnx_file.c_str());
Ort::SessionOptions so;
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str());
so.AppendExecutionProvider("QNN", provider_options);
try {
Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so);
FAIL(); // Should not get here!
} catch (const Ort::Exception& excpt) {
ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_INVALID_ARGUMENT);
ASSERT_THAT(excpt.what(), testing::HasSubstr("context_file_path should not point to a folder."));
}
}
// Create session 1 to generate context binary file
// Create session 2 to do same thing, make sure session 2 failed because file exist already
// Make sure no new file over write from session 2
TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationNoOverWrite) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
provider_options["offload_graph_io_quantization"] = "0";
const std::unordered_map<std::string, int> domain_to_version = {{"", 13}, {kMSDomain, 1}};
auto& logging_manager = DefaultLoggingManager();
logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR);
onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
logging_manager.DefaultLogger());
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
bool single_ep_node = true;
BuildGraphWithQAndNonQ(single_ep_node)(helper);
helper.SetGraphOutputs();
ASSERT_STATUS_OK(model.MainGraph().Resolve());
// Serialize the model to a string.
std::string model_data;
model.ToProto().SerializeToString(&model_data);
const auto model_data_span = AsByteSpan(model_data.data(), model_data.size());
const std::string ep_context_onnx_file = "./ep_context_no_over_write.onnx";
const std::string ep_context_binary_file = "./ep_context_no_over_write.onnx_QNNExecutionProvider_QNN_10880527342279992768_1_0.bin";
std::remove(ep_context_onnx_file.c_str());
Ort::SessionOptions so;
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str());
so.AppendExecutionProvider("QNN", provider_options);
Ort::Session session1(*ort_env, model_data_span.data(), model_data_span.size(), so);
auto modify_time_1 = std::filesystem::last_write_time(ep_context_binary_file);
try {
Ort::Session session2(*ort_env, model_data_span.data(), model_data_span.size(), so);
FAIL(); // Should not get here!
} catch (const Ort::Exception& excpt) {
ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_FAIL);
ASSERT_THAT(excpt.what(), testing::HasSubstr("exist already."));
auto modify_time_2 = std::filesystem::last_write_time(ep_context_binary_file);
ASSERT_EQ(modify_time_1, modify_time_2);
}
ASSERT_EQ(std::remove(ep_context_onnx_file.c_str()), 0);
ASSERT_EQ(std::remove(ep_context_binary_file.c_str()), 0);
}
// Create a model with Case + Add (quantized)
// cast_input -> Cast -> Q -> DQ \
// Add -> Q -> DQ -> output