mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
[QNN EP] Reduce overhead of QNN context binary loading (#17965)
### Description Reduce overhead of QNN context binary loading by avoiding memory copy ### Motivation and Context Reduce the session initialization time and memory usage while load from QNN context binary
This commit is contained in:
parent
cbb0e0f83c
commit
35ecce4549
7 changed files with 74 additions and 80 deletions
|
|
@ -60,26 +60,32 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
|
||||
std::string& ep_cache_context,
|
||||
const logging::Logger& logger) {
|
||||
Status QnnCacheModelHandler::GetEpContextFromModel(const std::string& ctx_onnx_model_path,
|
||||
QnnBackendManager* qnn_backend_manager,
|
||||
QnnModel& qnn_model,
|
||||
const logging::Logger& logger) {
|
||||
using namespace onnxruntime;
|
||||
std::shared_ptr<Model> model;
|
||||
ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger));
|
||||
const auto& graph = model->MainGraph();
|
||||
ORT_RETURN_IF_ERROR(GetEpContextFromGraph(GraphViewer(graph), ctx_onnx_model_path, ep_cache_context));
|
||||
|
||||
return Status::OK();
|
||||
return GetEpContextFromGraph(GraphViewer(graph),
|
||||
ctx_onnx_model_path,
|
||||
qnn_backend_manager,
|
||||
qnn_model);
|
||||
}
|
||||
|
||||
Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::string& ctx_onnx_model_path,
|
||||
std::string& ep_cache_context) {
|
||||
Status QnnCacheModelHandler::GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::string& ctx_onnx_model_path,
|
||||
QnnBackendManager* qnn_backend_manager,
|
||||
QnnModel& qnn_model) {
|
||||
const auto& node = graph_viewer.Nodes().begin();
|
||||
NodeAttrHelper node_helper(*node);
|
||||
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
|
||||
if (is_embed_mode) {
|
||||
ep_cache_context = node_helper.Get(EP_CACHE_CONTEXT, "");
|
||||
const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, "");
|
||||
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast<char*>(context_binary.c_str()),
|
||||
static_cast<uint64_t>(context_binary.length()),
|
||||
qnn_model);
|
||||
} else {
|
||||
std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, "");
|
||||
|
||||
|
|
@ -88,17 +94,23 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
|
|||
size_t buffer_size{0};
|
||||
std::ifstream cache_file(context_binary_path.c_str(), std::ifstream::binary);
|
||||
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file.");
|
||||
|
||||
cache_file.seekg(0, cache_file.end);
|
||||
buffer_size = static_cast<size_t>(cache_file.tellg());
|
||||
ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered.");
|
||||
|
||||
cache_file.seekg(0, cache_file.beg);
|
||||
ep_cache_context.reserve(buffer_size);
|
||||
std::unique_ptr<char[]> buffer = std::make_unique<char[]>(buffer_size);
|
||||
ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for cache file.");
|
||||
// Load file into buffer
|
||||
ep_cache_context.assign(std::istreambuf_iterator<char>(cache_file), std::istreambuf_iterator<char>());
|
||||
const auto& read_result = cache_file.read(buffer.get(), buffer_size);
|
||||
ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file.");
|
||||
cache_file.close();
|
||||
ORT_RETURN_IF(ep_cache_context.length() != buffer_size, "Failed to read contents from cached context file.");
|
||||
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
|
||||
static_cast<uint64_t>(buffer_size),
|
||||
qnn_model);
|
||||
}
|
||||
ORT_RETURN_IF(ep_cache_context.empty(), "Cached context empty.");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ namespace onnxruntime {
|
|||
namespace qnn {
|
||||
|
||||
class QnnModel;
|
||||
class QnnBackendManager;
|
||||
|
||||
static const std::string EPCONTEXT_OP = "EPContext";
|
||||
static const std::string MAIN_CONTEXT = "main_context";
|
||||
|
|
@ -37,32 +38,24 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
|
|||
std::vector<NodeArg*>& node_args,
|
||||
onnxruntime::Graph& graph);
|
||||
|
||||
Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
|
||||
std::string& ep_engine_cache,
|
||||
const logging::Logger& logger);
|
||||
|
||||
Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::string& ctx_onnx_model_path,
|
||||
std::string& ep_cache_context);
|
||||
|
||||
class QnnCacheModelHandler {
|
||||
public:
|
||||
QnnCacheModelHandler(bool qnn_context_embed_mode) : qnn_context_embed_mode_(qnn_context_embed_mode) {
|
||||
}
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnCacheModelHandler);
|
||||
|
||||
Status GetEpContext(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::string& ctx_onnx_model_path,
|
||||
bool is_qnn_ctx_model,
|
||||
bool is_ctx_cache_file_exist,
|
||||
std::string& ep_engine_cache,
|
||||
const logging::Logger& logger) const {
|
||||
Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::string& ctx_onnx_model_path,
|
||||
bool is_qnn_ctx_model,
|
||||
bool is_ctx_cache_file_exist,
|
||||
QnnBackendManager* qnn_backend_manager,
|
||||
QnnModel& qnn_model,
|
||||
const logging::Logger& logger) {
|
||||
if (is_qnn_ctx_model) {
|
||||
ORT_RETURN_IF_ERROR(GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, ep_engine_cache));
|
||||
return GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model);
|
||||
} else if (is_ctx_cache_file_exist) {
|
||||
ORT_RETURN_IF_ERROR(GetEpContextFromModel(ctx_onnx_model_path, ep_engine_cache, logger));
|
||||
return GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -92,12 +85,25 @@ class QnnCacheModelHandler {
|
|||
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
|
||||
const logging::Logger& logger);
|
||||
|
||||
private:
|
||||
Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
|
||||
QnnBackendManager* qnn_backend_manager,
|
||||
QnnModel& qnn_model,
|
||||
const logging::Logger& logger);
|
||||
|
||||
Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::string& ctx_onnx_model_path,
|
||||
QnnBackendManager* qnn_backend_manager,
|
||||
QnnModel& qnn_model);
|
||||
|
||||
private:
|
||||
bool is_metadata_ready_ = false;
|
||||
// model_name_ to cache_source_ -- metadata get from generated Qnn context binary Onnx model
|
||||
std::string model_name_ = "";
|
||||
std::string model_description_ = "";
|
||||
std::string graph_partition_name_ = "";
|
||||
std::string cache_source_ = "";
|
||||
|
||||
std::string context_cache_path_ = "";
|
||||
bool ctx_file_exists_ = false;
|
||||
bool get_capability_round_2_ = false;
|
||||
|
|
|
|||
|
|
@ -448,20 +448,7 @@ std::unique_ptr<unsigned char[]> QnnBackendManager::GetContextBinaryBuffer(uint6
|
|||
return context_buffer;
|
||||
}
|
||||
|
||||
Status QnnBackendManager::LoadCachedQnnCtxFromOnnxModel(const std::string& ep_engine_cache,
|
||||
QnnModel& qnn_model,
|
||||
bool& loaded_from_cache) {
|
||||
loaded_from_cache = false;
|
||||
|
||||
if (!ep_engine_cache.empty()) {
|
||||
ORT_RETURN_IF_ERROR(LoadCachedQnnContextFromBuffer(ep_engine_cache, qnn_model));
|
||||
loaded_from_cache = true;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QnnBackendManager::LoadCachedQnnContextFromBuffer(const std::string& buffer, QnnModel& qnn_model) {
|
||||
Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model) {
|
||||
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
|
||||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
|
||||
nullptr == qnn_sys_interface_.systemContextFree;
|
||||
|
|
@ -474,8 +461,8 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(const std::string& buff
|
|||
const QnnSystemContext_BinaryInfo_t* binary_info = nullptr;
|
||||
Qnn_ContextBinarySize_t binary_info_size{0};
|
||||
rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle,
|
||||
static_cast<void*>(const_cast<char*>(buffer.c_str())),
|
||||
static_cast<uint64_t>(buffer.length()),
|
||||
static_cast<void*>(buffer),
|
||||
buffer_length,
|
||||
&binary_info,
|
||||
&binary_info_size);
|
||||
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info.");
|
||||
|
|
@ -502,8 +489,8 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(const std::string& buff
|
|||
rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
|
||||
device_handle_,
|
||||
(const QnnContext_Config_t**)&context_config_,
|
||||
static_cast<void*>(const_cast<char*>(buffer.c_str())),
|
||||
static_cast<uint64_t>(buffer.length()),
|
||||
static_cast<void*>(buffer),
|
||||
buffer_length,
|
||||
&context_,
|
||||
profile_backend_handle_);
|
||||
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary.");
|
||||
|
|
|
|||
|
|
@ -74,9 +74,7 @@ class QnnBackendManager {
|
|||
|
||||
std::unique_ptr<unsigned char[]> GetContextBinaryBuffer(uint64_t& written_buffer_size);
|
||||
|
||||
Status LoadCachedQnnCtxFromOnnxModel(const std::string& ep_engine_cache,
|
||||
QnnModel& qnn_model,
|
||||
bool& loaded_from_cache);
|
||||
Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model);
|
||||
|
||||
Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context);
|
||||
|
||||
|
|
@ -174,8 +172,6 @@ class QnnBackendManager {
|
|||
return (backend_build_id == nullptr ? std::string("") : std::string(backend_build_id));
|
||||
}
|
||||
|
||||
Status LoadCachedQnnContextFromBuffer(const std::string& buffer, QnnModel& qnn_model);
|
||||
|
||||
private:
|
||||
const std::string backend_path_;
|
||||
const logging::Logger* logger_ = nullptr;
|
||||
|
|
|
|||
|
|
@ -483,33 +483,28 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
|
|||
bool is_qnn_ctx_model = false;
|
||||
ORT_RETURN_IF_ERROR(qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs, is_qnn_ctx_model));
|
||||
|
||||
if (context_cache_enabled_ || is_qnn_ctx_model) {
|
||||
bool is_ctx_file_exist = qnn_cache_model_handler_->GetIsContextCacheFileExists();
|
||||
if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) {
|
||||
ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support single partition for context cache feature.");
|
||||
std::unique_ptr<qnn::QnnModel> qnn_model = std::make_unique<qnn::QnnModel>(logger, qnn_backend_manager_.get());
|
||||
bool loaded_from_cache = false;
|
||||
std::string ep_engine_cache;
|
||||
ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->GetEpContext(graph_viewer,
|
||||
context_cache_path_,
|
||||
is_qnn_ctx_model,
|
||||
qnn_cache_model_handler_->GetIsContextCacheFileExists(),
|
||||
ep_engine_cache,
|
||||
logger));
|
||||
ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnCtxFromOnnxModel(ep_engine_cache,
|
||||
*(qnn_model.get()),
|
||||
loaded_from_cache));
|
||||
// Load and execute from cached context if exist
|
||||
if (loaded_from_cache) {
|
||||
ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node));
|
||||
ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput());
|
||||
ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->LoadQnnCtxFromOnnxModel(graph_viewer,
|
||||
context_cache_path_,
|
||||
is_qnn_ctx_model,
|
||||
is_ctx_file_exist,
|
||||
qnn_backend_manager_.get(),
|
||||
*(qnn_model.get()),
|
||||
logger));
|
||||
ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node));
|
||||
ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput());
|
||||
|
||||
// fused node name is QNNExecutionProvider_QNN_[hash_id]_[id]
|
||||
// the name here should be same with context->node_name in compute_info
|
||||
LOGS(logger, VERBOSE) << "fused node name: " << fused_node.Name();
|
||||
qnn_models_.emplace(fused_node.Name(), std::move(qnn_model));
|
||||
// fused node name is QNNExecutionProvider_QNN_[hash_id]_[id]
|
||||
// the name here should be same with context->node_name in compute_info
|
||||
LOGS(logger, VERBOSE) << "fused node name: " << fused_node.Name();
|
||||
qnn_models_.emplace(fused_node.Name(), std::move(qnn_model));
|
||||
|
||||
ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger));
|
||||
return Status::OK();
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger));
|
||||
|
|
@ -524,8 +519,6 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
|
|||
qnn_models_,
|
||||
logger));
|
||||
}
|
||||
qnn_cache_model_handler_.reset();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ int64_t NodeAttrHelper::Get(const std::string& key, int64_t def_val) const {
|
|||
return node_attributes_.at(key).i();
|
||||
}
|
||||
|
||||
std::string NodeAttrHelper::Get(const std::string& key, const std::string& def_val) const {
|
||||
const std::string& NodeAttrHelper::Get(const std::string& key, const std::string& def_val) const {
|
||||
if (!HasAttr(key))
|
||||
return def_val;
|
||||
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class NodeAttrHelper {
|
|||
|
||||
int64_t Get(const std::string& key, int64_t def_val) const;
|
||||
|
||||
std::string Get(const std::string& key, const std::string& def_val) const;
|
||||
const std::string& Get(const std::string& key, const std::string& def_val) const;
|
||||
|
||||
std::vector<int64_t> Get(const std::string& key, const std::vector<int64_t>& def_val) const;
|
||||
std::vector<float> Get(const std::string& key, const std::vector<float>& def_val) const;
|
||||
|
|
|
|||
Loading…
Reference in a new issue