[QNN EP] Reduce overhead of QNN context binary loading (#17965)

### Description
Reduce overhead of QNN context binary loading by avoiding memory copy

### Motivation and Context
Reduce the session initialization time and memory usage while load from
QNN context binary
This commit is contained in:
Hector Li 2023-10-18 15:30:35 -07:00 committed by GitHub
parent cbb0e0f83c
commit 35ecce4549
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 74 additions and 80 deletions

View file

@ -60,26 +60,32 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
return Status::OK();
}
Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
std::string& ep_cache_context,
const logging::Logger& logger) {
Status QnnCacheModelHandler::GetEpContextFromModel(const std::string& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger) {
using namespace onnxruntime;
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger));
const auto& graph = model->MainGraph();
ORT_RETURN_IF_ERROR(GetEpContextFromGraph(GraphViewer(graph), ctx_onnx_model_path, ep_cache_context));
return Status::OK();
return GetEpContextFromGraph(GraphViewer(graph),
ctx_onnx_model_path,
qnn_backend_manager,
qnn_model);
}
Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
std::string& ep_cache_context) {
Status QnnCacheModelHandler::GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model) {
const auto& node = graph_viewer.Nodes().begin();
NodeAttrHelper node_helper(*node);
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
if (is_embed_mode) {
ep_cache_context = node_helper.Get(EP_CACHE_CONTEXT, "");
const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, "");
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast<char*>(context_binary.c_str()),
static_cast<uint64_t>(context_binary.length()),
qnn_model);
} else {
std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, "");
@ -88,17 +94,23 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
size_t buffer_size{0};
std::ifstream cache_file(context_binary_path.c_str(), std::ifstream::binary);
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file.");
cache_file.seekg(0, cache_file.end);
buffer_size = static_cast<size_t>(cache_file.tellg());
ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered.");
cache_file.seekg(0, cache_file.beg);
ep_cache_context.reserve(buffer_size);
std::unique_ptr<char[]> buffer = std::make_unique<char[]>(buffer_size);
ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for cache file.");
// Load file into buffer
ep_cache_context.assign(std::istreambuf_iterator<char>(cache_file), std::istreambuf_iterator<char>());
const auto& read_result = cache_file.read(buffer.get(), buffer_size);
ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file.");
cache_file.close();
ORT_RETURN_IF(ep_cache_context.length() != buffer_size, "Failed to read contents from cached context file.");
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
static_cast<uint64_t>(buffer_size),
qnn_model);
}
ORT_RETURN_IF(ep_cache_context.empty(), "Cached context empty.");
return Status::OK();
}

View file

@ -18,6 +18,7 @@ namespace onnxruntime {
namespace qnn {
class QnnModel;
class QnnBackendManager;
static const std::string EPCONTEXT_OP = "EPContext";
static const std::string MAIN_CONTEXT = "main_context";
@ -37,32 +38,24 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
std::vector<NodeArg*>& node_args,
onnxruntime::Graph& graph);
Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
std::string& ep_engine_cache,
const logging::Logger& logger);
Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
std::string& ep_cache_context);
class QnnCacheModelHandler {
public:
QnnCacheModelHandler(bool qnn_context_embed_mode) : qnn_context_embed_mode_(qnn_context_embed_mode) {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnCacheModelHandler);
Status GetEpContext(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
bool is_qnn_ctx_model,
bool is_ctx_cache_file_exist,
std::string& ep_engine_cache,
const logging::Logger& logger) const {
Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
bool is_qnn_ctx_model,
bool is_ctx_cache_file_exist,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger) {
if (is_qnn_ctx_model) {
ORT_RETURN_IF_ERROR(GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, ep_engine_cache));
return GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model);
} else if (is_ctx_cache_file_exist) {
ORT_RETURN_IF_ERROR(GetEpContextFromModel(ctx_onnx_model_path, ep_engine_cache, logger));
return GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger);
}
return Status::OK();
}
@ -92,12 +85,25 @@ class QnnCacheModelHandler {
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
const logging::Logger& logger);
private:
Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger);
Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model);
private:
bool is_metadata_ready_ = false;
// model_name_ to cache_source_ -- metadata get from generated Qnn context binary Onnx model
std::string model_name_ = "";
std::string model_description_ = "";
std::string graph_partition_name_ = "";
std::string cache_source_ = "";
std::string context_cache_path_ = "";
bool ctx_file_exists_ = false;
bool get_capability_round_2_ = false;

View file

