[QNN EP] Add model description into context binary file metadata for validation (#16248)

### Description
Add model description into context binary file metadata for validation

### Motivation and Context
Dump more information for validation

---------

Co-authored-by: Adrian Lizarraga <adlizarraga@microsoft.com>
This commit is contained in:
Hector Li 2023-06-08 22:13:43 -07:00 committed by GitHub
parent d1e8d4a261
commit a9d47f72a4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 88 additions and 61 deletions

View file

@ -316,6 +316,26 @@ Status QnnBackendManager::ReleaseContext() {
return Status::OK();
}
bool QnnBackendManager::IsContextCacheFileExists(const std::string& customer_context_cache_path,
const std::string& model_description,
const onnxruntime::PathString& model_pathstring) {
// Avoid duplicate work
if (!context_cache_path_.empty()) {
return ctx_file_exists_;
}
model_description_ = model_description;
// Use user provided context cache file path if exist, otherwise try model_file.onnx.bin by default
if (customer_context_cache_path.empty()) {
context_cache_path_ = PathToUTF8String(model_pathstring) + ".bin";
} else {
context_cache_path_ = customer_context_cache_path;
}
ctx_file_exists_ = std::filesystem::exists(context_cache_path_);
return ctx_file_exists_;
}
Status WriteInt16ToBinaryFile(std::ofstream& of_stream, uint16_t value) {
const std::vector<uint16_t> data{value};
std::vector<unsigned char> data_bytes(sizeof(uint16_t) / sizeof(unsigned char));
@ -324,9 +344,7 @@ Status WriteInt16ToBinaryFile(std::ofstream& of_stream, uint16_t value) {
return Status::OK();
}
Status QnnBackendManager::DumpQnnContext(const onnxruntime::PathString& context_cache_pathstring,
const std::string& model_name,
const std::string& graph_name) {
Status QnnBackendManager::DumpQnnContext(const std::string& model_name, const std::string& graph_name) {
if (nullptr == qnn_interface_.contextGetBinarySize ||
nullptr == qnn_interface_.contextGetBinary) {
LOGS(*logger_, ERROR) << "Failed to get valid function pointer.";
@ -362,7 +380,7 @@ Status QnnBackendManager::DumpQnnContext(const onnxruntime::PathString& context_
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Context written buffer exceeds allocated buffer size.");
}
std::ofstream of_stream(context_cache_pathstring.c_str(), std::ofstream::binary);
std::ofstream of_stream(context_cache_path_.c_str(), std::ofstream::binary);
if (!of_stream) {
LOGS(*logger_, ERROR) << "Failed to open cached context file.";
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to open context cache file.");
@ -371,7 +389,10 @@ Status QnnBackendManager::DumpQnnContext(const onnxruntime::PathString& context_
// Write Ort metadata into context binary file
uint16_t model_name_length = static_cast<uint16_t>(model_name.length());
uint16_t graph_name_length = static_cast<uint16_t>(graph_name.length());
uint16_t header_length = 3 * sizeof(uint16_t) + model_name_length + graph_name_length;
uint16_t model_description_length = static_cast<uint16_t>(model_description_.length());
// Header: uint16_t(totale_length)|uint16_t(model_name_length)|model_name|uint16_t(graph_name_length)|graph_name|uint16_t(model_description_length)|model_description
uint16_t header_length = 4 * sizeof(uint16_t) + model_name_length + graph_name_length + model_description_length;
uint16_t totale_length = header_length + static_cast<uint16_t>(strlen(QNN_PROVIDER));
of_stream.write(QNN_PROVIDER, strlen(QNN_PROVIDER));
@ -382,6 +403,11 @@ Status QnnBackendManager::DumpQnnContext(const onnxruntime::PathString& context_
ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, graph_name_length));
of_stream.write(graph_name.c_str(), graph_name_length);
ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, model_description_length));
of_stream.write(model_description_.c_str(), model_description_length);
model_description_.clear();
LOGS(*logger_, VERBOSE) << "Dump metadata with length: " << totale_length;
of_stream.write(reinterpret_cast<char*>(context_buffer.get()), written_buffer_size);
@ -390,14 +416,16 @@ Status QnnBackendManager::DumpQnnContext(const onnxruntime::PathString& context_
return Status::OK();
}
Status QnnBackendManager::LoadCachedQnnContext(const onnxruntime::PathString& context_cache_pathstring, QnnModel& qnn_model) {
Status QnnBackendManager::LoadCachedQnnContext(QnnModel& qnn_model) {
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
nullptr == qnn_sys_interface_.systemContextFree;
ORT_RETURN_IF(result, "Failed to get valid function pointer.");
ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!");
uint64_t buffer_size{0};
std::ifstream cache_file(context_cache_pathstring.c_str(), std::ifstream::binary);
std::ifstream cache_file(context_cache_path_.c_str(), std::ifstream::binary);
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file.");
cache_file.seekg(0, cache_file.end);
buffer_size = cache_file.tellg();
@ -466,6 +494,8 @@ Status QnnBackendManager::LoadCachedQnnContext(const onnxruntime::PathString& co
ORT_RETURN_IF_ERROR(ExtractBackendProfilingInfo());
context_created_ = true;
model_description_.clear();
model_description_from_ctx_cache_.clear();
LOGS(*logger_, VERBOSE) << "Load from cached QNN Context completed.";
return Status::OK();
}
@ -502,12 +532,11 @@ Status ReadInt16FromBinaryFile(std::ifstream& binary_file, uint16_t& value) {
}
/* \brief: Try to get metadata from Ort generated context cache binary file.
* \param[in] context_cache_pathstring - context cache binary file path string
* Cached context binary file generated by Ort has some metadata which can be used for validation with the model
* to avoid user choose a wrong context binary file which is not for this model
* It is treated as Qnn generated context binary file if no metadata found from the file
*/
Status QnnBackendManager::GetMetadataFromOrtContextFile(const onnxruntime::PathString& context_cache_pathstring) {
Status QnnBackendManager::GetMetadataFromOrtContextFile() {
// Only try parse meta data once
if (ctx_metadata_tried_) {
return Status::OK();
@ -515,7 +544,7 @@ Status QnnBackendManager::GetMetadataFromOrtContextFile(const onnxruntime::PathS
ctx_metadata_tried_ = true;
uint64_t buffer_size = 0;
std::ifstream cache_file(context_cache_pathstring.c_str(), std::ifstream::binary);
std::ifstream cache_file(context_cache_path_.c_str(), std::ifstream::binary);
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open context cache file.");
cache_file.seekg(0, cache_file.end);
buffer_size = cache_file.tellg();
@ -533,17 +562,18 @@ Status QnnBackendManager::GetMetadataFromOrtContextFile(const onnxruntime::PathS
}
ort_generated_ctx_cache_ = true;
uint16_t header_length = 0;
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, header_length));
ort_ctx_metadata_length_ = header_length + static_cast<uint16_t>(ort_flag_length);
uint16_t str_length = 0;
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
ort_ctx_metadata_length_ = str_length + static_cast<uint16_t>(ort_flag_length);
uint16_t model_name_length = 0;
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, model_name_length));
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, model_name_from_ctx_cache_, static_cast<size_t>(model_name_length)));
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, model_name_from_ctx_cache_, static_cast<size_t>(str_length)));
uint16_t graph_name_length = 0;
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, graph_name_length));
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, graph_name_from_ctx_cache_, static_cast<size_t>(graph_name_length)));
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, graph_name_from_ctx_cache_, static_cast<size_t>(str_length)));
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, model_description_from_ctx_cache_, static_cast<size_t>(str_length)));
return Status::OK();
}
@ -555,15 +585,30 @@ Status QnnBackendManager::GetMetadataFromOrtContextFile(const onnxruntime::PathS
* so only validate the graph name for 2nd call
*/
Status QnnBackendManager::ValidateWithContextFile(const std::string& model_name, const std::string& graph_name) {
ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!");
// Get metadata from cached context binary file
ORT_RETURN_IF_ERROR(GetMetadataFromOrtContextFile());
// The context binary file doesn't have ORT metadata, so it is generated from QNN toolchain not from ORT
if (!ort_generated_ctx_cache_) {
return Status::OK();
}
ORT_RETURN_IF(model_name != model_name_from_ctx_cache_,
"Model file name from context cache metadata: " + model_name_from_ctx_cache_ + " is different with target: " + model_name);
"Model file name from context cache metadata: " + model_name_from_ctx_cache_ +
" is different with target: " + model_name +
". Please make sure the context binary file matches the model.");
ORT_RETURN_IF(model_description_ != model_description_from_ctx_cache_,
"Model description from context cache metadata: " + model_description_from_ctx_cache_ +
" is different with target: " + model_description_ +
". Please make sure the context binary file matches the model.");
ORT_RETURN_IF(graph_name != graph_name_from_ctx_cache_ && get_capability_round_2_,
"Graph name from context cache metadata: " + graph_name_from_ctx_cache_ + " is different with target: " + graph_name);
"Graph name from context cache metadata: " + graph_name_from_ctx_cache_ +
" is different with target: " + graph_name +
". You may need to re-generate the context binary file.");
get_capability_round_2_ = true;
return Status::OK();

