mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Validate the context_file_path before EP compile graphs (#23611)
Validate the context_file_path before EP compile graphs to make it fail fast. To avoid the possibility that EP generate new file (context binary file or blob file) over write the existing file. Return error if the path points to folder.
This commit is contained in:
parent
0887e3694a
commit
002916acb0
3 changed files with 138 additions and 6 deletions
|
|
@ -275,6 +275,7 @@ static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable";
|
|||
|
||||
// Specify the file path for the Onnx model which has EP context.
|
||||
// Default to original_file_name_ctx.onnx if not specified
|
||||
// Folder is not a valid option
|
||||
static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path";
|
||||
|
||||
// Flag to specify whether to dump the EP context into the Onnx model.
|
||||
|
|
|
|||
|
|
@ -643,6 +643,29 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Validate the ep_context_path to make sure it is file path and check whether the file exist already
|
||||
static Status EpContextFilePathCheck(const std::string& ep_context_path,
|
||||
const std::filesystem::path& model_path) {
|
||||
std::filesystem::path context_cache_path;
|
||||
if (!ep_context_path.empty()) {
|
||||
context_cache_path = ep_context_path;
|
||||
if (!context_cache_path.has_filename()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "context_file_path should not point to a folder.");
|
||||
}
|
||||
} else if (!model_path.empty()) {
|
||||
context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx");
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty.");
|
||||
}
|
||||
|
||||
if (std::filesystem::exists(context_cache_path)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '",
|
||||
context_cache_path, "' exist already.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status CreateEpContextModel(const ExecutionProviders& execution_providers,
|
||||
const Graph& graph,
|
||||
const std::filesystem::path& ep_context_path,
|
||||
|
|
@ -678,11 +701,6 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty");
|
||||
}
|
||||
|
||||
if (std::filesystem::exists(context_cache_path)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '",
|
||||
context_cache_path, "' exist already.");
|
||||
}
|
||||
|
||||
Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(),
|
||||
graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path
|
||||
IOnnxRuntimeOpSchemaRegistryList{graph.GetSchemaRegistry()},
|
||||
|
|
@ -1007,9 +1025,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
|
|||
|
||||
if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
|
||||
if (ep_context_enabled) {
|
||||
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
|
||||
// Check before EP compile graphs
|
||||
ORT_RETURN_IF_ERROR(EpContextFilePathCheck(ep_context_path, graph.ModelPath()));
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger));
|
||||
|
||||
bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
|
||||
if (ep_context_enabled) {
|
||||
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
|
||||
std::string external_ini_file_name = config_options.GetConfigOrDefault(kOrtSessionOptionsEpContextModelExternalInitializersFileName, "");
|
||||
|
|
|
|||
|
|
@ -252,6 +252,113 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCpuNodeWithoutExternalWeights) {
|
|||
EpCtxCpuNodeWithExternalIniFileTestBody(false);
|
||||
}
|
||||
|
||||
// Set ep.context_file_path to folder path which is not a valid option, check the error message
|
||||
TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationFolderPathNotExpected) {
|
||||
ProviderOptions provider_options;
|
||||
#if defined(_WIN32)
|
||||
provider_options["backend_path"] = "QnnHtp.dll";
|
||||
#else
|
||||
provider_options["backend_path"] = "libQnnHtp.so";
|
||||
#endif
|
||||
provider_options["offload_graph_io_quantization"] = "0";
|
||||
|
||||
const std::unordered_map<std::string, int> domain_to_version = {{"", 13}, {kMSDomain, 1}};
|
||||
|
||||
auto& logging_manager = DefaultLoggingManager();
|
||||
logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR);
|
||||
|
||||
onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(),
|
||||
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
|
||||
logging_manager.DefaultLogger());
|
||||
Graph& graph = model.MainGraph();
|
||||
ModelTestBuilder helper(graph);
|
||||
bool single_ep_node = true;
|
||||
BuildGraphWithQAndNonQ(single_ep_node)(helper);
|
||||
helper.SetGraphOutputs();
|
||||
ASSERT_STATUS_OK(model.MainGraph().Resolve());
|
||||
|
||||
// Serialize the model to a string.
|
||||
std::string model_data;
|
||||
model.ToProto().SerializeToString(&model_data);
|
||||
|
||||
const auto model_data_span = AsByteSpan(model_data.data(), model_data.size());
|
||||
|
||||
const std::string ep_context_onnx_file = "./ep_context_folder_not_expected/";
|
||||
std::remove(ep_context_onnx_file.c_str());
|
||||
Ort::SessionOptions so;
|
||||
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
|
||||
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str());
|
||||
so.AppendExecutionProvider("QNN", provider_options);
|
||||
|
||||
try {
|
||||
Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so);
|
||||
FAIL(); // Should not get here!
|
||||
} catch (const Ort::Exception& excpt) {
|
||||
ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_INVALID_ARGUMENT);
|
||||
ASSERT_THAT(excpt.what(), testing::HasSubstr("context_file_path should not point to a folder."));
|
||||
}
|
||||
}
|
||||
|
||||
// Create session 1 to generate context binary file
|
||||
// Create session 2 to do same thing, make sure session 2 failed because file exist already
|
||||
// Make sure no new file over write from session 2
|
||||
TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationNoOverWrite) {
|
||||
ProviderOptions provider_options;
|
||||
#if defined(_WIN32)
|
||||
provider_options["backend_path"] = "QnnHtp.dll";
|
||||
#else
|
||||
provider_options["backend_path"] = "libQnnHtp.so";
|
||||
#endif
|
||||
provider_options["offload_graph_io_quantization"] = "0";
|
||||
|
||||
const std::unordered_map<std::string, int> domain_to_version = {{"", 13}, {kMSDomain, 1}};
|
||||
|
||||
auto& logging_manager = DefaultLoggingManager();
|
||||
logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR);
|
||||
|
||||
onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(),
|
||||
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
|
||||
logging_manager.DefaultLogger());
|
||||
Graph& graph = model.MainGraph();
|
||||
ModelTestBuilder helper(graph);
|
||||
bool single_ep_node = true;
|
||||
BuildGraphWithQAndNonQ(single_ep_node)(helper);
|
||||
helper.SetGraphOutputs();
|
||||
ASSERT_STATUS_OK(model.MainGraph().Resolve());
|
||||
|
||||
// Serialize the model to a string.
|
||||
std::string model_data;
|
||||
model.ToProto().SerializeToString(&model_data);
|
||||
|
||||
const auto model_data_span = AsByteSpan(model_data.data(), model_data.size());
|
||||
|
||||
const std::string ep_context_onnx_file = "./ep_context_no_over_write.onnx";
|
||||
const std::string ep_context_binary_file = "./ep_context_no_over_write.onnx_QNNExecutionProvider_QNN_10880527342279992768_1_0.bin";
|
||||
|
||||
std::remove(ep_context_onnx_file.c_str());
|
||||
Ort::SessionOptions so;
|
||||
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
|
||||
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str());
|
||||
so.AppendExecutionProvider("QNN", provider_options);
|
||||
|
||||
Ort::Session session1(*ort_env, model_data_span.data(), model_data_span.size(), so);
|
||||
|
||||
auto modify_time_1 = std::filesystem::last_write_time(ep_context_binary_file);
|
||||
|
||||
try {
|
||||
Ort::Session session2(*ort_env, model_data_span.data(), model_data_span.size(), so);
|
||||
FAIL(); // Should not get here!
|
||||
} catch (const Ort::Exception& excpt) {
|
||||
ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_FAIL);
|
||||
ASSERT_THAT(excpt.what(), testing::HasSubstr("exist already."));
|
||||
auto modify_time_2 = std::filesystem::last_write_time(ep_context_binary_file);
|
||||
ASSERT_EQ(modify_time_1, modify_time_2);
|
||||
}
|
||||
|
||||
ASSERT_EQ(std::remove(ep_context_onnx_file.c_str()), 0);
|
||||
ASSERT_EQ(std::remove(ep_context_binary_file.c_str()), 0);
|
||||
}
|
||||
|
||||
// Create a model with Case + Add (quantized)
|
||||
// cast_input -> Cast -> Q -> DQ \
|
||||
// Add -> Q -> DQ -> output
|
||||
|
|
|
|||
Loading…
Reference in a new issue