@ -448,20 +448,7 @@ std::unique_ptr<unsigned char[]> QnnBackendManager::GetContextBinaryBuffer(uint6
return context_buffer;
}
Status QnnBackendManager::LoadCachedQnnCtxFromOnnxModel(const std::string& ep_engine_cache,
QnnModel& qnn_model,
bool& loaded_from_cache) {
loaded_from_cache = false;
if (!ep_engine_cache.empty()) {
ORT_RETURN_IF_ERROR(LoadCachedQnnContextFromBuffer(ep_engine_cache, qnn_model));
loaded_from_cache = true;
}
return Status::OK();
}
Status QnnBackendManager::LoadCachedQnnContextFromBuffer(const std::string& buffer, QnnModel& qnn_model) {
Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model) {
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
nullptr == qnn_sys_interface_.systemContextFree;
@ -474,8 +461,8 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(const std::string& buff
const QnnSystemContext_BinaryInfo_t* binary_info = nullptr;
Qnn_ContextBinarySize_t binary_info_size{0};
rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle,
static_cast<void*>(const_cast<char*>(buffer.c_str())),
static_cast<uint64_t>(buffer.length()),
static_cast<void*>(buffer),
buffer_length,
&binary_info,
&binary_info_size);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info.");
@ -502,8 +489,8 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(const std::string& buff
rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
device_handle_,
(const QnnContext_Config_t**)&context_config_,
static_cast<void*>(const_cast<char*>(buffer.c_str())),
static_cast<uint64_t>(buffer.length()),
static_cast<void*>(buffer),
buffer_length,
&context_,
profile_backend_handle_);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary.");

View file

@ -74,9 +74,7 @@ class QnnBackendManager {
std::unique_ptr<unsigned char[]> GetContextBinaryBuffer(uint64_t& written_buffer_size);
Status LoadCachedQnnCtxFromOnnxModel(const std::string& ep_engine_cache,
QnnModel& qnn_model,
bool& loaded_from_cache);
Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model);
Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context);
@ -174,8 +172,6 @@ class QnnBackendManager {
return (backend_build_id == nullptr ? std::string("") : std::string(backend_build_id));
}
Status LoadCachedQnnContextFromBuffer(const std::string& buffer, QnnModel& qnn_model);
private:
const std::string backend_path_;
const logging::Logger* logger_ = nullptr;

View file

@ -483,33 +483,28 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
bool is_qnn_ctx_model = false;
ORT_RETURN_IF_ERROR(qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs, is_qnn_ctx_model));
if (context_cache_enabled_ || is_qnn_ctx_model) {
bool is_ctx_file_exist = qnn_cache_model_handler_->GetIsContextCacheFileExists();
if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) {
ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support single partition for context cache feature.");
std::unique_ptr<qnn::QnnModel> qnn_model = std::make_unique<qnn::QnnModel>(logger, qnn_backend_manager_.get());
bool loaded_from_cache = false;
std::string ep_engine_cache;
ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->GetEpContext(graph_viewer,
context_cache_path_,
is_qnn_ctx_model,
qnn_cache_model_handler_->GetIsContextCacheFileExists(),
ep_engine_cache,
logger));
ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnCtxFromOnnxModel(ep_engine_cache,
*(qnn_model.get()),
loaded_from_cache));
// Load and execute from cached context if exist
if (loaded_from_cache) {
ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node));
ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput());
ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->LoadQnnCtxFromOnnxModel(graph_viewer,
context_cache_path_,
is_qnn_ctx_model,
is_ctx_file_exist,
qnn_backend_manager_.get(),
*(qnn_model.get()),
logger));
ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node));
ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput());
// fused node name is QNNExecutionProvider_QNN_[hash_id]_[id]
// the name here should be same with context->node_name in compute_info
LOGS(logger, VERBOSE) << "fused node name: " << fused_node.Name();
qnn_models_.emplace(fused_node.Name(), std::move(qnn_model));
// fused node name is QNNExecutionProvider_QNN_[hash_id]_[id]
// the name here should be same with context->node_name in compute_info
LOGS(logger, VERBOSE) << "fused node name: " << fused_node.Name();
qnn_models_.emplace(fused_node.Name(), std::move(qnn_model));
ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger));
return Status::OK();
}
ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger));
return Status::OK();
}
ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger));
@ -524,8 +519,6 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
qnn_models_,
logger));
}
qnn_cache_model_handler_.reset();
return Status::OK();
}
} // namespace onnxruntime

View file

@ -119,7 +119,7 @@ int64_t NodeAttrHelper::Get(const std::string& key, int64_t def_val) const {
return node_attributes_.at(key).i();
}
std::string NodeAttrHelper::Get(const std::string& key, const std::string& def_val) const {
const std::string& NodeAttrHelper::Get(const std::string& key, const std::string& def_val) const {
if (!HasAttr(key))
return def_val;

View file

@ -44,7 +44,7 @@ class NodeAttrHelper {
int64_t Get(const std::string& key, int64_t def_val) const;
std::string Get(const std::string& key, const std::string& def_val) const;
const std::string& Get(const std::string& key, const std::string& def_val) const;
std::vector<int64_t> Get(const std::string& key, const std::vector<int64_t>& def_val) const;
std::vector<float> Get(const std::string& key, const std::vector<float>& def_val) const;