View file

@ -69,13 +69,11 @@ class QnnBackendManager {
return CreateContext();
}
Status DumpQnnContext(const onnxruntime::PathString& context_cache_pathstring,
const std::string& model_name,
const std::string& graph_name);
Status DumpQnnContext(const std::string& model_name, const std::string& graph_name);
Status LoadCachedQnnContext(const onnxruntime::PathString& context_cache_pathstring, QnnModel& qnn_model);
Status LoadCachedQnnContext(QnnModel& qnn_model);
Status GetMetadataFromOrtContextFile(const onnxruntime::PathString& model_path);
Status GetMetadataFromOrtContextFile();
Status ValidateWithContextFile(const std::string& model_name, const std::string& graph_name);
@ -133,6 +131,10 @@ class QnnBackendManager {
// NPU backend requires quantized model
bool IsNpuBackend() { return is_npu_backend_; }
bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
const std::string& model_description,
const onnxruntime::PathString& model_pathstring);
private:
void* LoadLib(const char* file_name, int flags, std::string& error_msg);
@ -197,6 +199,10 @@ class QnnBackendManager {
HtpPerformanceMode htp_performance_mode_;
std::string model_name_from_ctx_cache_ = "";
std::string graph_name_from_ctx_cache_ = "";
std::string model_description_from_ctx_cache_ = "";
std::string model_description_ = "";
std::string context_cache_path_ = "";
bool ctx_file_exists_ = false;
bool ctx_metadata_tried_ = false;
bool ort_generated_ctx_cache_ = false;
bool get_capability_round_2_ = false;

View file

@ -276,18 +276,9 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
const auto& logger = *GetLogger();
bool load_from_cached_context = false;
if (context_cache_enabled_) {
onnxruntime::PathString context_cache_pathstring;
load_from_cached_context = IsContextCacheFileExists(graph_viewer.ModelPath().ToPathString(),
context_cache_pathstring);
// Get metadata from cached context binary file
if (load_from_cached_context) {
auto rt = qnn_backend_manager_->GetMetadataFromOrtContextFile(context_cache_pathstring);
if (Status::OK() != rt) {
LOGS(logger, ERROR) << "Failed to get metadata from cached context binary file. " << rt.ErrorMessage();
return result;
}
}
load_from_cached_context = qnn_backend_manager_->IsContextCacheFileExists(context_cache_path_,
graph_viewer.Description(),
graph_viewer.ModelPath().ToPathString());
}
// Load from cached context will load the QnnSystem lib and skip the Qnn context creation
@ -444,19 +435,6 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector<FusedNodeAndG
return Status::OK();
}
bool QNNExecutionProvider::IsContextCacheFileExists(const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_pathstring) const {
// Use user provided context cache file path if exist, otherwise try model_file.onnx.bin by default
if (context_cache_path_.empty()) {
context_cache_pathstring = model_pathstring + ToPathString(".bin");
} else {
context_cache_pathstring = ToPathString(context_cache_path_);
}
bool context_cache_file_exist = std::filesystem::exists(context_cache_pathstring.c_str());
return context_cache_file_exist;
}
Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) {
const auto& logger = *GetLogger();
@ -466,15 +444,18 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support singel partition for context cache feature.");
Node& fused_node = fused_nodes_and_graphs[0].fused_node;
const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[0].filtered_graph);
onnxruntime::PathString context_cache_pathstring;
bool load_from_cached_context = IsContextCacheFileExists(graph_viewer.ModelPath().ToPathString(),
context_cache_pathstring);
// The dumy_model_description won't be used since IsContextCacheFileExists call cached the result
// The graph_viewer.Description here is not same with original model
std::string dumy_model_description = "";
bool load_from_cached_context = qnn_backend_manager_->IsContextCacheFileExists(context_cache_path_,
dumy_model_description,
graph_viewer.ModelPath().ToPathString());
// Load and execute from cached context if exist
if (load_from_cached_context) {
std::unique_ptr<qnn::QnnModel> qnn_model = std::make_unique<qnn::QnnModel>(logger,
qnn_backend_manager_.get(),
is_npu_backend);
ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnContext(context_cache_pathstring, *(qnn_model.get())));
ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnContext(*(qnn_model.get())));
ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node));
ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput());
@ -490,8 +471,7 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger));
// graph_viewer.Name() is generated in GetCapability, e.g QNN_[hash_id]_[id]
// dump graph_viewer.Name() as metadata in context cache binary file, so that we can validate it in GetCapability
ORT_RETURN_IF_ERROR(qnn_backend_manager_->DumpQnnContext(context_cache_pathstring,
GetFileNameFromModelPath(graph_viewer.ModelPath()),
ORT_RETURN_IF_ERROR(qnn_backend_manager_->DumpQnnContext(GetFileNameFromModelPath(graph_viewer.ModelPath()),
graph_viewer.Name()));
}
return Status::OK();

View file

@ -54,9 +54,6 @@ class QNNExecutionProvider : public IExecutionProvider {
std::vector<NodeComputeInfo>& node_compute_funcs,
const logging::Logger& logger);
bool IsContextCacheFileExists(const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_pathstring) const;
void ParseHtpPerformanceMode(std::string htp_performance_mode_string);
private:

View file

@ -156,7 +156,6 @@ TEST_F(QnnHTPBackendTests, TestQDQAtanTest) {
// 1st run will generate the Qnn context cache binary file
// 2nd run will load and run from Qnn context cache binary file
TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) {
RunQDQSingleInputOpTest({1, 2, 3}, "Atan", "TestQDQGeluTest", 11, ExpectedEPNodeAssignment::All, 1);
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";