[QNN EP] Qnn cache improvement (#17757)

### Description
Improve the QNN context binary cache feature to reduce the memory
overhead and initialization time overhead.
Instead of dumping a Qnn context binary file with metadata as header, we
dump a Onnx format file with metadata inside Onnx node.

### Motivation and Context
 reduce the memory overhead and initialization time overhead
This commit is contained in:
Hector Li 2023-10-06 15:56:33 -07:00 committed by GitHub
parent 569876fb16
commit 385fab5bae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 770 additions and 285 deletions

View file

@ -27,6 +27,7 @@ Do not modify directly.*
* <a href="#com.microsoft.DequantizeWithOrder">com.microsoft.DequantizeWithOrder</a>
* <a href="#com.microsoft.DynamicQuantizeLSTM">com.microsoft.DynamicQuantizeLSTM</a>
* <a href="#com.microsoft.DynamicQuantizeMatMul">com.microsoft.DynamicQuantizeMatMul</a>
* <a href="#com.microsoft.EPContext">com.microsoft.EPContext</a>
* <a href="#com.microsoft.EmbedLayerNormalization">com.microsoft.EmbedLayerNormalization</a>
* <a href="#com.microsoft.ExpandDims">com.microsoft.ExpandDims</a>
* <a href="#com.microsoft.FastGelu">com.microsoft.FastGelu</a>
@ -1520,6 +1521,55 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>
### <a name="com.microsoft.EPContext"></a><a name="com.microsoft.epcontext">**com.microsoft.EPContext**</a>
Onnx node container for EP context.
#### Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
#### Attributes
<dl>
<dt><tt>embed_mode</tt> : int</dt>
<dd>1: indicate ep_cache_context is the context content. 0: indicate ep_cache_context is the file path to the context content.The path is relative to this Onnx file. Default is 1.</dd>
<dt><tt>ep_cache_context</tt> : string</dt>
<dd>payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0.</dd>
<dt><tt>ep_sdk_version</tt> : string</dt>
<dd>(Optional) SDK version used to convert the model.</dd>
<dt><tt>main_context</tt> : int</dt>
<dd>Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.</dd>
<dt><tt>notes</tt> : string</dt>
<dd>(Optional) Some notes for the model</dd>
<dt><tt>partition_name</tt> : string</dt>
<dd>(Optional) partitioned graph name.</dd>
<dt><tt>source</tt> : string</dt>
<dd>(Optional) the source used to generate the engine/context cache file. Ort EP or native SDK tool chain</dd>
</dl>
#### Inputs (1 - &#8734;)
<dl>
<dt><tt>inputs</tt> (variadic) : T</dt>
<dd>List of tensors for inputs</dd>
</dl>
#### Outputs (1 - &#8734;)
<dl>
<dt><tt>outputs</tt> (variadic) : T</dt>
<dd>One or more outputs, list of tensors for outputs</dd>
</dl>
#### Type Constraints
<dl>
<dt><tt>T</tt> : tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(float16), tensor(float), tensor(double)</dt>
<dd>Constrain input and output types.</dd>
</dl>
### <a name="com.microsoft.EmbedLayerNormalization"></a><a name="com.microsoft.embedlayernormalization">**com.microsoft.EmbedLayerNormalization**</a>
EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing.

View file

@ -3597,6 +3597,9 @@ struct OrtApi {
* "rpc_control_latency": QNN RPC control latency.
* "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance",
* "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default".
* "qnn_context_embed_mode", 1 means dump the QNN context binary into node attribute EPContext->ep_cache_context in the Onnx skeleton model.
* 0 means dump the QNN context binary into separate bin file and set the path to EPContext->ep_cache_context.
* The path is relative path to the Onnx skeleton model file.
* "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will
* dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and
* may alter model/EP partitioning. Use only for debugging.

View file

@ -2844,6 +2844,83 @@ void RegisterContribSchemas() {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
});
ONNX_CONTRIB_OPERATOR_SCHEMA(EPContext)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc("Onnx node container for EP context.")
.Attr(
"main_context",
"Usually each single EPContext associate with a graph partition."
"But for some case like QNN, it has single EPContext contains all partitions."
"In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context."
"The path is relative to this Onnx file. Default is 1.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Attr(
"ep_cache_context",
"payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0.",
AttributeProto::STRING,
OPTIONAL_VALUE)
.Attr(
"embed_mode",
"1: indicate ep_cache_context is the context content. 0: indicate ep_cache_context is the file path to the context content."
"The path is relative to this Onnx file. Default is 1.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Attr(
"ep_sdk_version",
"(Optional) SDK version used to convert the model.",
AttributeProto::STRING,
OPTIONAL_VALUE)
.Attr(
"partition_name",
"(Optional) partitioned graph name.",
AttributeProto::STRING,
OPTIONAL_VALUE)
.Attr(
"source",
"(Optional) the source used to generate the engine/context cache file. Ort EP or native SDK tool chain",
AttributeProto::STRING,
OPTIONAL_VALUE)
.Attr("notes", "(Optional) Some notes for the model", AttributeProto::STRING, OPTIONAL_VALUE)
.AllowUncheckedAttributes()
.Input(
0,
"inputs",
"List of tensors for inputs",
"T",
OpSchema::Variadic,
true,
1,
OpSchema::NonDifferentiable)
.Output(
0,
"outputs",
"One or more outputs, list of tensors for outputs",
"T",
OpSchema::Variadic,
true,
1,
OpSchema::NonDifferentiable)
.TypeConstraint(
"T",
{"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(float16)",
"tensor(float)",
"tensor(double)"},
"Constrain input and output types.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
});
static const char* BitmaskDropout_ver1_doc = R"DOC(
BitmaskDropout takes an input floating-point tensor, an optional input ratio (floating-point scalar) and an optional input training_mode (boolean scalar).
It produces two tensor outputs: output (floating-point tensor) and mask (optional `Tensor<uint32>`). If `training_mode` is true then the output Y will be a random dropout.

View file

@ -0,0 +1,264 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/qnn/builder/onnx_ctx_model_helper.h"
#include "core/graph/constants.h"
#include "core/providers/qnn/builder/qnn_model.h"
#include <iostream>
#include <fstream>
#include <filesystem>
namespace onnxruntime {
namespace qnn {
Status IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
bool& is_qnn_ctx_model) {
is_qnn_ctx_model = false;
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph);
// It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type
int count = 0;
for (const auto& node : graph_viewer.Nodes()) {
if (EPCONTEXT_OP == node.OpType()) {
is_qnn_ctx_model = true;
}
++count;
}
ORT_RETURN_IF(is_qnn_ctx_model && count > 1, "Fused graph should only has 1 single EPContext node.");
}
return Status::OK();
}
bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer) {
// It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type
for (const auto& node : graph_viewer.Nodes()) {
if (EPCONTEXT_OP == node.OpType()) {
return true;
}
}
return false;
}
Status CreateNodeArgs(const std::vector<std::string>& names,
const std::unordered_map<std::string, OnnxTensorInfo>& tensor_info_table,
std::vector<NodeArg*>& node_args,
onnxruntime::Graph& graph) {
using namespace ONNX_NAMESPACE;
for (size_t i = 0; i < names.size(); ++i) {
std::string name = names[i];
ORT_RETURN_IF(tensor_info_table.find(name) == tensor_info_table.end(), "Tensor name: ", name, " not found in tensor_info_table");
const OnnxTensorInfo& tensor_info = tensor_info_table.at(name);
TypeProto tensor_type;
tensor_type.mutable_tensor_type()->set_elem_type(tensor_info.data_type_);
for (size_t j = 0; j < tensor_info.shape_.size(); ++j) {
tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(tensor_info.shape_[j]);
}
auto& input_arg = graph.GetOrCreateNodeArg(name, &tensor_type);
node_args.push_back(&input_arg);
}
return Status::OK();
}
Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
std::string& ep_cache_context,
const logging::Logger& logger) {
using namespace onnxruntime;
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger));
const auto& graph = model->MainGraph();
ORT_RETURN_IF_ERROR(GetEpContextFromGraph(GraphViewer(graph), ctx_onnx_model_path, ep_cache_context));
return Status::OK();
}
Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
std::string& ep_cache_context) {
const auto& node = graph_viewer.Nodes().begin();
NodeAttrHelper node_helper(*node);
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
if (is_embed_mode) {
ep_cache_context = node_helper.Get(EP_CACHE_CONTEXT, "");
} else {
std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, "");
std::string context_binary_path(std::filesystem::path(ctx_onnx_model_path).parent_path().string() +
"/" + external_qnn_context_binary_file_name);
size_t buffer_size{0};
std::ifstream cache_file(context_binary_path.c_str(), std::ifstream::binary);
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file.");
cache_file.seekg(0, cache_file.end);
buffer_size = static_cast<size_t>(cache_file.tellg());
ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered.");
cache_file.seekg(0, cache_file.beg);
ep_cache_context.reserve(buffer_size);
// Load file into buffer
ep_cache_context.assign(std::istreambuf_iterator<char>(cache_file), std::istreambuf_iterator<char>());
cache_file.close();
ORT_RETURN_IF(ep_cache_context.length() != buffer_size, "Failed to read contents from cached context file.");
}
ORT_RETURN_IF(ep_cache_context.empty(), "Cached context empty.");
return Status::OK();
}
Status QnnCacheModelHandler::GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path,
std::string& model_name,
std::string& model_description,
std::string& graph_partition_name,
std::string& cache_source,
const logging::Logger& logger) {
if (!is_metadata_ready_) {
using namespace onnxruntime;
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger));
const auto& graph = GraphViewer(model->MainGraph());
const auto& node = graph.Nodes().begin();
NodeAttrHelper node_helper(*node);
model_name_ = graph.Name();
model_description_ = graph.Description();
graph_partition_name_ = node_helper.Get(PARTITION_NAME, "");
cache_source_ = node_helper.Get(SOURCE, "");
is_metadata_ready_ = true;
}
model_name = model_name_;
model_description = model_description_;
graph_partition_name = graph_partition_name_;
cache_source = cache_source_;
return Status::OK();
}
bool QnnCacheModelHandler::IsContextCacheFileExists(const std::string& customer_context_cache_path,
const std::string& model_description,
const onnxruntime::PathString& model_pathstring) {
// Avoid duplicate work
if (ctx_file_exists_) {
return ctx_file_exists_;
}
model_description_ = model_description;
// Use user provided context cache file path if exist, otherwise try model_file.onnx_ctx.onnx by default
if (customer_context_cache_path.empty()) {
context_cache_path_ = PathToUTF8String(model_pathstring) + "_qnn_ctx.onnx";
} else {
context_cache_path_ = customer_context_cache_path;
}
ctx_file_exists_ = std::filesystem::exists(context_cache_path_);
return ctx_file_exists_;
}
Status QnnCacheModelHandler::ValidateWithContextFile(const std::string& model_name,
const std::string& graph_partition_name,
const logging::Logger& logger) {
ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!");
std::string model_name_from_ctx_cache;
std::string model_description_from_ctx_cache;
std::string graph_partition_name_from_ctx_cache;
std::string cache_source;
ORT_RETURN_IF_ERROR(GetMetadataFromEpContextModel(context_cache_path_,
model_name_from_ctx_cache,
model_description_from_ctx_cache,
graph_partition_name_from_ctx_cache,
cache_source,
logger));
// The source attribute from the skeleton onnx file indicate whether it's generated from QNN toolchain or ORT
if (cache_source != kQnnExecutionProvider) {
return Status::OK();
}
ORT_RETURN_IF(model_name != model_name_from_ctx_cache,
"Model file name from context cache metadata: " + model_name_from_ctx_cache +
" is different with target: " + model_name +
". Please make sure the context binary file matches the model.");
ORT_RETURN_IF(model_description_ != model_description_from_ctx_cache,
"Model description from context cache metadata: " + model_description_from_ctx_cache +
" is different with target: " + model_description_ +
". Please make sure the context binary file matches the model.");
ORT_RETURN_IF(graph_partition_name != graph_partition_name_from_ctx_cache && get_capability_round_2_,
"Graph name from context cache metadata: " + graph_partition_name_from_ctx_cache +
" is different with target: " + graph_partition_name +
". You may need to re-generate the context binary file.");
get_capability_round_2_ = true;
return Status::OK();
}
Status QnnCacheModelHandler::GenerateCtxCacheOnnxModel(unsigned char* buffer,
uint64_t buffer_size,
const std::string& sdk_build_version,
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
const logging::Logger& logger) {
std::unordered_map<std::string, int> domain_to_version = {{kOnnxDomain, 11}, {kMSDomain, 1}};
Model model(model_name_, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, logger);
auto& graph = model.MainGraph();
graph.SetDescription(model_description_);
using namespace ONNX_NAMESPACE;
int index = 0;
// Still need more work to support multiple partition, it's out of EP's scope.
// Already have code to make sure it's single partition before this method get invoked.
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph);
Node& fused_node = fused_node_graph.fused_node;
// graph_viewer.Name() is generated in GetCapability, e.g QNN_[hash_id]_[id]
// dump graph_viewer.Name() as metadata in context cache binary file, so that we can validate it in GetCapability
auto qnn_model_kv = qnn_models.find(fused_node.Name());
ORT_RETURN_IF(qnn_model_kv == qnn_models.end(), fused_node.Name(), " not exist in QnnModel table.");
auto qnn_model = qnn_model_kv->second.get();
std::vector<NodeArg*> inputs;
std::vector<NodeArg*> outputs;
ORT_RETURN_IF_ERROR(CreateNodeArgs(qnn_model->GetInputNames(), qnn_model->GetInputsInfo(), inputs, graph));
ORT_RETURN_IF_ERROR(CreateNodeArgs(qnn_model->GetOutputNames(), qnn_model->GetOutputsInfo(), outputs, graph));
const std::string& graph_name = graph_viewer.Name();
auto& ep_node = graph.AddNode(graph_name,
EPCONTEXT_OP,
"Onnx Qnn context binary cache for graph partition: " + graph_name,
inputs,
outputs,
nullptr,
kMSDomain);
// Only dump the context buffer once since all QNN graph are in one single context
if (0 == index) {
if (qnn_context_embed_mode_) {
std::string cache_payload(buffer, buffer + buffer_size);
ep_node.AddAttribute(EP_CACHE_CONTEXT, cache_payload);
} else {
std::string context_cache_path(context_cache_path_ + "_" + graph_name + ".bin");
std::string context_cache_name(std::filesystem::path(context_cache_path).filename().string());
std::ofstream of_stream(context_cache_path.c_str(), std::ofstream::binary);
if (!of_stream) {
LOGS(logger, ERROR) << "Failed to open create context file.";
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to open context cache file.");
}
of_stream.write(reinterpret_cast<char*>(buffer), buffer_size);
ep_node.AddAttribute(EP_CACHE_CONTEXT, context_cache_name);
}
} else {
ep_node.AddAttribute(MAIN_CONTEXT, static_cast<int64_t>(0));
}
int64_t embed_mode = qnn_context_embed_mode_ ? static_cast<int64_t>(1) : static_cast<int64_t>(0);
ep_node.AddAttribute(EMBED_MODE, embed_mode);
ep_node.AddAttribute(EP_SDK_VER, sdk_build_version);
ep_node.AddAttribute(PARTITION_NAME, graph_name);
ep_node.AddAttribute(SOURCE, kQnnExecutionProvider);
++index;
}
ORT_RETURN_IF_ERROR(graph.Resolve());
ORT_RETURN_IF_ERROR(Model::Save(model, context_cache_path_));
return Status::OK();
}
} // namespace qnn
} // namespace onnxruntime

View file

@ -0,0 +1,108 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <vector>
#include "qnn_def.h"
#include "core/common/logging/logging.h"
#include "core/graph/graph_viewer.h"
#include "core/providers/shared/utils/utils.h"
#include "core/graph/model.h"
#include "core/framework/execution_provider.h"
namespace onnxruntime {
namespace qnn {
class QnnModel;
static const std::string EPCONTEXT_OP = "EPContext";
static const std::string MAIN_CONTEXT = "main_context";
static const std::string EMBED_MODE = "embed_mode";
static const std::string EP_CACHE_CONTEXT = "ep_cache_context";
static const std::string EP_SDK_VER = "ep_sdk_version";
static const std::string PARTITION_NAME = "partition_name";
static const std::string SOURCE = "source";
Status IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
bool& is_qnn_ctx_model);
bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer);
Status CreateNodeArgs(const std::vector<std::string>& names,
const std::unordered_map<std::string, OnnxTensorInfo>& tensor_info_table,
std::vector<NodeArg*>& node_args,
onnxruntime::Graph& graph);
Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
std::string& ep_engine_cache,
const logging::Logger& logger);
Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
std::string& ep_cache_context);
class QnnCacheModelHandler {
public:
QnnCacheModelHandler(bool qnn_context_embed_mode) : qnn_context_embed_mode_(qnn_context_embed_mode) {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnCacheModelHandler);
Status GetEpContext(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
bool is_qnn_ctx_model,
bool is_ctx_cache_file_exist,
std::string& ep_engine_cache,
const logging::Logger& logger) const {
if (is_qnn_ctx_model) {
ORT_RETURN_IF_ERROR(GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, ep_engine_cache));
} else if (is_ctx_cache_file_exist) {
ORT_RETURN_IF_ERROR(GetEpContextFromModel(ctx_onnx_model_path, ep_engine_cache, logger));
}
return Status::OK();
}
bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
const std::string& model_description,
const onnxruntime::PathString& model_pathstring);
bool GetIsContextCacheFileExists() const {
return ctx_file_exists_;
}
Status ValidateWithContextFile(const std::string& model_name,
const std::string& graph_name,
const logging::Logger& logger);
Status GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path,
std::string& model_name,
std::string& model_description,
std::string& graph_partition_name,
std::string& cache_source,
const logging::Logger& logger);
Status GenerateCtxCacheOnnxModel(unsigned char* buffer,
uint64_t buffer_size,
const std::string& sdk_build_version,
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
const logging::Logger& logger);
private:
bool is_metadata_ready_ = false;
std::string model_name_ = "";
std::string model_description_ = "";
std::string graph_partition_name_ = "";
std::string cache_source_ = "";
std::string context_cache_path_ = "";
bool ctx_file_exists_ = false;
bool get_capability_round_2_ = false;
bool qnn_context_embed_mode_ = true;
}; // QnnCacheModelHandler
} // namespace qnn
} // namespace onnxruntime

View file

@ -3,8 +3,6 @@
#include "qnn_backend_manager.h"
#include "qnn_model.h"
#include <iostream>
#include <fstream>
#include <filesystem>
#include "QnnOpDef.h"
#include "HTP/QnnHtpPerfInfrastructure.h"
@ -16,6 +14,7 @@
#include "core/common/gsl.h"
#include "core/framework/endian_utils.h"
#include "core/common/logging/capture.h"
#include "core/providers/qnn/builder/onnx_ctx_model_helper.h"
// Flag to determine if Backend should do node validation for each opNode added
#define DO_GRAPH_NODE_VALIDATIONS 1
@ -28,8 +27,6 @@ typedef Qnn_ErrorHandle_t (*QnnInterfaceGetProvidersFn_t)(const QnnInterface_t**
typedef Qnn_ErrorHandle_t (*QnnSystemInterfaceGetProvidersFn_t)(const QnnSystemInterface_t*** providerList,
uint32_t* numProviders);
constexpr const char* QNN_PROVIDER = "ORTQNNEP";
static Qnn_Version_t GetQnnInterfaceApiVersion(const QnnInterface_t* qnn_interface) {
return qnn_interface->apiVersion.coreApiVersion;
}
@ -412,135 +409,64 @@ Status QnnBackendManager::ReleaseContext() {
return Status::OK();
}
bool QnnBackendManager::IsContextCacheFileExists(const std::string& customer_context_cache_path,
const std::string& model_description,
const onnxruntime::PathString& model_pathstring) {
// Avoid duplicate work
if (!context_cache_path_.empty()) {
return ctx_file_exists_;
}
model_description_ = model_description;
// Use user provided context cache file path if exist, otherwise try model_file.onnx.bin by default
if (customer_context_cache_path.empty()) {
context_cache_path_ = PathToUTF8String(model_pathstring) + ".bin";
} else {
context_cache_path_ = customer_context_cache_path;
}
ctx_file_exists_ = std::filesystem::exists(context_cache_path_);
return ctx_file_exists_;
}
Status WriteInt16ToBinaryFile(std::ofstream& of_stream, uint16_t value) {
const std::vector<uint16_t> data{value};
std::vector<unsigned char> data_bytes(sizeof(uint16_t) / sizeof(unsigned char));
ORT_RETURN_IF_ERROR(onnxruntime::utils::WriteLittleEndian(gsl::make_span(data), gsl::make_span(data_bytes)));
of_stream.write(reinterpret_cast<char*>(data_bytes.data()), data_bytes.size());
return Status::OK();
}
Status QnnBackendManager::DumpQnnContext(const std::string& model_name, const std::string& graph_name) {
std::unique_ptr<unsigned char[]> QnnBackendManager::GetContextBinaryBuffer(uint64_t& written_buffer_size) {
if (nullptr == qnn_interface_.contextGetBinarySize ||
nullptr == qnn_interface_.contextGetBinary) {
LOGS(*logger_, ERROR) << "Failed to get valid function pointer.";
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get valid function pointer.");
return nullptr;
}
uint64_t required_buffer_size(0);
Qnn_ErrorHandle_t rt = qnn_interface_.contextGetBinarySize(context_, &required_buffer_size);
if (QNN_CONTEXT_NO_ERROR != rt) {
LOGS(*logger_, ERROR) << "Failed to get QNN context binary size. Error code: " << rt;
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get QNN context binary size.");
return nullptr;
}
std::unique_ptr<unsigned char[]> context_buffer = std::make_unique<unsigned char[]>(required_buffer_size);
if (nullptr == context_buffer) {
LOGS(*logger_, ERROR) << "Failed to allocate buffer for context cache.";
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to allocate buffer for context cache.");
return nullptr;
}
uint64_t written_buffer_size(0);
rt = qnn_interface_.contextGetBinary(context_,
reinterpret_cast<void*>(context_buffer.get()),
required_buffer_size,
&written_buffer_size);
if (QNN_CONTEXT_NO_ERROR != rt) {
LOGS(*logger_, ERROR) << "Failed to get context binary.";
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get context binary.");
return nullptr;
}
if (required_buffer_size < written_buffer_size) {
LOGS(*logger_, ERROR) << "Context written buffer size: " << written_buffer_size
<< " exceeds allocated buffer size: " << required_buffer_size;
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Context written buffer exceeds allocated buffer size.");
return nullptr;
}
std::ofstream of_stream(context_cache_path_.c_str(), std::ofstream::binary);
if (!of_stream) {
LOGS(*logger_, ERROR) << "Failed to open cached context file.";
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to open context cache file.");
LOGS(*logger_, VERBOSE) << "Get context binary buffer succeed.";
return context_buffer;
}
Status QnnBackendManager::LoadCachedQnnCtxFromOnnxModel(const std::string& ep_engine_cache,
QnnModel& qnn_model,
bool& loaded_from_cache) {
loaded_from_cache = false;
if (!ep_engine_cache.empty()) {
ORT_RETURN_IF_ERROR(LoadCachedQnnContextFromBuffer(ep_engine_cache, qnn_model));
loaded_from_cache = true;
}
// Write Ort metadata into context binary file
uint16_t model_name_length = static_cast<uint16_t>(model_name.length());
uint16_t graph_name_length = static_cast<uint16_t>(graph_name.length());
uint16_t model_description_length = static_cast<uint16_t>(model_description_.length());
// Header: uint16_t(totale_length)|uint16_t(model_name_length)|model_name|uint16_t(graph_name_length)|graph_name|uint16_t(model_description_length)|model_description
uint16_t header_length = 4 * sizeof(uint16_t) + model_name_length + graph_name_length + model_description_length;
uint16_t totale_length = header_length + static_cast<uint16_t>(strlen(QNN_PROVIDER));
of_stream.write(QNN_PROVIDER, strlen(QNN_PROVIDER));
ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, header_length));
ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, model_name_length));
of_stream.write(model_name.c_str(), model_name_length);
ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, graph_name_length));
of_stream.write(graph_name.c_str(), graph_name_length);
ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, model_description_length));
of_stream.write(model_description_.c_str(), model_description_length);
model_description_.clear();
LOGS(*logger_, VERBOSE) << "Dump metadata with length: " << totale_length;
of_stream.write(reinterpret_cast<char*>(context_buffer.get()), written_buffer_size);
LOGS(*logger_, VERBOSE) << "Dump QNN Context completed.";
return Status::OK();
}
Status QnnBackendManager::LoadCachedQnnContext(QnnModel& qnn_model) {
Status QnnBackendManager::LoadCachedQnnContextFromBuffer(const std::string& buffer, QnnModel& qnn_model) {
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
nullptr == qnn_sys_interface_.systemContextFree;
ORT_RETURN_IF(result, "Failed to get valid function pointer.");
ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!");
uint64_t buffer_size{0};
std::ifstream cache_file(context_cache_path_.c_str(), std::ifstream::binary);
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file.");
cache_file.seekg(0, cache_file.end);
buffer_size = cache_file.tellg();
ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered.");
cache_file.seekg(0, cache_file.beg);
// Skip Ort generated metadata
if (ort_generated_ctx_cache_) {
cache_file.seekg(ort_ctx_metadata_length_);
buffer_size -= ort_ctx_metadata_length_;
}
std::unique_ptr<unsigned char[]> buffer = std::make_unique<unsigned char[]>(buffer_size);
ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for cache file.");
// Load file into buffer
const auto& read_result = cache_file.read(reinterpret_cast<char*>(buffer.get()), buffer_size);
cache_file.close();
ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file.");
QnnSystemContext_Handle_t sys_ctx_handle = nullptr;
auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle.");
@ -548,8 +474,8 @@ Status QnnBackendManager::LoadCachedQnnContext(QnnModel& qnn_model) {
const QnnSystemContext_BinaryInfo_t* binary_info = nullptr;
Qnn_ContextBinarySize_t binary_info_size{0};
rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle,
static_cast<void*>(buffer.get()),
buffer_size,
static_cast<void*>(const_cast<char*>(buffer.c_str())),
static_cast<uint64_t>(buffer.length()),
&binary_info,
&binary_info_size);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info.");
@ -576,12 +502,14 @@ Status QnnBackendManager::LoadCachedQnnContext(QnnModel& qnn_model) {
rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
device_handle_,
(const QnnContext_Config_t**)&context_config_,
static_cast<void*>(buffer.get()),
buffer_size,
static_cast<void*>(const_cast<char*>(buffer.c_str())),
static_cast<uint64_t>(buffer.length()),
&context_,
profile_backend_handle_);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary.");
// More work to support multiple partition, how to map the graph name in compile to qnn graph name
// Need the lower level framework to understand EPContext op and pass in the partition_name in fused_node during Compile
ORT_RETURN_IF_ERROR(qnn_model.DeserializeGraphInfoFromBinaryInfo(graphs_info[0]));
qnn_sys_interface_.systemContextFree(sys_ctx_handle);
@ -590,126 +518,10 @@ Status QnnBackendManager::LoadCachedQnnContext(QnnModel& qnn_model) {
ORT_RETURN_IF_ERROR(ExtractBackendProfilingInfo());
context_created_ = true;
model_description_.clear();
model_description_from_ctx_cache_.clear();
LOGS(*logger_, VERBOSE) << "Load from cached QNN Context completed.";
return Status::OK();
}
/* \brief: Read string data from binary file with given length
* \param[in] binary_file - file stream of the binary file
* \param[out] result_str - string read from binary file
* \param[out] length - length to read
*/
Status ReadStringFromBinaryFile(std::ifstream& binary_file, std::string& result_str, size_t length) {
result_str.resize(length);
const auto& read_result = binary_file.read(result_str.data(), length);
ORT_RETURN_IF(!read_result, "Failed to read contents from cached context binary file.");
return Status::OK();
}
/* \brief: Read a uint16_t from binary file
* \param[in] binary_file - file stream of the binary file
* \param[out] value - uint16_t value
*/
Status ReadInt16FromBinaryFile(std::ifstream& binary_file, uint16_t& value) {
std::unique_ptr<char[]> buffer = std::make_unique<char[]>(sizeof(uint16_t));
ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for buffer.");
const auto& read_result = binary_file.read(buffer.get(), sizeof(uint16_t));
ORT_RETURN_IF(!read_result, "Failed to read contents from cached context binary file.");
auto src = gsl::make_span<const unsigned char>(reinterpret_cast<unsigned char*>(buffer.get()), sizeof(uint16_t));
std::vector<uint16_t> dst(1);
ORT_RETURN_IF_ERROR(onnxruntime::utils::ReadLittleEndian(src, gsl::make_span(dst)));
value = dst[0];
return Status::OK();
}
/* \brief: Try to get metadata from Ort generated context cache binary file.
* Cached context binary file generated by Ort has some metadata which can be used for validation with the model
* to avoid user choose a wrong context binary file which is not for this model
* It is treated as Qnn generated context binary file if no metadata found from the file
*/
Status QnnBackendManager::GetMetadataFromOrtContextFile() {
// Only try parse meta data once
if (ctx_metadata_tried_) {
return Status::OK();
}
ctx_metadata_tried_ = true;
uint64_t buffer_size = 0;
std::ifstream cache_file(context_cache_path_.c_str(), std::ifstream::binary);
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open context cache file.");
cache_file.seekg(0, cache_file.end);
buffer_size = cache_file.tellg();
ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered.");
cache_file.seekg(0, cache_file.beg);
// Read ort flag
std::string ort_flag("");
size_t ort_flag_length = strlen(QNN_PROVIDER);
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, ort_flag, ort_flag_length));
// It's not Ort generated context binary file
if (strncmp(ort_flag.c_str(), QNN_PROVIDER, ort_flag_length) != 0) {
return Status::OK();
}
ort_generated_ctx_cache_ = true;
uint16_t str_length = 0;
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
ort_ctx_metadata_length_ = str_length + static_cast<uint16_t>(ort_flag_length);
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, model_name_from_ctx_cache_, static_cast<size_t>(str_length)));
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, graph_name_from_ctx_cache_, static_cast<size_t>(str_length)));
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, model_description_from_ctx_cache_, static_cast<size_t>(str_length)));
return Status::OK();
}
/* \brief: Validate the model file name and graph name with Ort generated context cache metadata
* \param[in] model_name - model file name
* \param[in] graph_name - graph name, e.g Ort_QNN_[hash_id]_[id]. Since GetCapability is called twice,
* [hash_id]_[id] changes even for same graph,
* so only validate the graph name for 2nd call
*/
Status QnnBackendManager::ValidateWithContextFile(const std::string& model_name, const std::string& graph_name) {
ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!");
// Get metadata from cached context binary file
ORT_RETURN_IF_ERROR(GetMetadataFromOrtContextFile());
// The context binary file doesn't have ORT metadata, so it is generated from QNN toolchain not from ORT
if (!ort_generated_ctx_cache_) {
return Status::OK();
}
ORT_RETURN_IF(model_name != model_name_from_ctx_cache_,
"Model file name from context cache metadata: " + model_name_from_ctx_cache_ +
" is different with target: " + model_name +
". Please make sure the context binary file matches the model.");
ORT_RETURN_IF(model_description_ != model_description_from_ctx_cache_,
"Model description from context cache metadata: " + model_description_from_ctx_cache_ +
" is different with target: " + model_description_ +
". Please make sure the context binary file matches the model.");
ORT_RETURN_IF(graph_name != graph_name_from_ctx_cache_ && get_capability_round_2_,
"Graph name from context cache metadata: " + graph_name_from_ctx_cache_ +
" is different with target: " + graph_name +
". You may need to re-generate the context binary file.");
get_capability_round_2_ = true;
return Status::OK();
}
Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context) {
if (backend_setup_completed_) {
LOGS(logger, VERBOSE) << "Backend setup already!";
@ -728,8 +540,9 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_
ORT_RETURN_IF_ERROR(LoadQnnSystemLib());
}
sdk_build_version_ = GetBackendBuildId();
LOGS(logger, VERBOSE) << "Backend build version: "
<< GetBackendBuildId();
<< sdk_build_version_;
SetLogger(&logger);
LOGS(logger, VERBOSE) << "SetLogger succeed.";

View file

@ -22,6 +22,7 @@ namespace onnxruntime {
namespace qnn {
class QnnModel;
class QnnCacheModelHandler;
class QnnBackendManager {
public:
@ -71,13 +72,11 @@ class QnnBackendManager {
return CreateContext();
}
Status DumpQnnContext(const std::string& model_name, const std::string& graph_name);
std::unique_ptr<unsigned char[]> GetContextBinaryBuffer(uint64_t& written_buffer_size);
Status LoadCachedQnnContext(QnnModel& qnn_model);
Status GetMetadataFromOrtContextFile();
Status ValidateWithContextFile(const std::string& model_name, const std::string& graph_name);
Status LoadCachedQnnCtxFromOnnxModel(const std::string& ep_engine_cache,
QnnModel& qnn_model,
bool& loaded_from_cache);
Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context);
@ -91,14 +90,6 @@ class QnnBackendManager {
const Qnn_ProfileHandle_t& GetQnnProfileHandle() { return profile_backend_handle_; }
std::string GetBackendBuildId() {
char* backend_build_id{nullptr};
if (QNN_SUCCESS != qnn_interface_.backendGetBuildId((const char**)&backend_build_id)) {
LOGS(*logger_, ERROR) << "Unable to get build Id from the backend.";
}
return (backend_build_id == nullptr ? std::string("") : std::string(backend_build_id));
}
void SetLogger(const logging::Logger* logger) {
if (logger_ == nullptr) {
logger_ = logger;
@ -133,9 +124,7 @@ class QnnBackendManager {
void SetQnnBackendType(uint32_t backend_id);
QnnBackendType GetQnnBackendType() { return qnn_backend_type_; }
bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
const std::string& model_description,
const onnxruntime::PathString& model_pathstring);
const std::string& GetSdkVersion() { return sdk_build_version_; }
private:
void* LoadLib(const char* file_name, int flags, std::string& error_msg);
@ -177,6 +166,16 @@ class QnnBackendManager {
return ret;
}
std::string GetBackendBuildId() {
char* backend_build_id{nullptr};
if (QNN_SUCCESS != qnn_interface_.backendGetBuildId((const char**)&backend_build_id)) {
LOGS(*logger_, ERROR) << "Unable to get build Id from the backend.";
}
return (backend_build_id == nullptr ? std::string("") : std::string(backend_build_id));
}
Status LoadCachedQnnContextFromBuffer(const std::string& buffer, QnnModel& qnn_model);
private:
const std::string backend_path_;
const logging::Logger* logger_ = nullptr;
@ -201,16 +200,7 @@ class QnnBackendManager {
std::vector<std::string> op_package_paths_;
uint32_t rpc_control_latency_ = 0;
HtpPerformanceMode htp_performance_mode_;
std::string model_name_from_ctx_cache_ = "";
std::string graph_name_from_ctx_cache_ = "";
std::string model_description_from_ctx_cache_ = "";
std::string model_description_ = "";
std::string context_cache_path_ = "";
bool ctx_file_exists_ = false;
bool ctx_metadata_tried_ = false;
bool ort_generated_ctx_cache_ = false;
bool get_capability_round_2_ = false;
uint16_t ort_ctx_metadata_length_ = 0;
std::string sdk_build_version_ = "";
#ifdef _WIN32
std::set<HMODULE> mod_handles_;
#endif

View file

@ -36,15 +36,16 @@ Status QnnModel::SetGraphInputOutputInfo(const GraphViewer& graph_viewer,
initializer_inputs_.emplace(graph_ini.first);
}
auto input_defs = fused_node.InputDefs();
ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(input_defs, inputs_info_, model_input_index_map_, true));
ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(input_defs, input_names_, inputs_info_, model_input_index_map_, true));
auto output_defs = fused_node.OutputDefs();
ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(output_defs, outputs_info_, model_output_index_map_));
ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(output_defs, output_names_, outputs_info_, model_output_index_map_));
return Status::OK();
}
Status QnnModel::ParseGraphInputOrOutput(ConstPointerContainer<std::vector<NodeArg*>>& input_output_defs,
std::vector<std::string>& input_output_names,
std::unordered_map<std::string, OnnxTensorInfo>& input_output_info_table,
std::unordered_map<std::string, size_t>& input_output_index_map,
bool is_input) {
@ -72,6 +73,7 @@ Status QnnModel::ParseGraphInputOrOutput(ConstPointerContainer<std::vector<NodeA
int32_t data_type = type_proto->tensor_type().elem_type();
// use index i so that for graph input, it has initializers included
input_output_info_table.emplace(std::piecewise_construct, std::forward_as_tuple(name), std::forward_as_tuple(i, data_type, std::move(shape)));
input_output_names.push_back(name);
}
return Status::OK();

View file

@ -47,6 +47,7 @@ class QnnModel {
Status SetGraphInputOutputInfo(const GraphViewer& graph_viewer,
const onnxruntime::Node& fused_node);
Status ParseGraphInputOrOutput(ConstPointerContainer<std::vector<NodeArg*>>& input_output_defs,
std::vector<std::string>& input_output_names,
std::unordered_map<std::string, OnnxTensorInfo>& input_output_info_table,
std::unordered_map<std::string, size_t>& input_output_index,
bool is_input = false);
@ -74,6 +75,24 @@ class QnnModel {
Status DeserializeGraphInfoFromBinaryInfo(const QnnSystemContext_GraphInfo_t& qnn_sys_ctx_graph_info);
const std::vector<std::string>& GetInputNames() const {
return input_names_;
}
const std::vector<std::string>& GetOutputNames() const {
return output_names_;
}
const std::unordered_map<std::string, OnnxTensorInfo>& GetInputsInfo() const {
return inputs_info_;
}
const std::unordered_map<std::string, OnnxTensorInfo>& GetOutputsInfo() const {
return outputs_info_;
}
const std::string& Name() { return graph_info_->Name(); }
private:
const NodeUnit& GetNodeUnit(const Node* node,
const std::unordered_map<const Node*, const NodeUnit*>& node_unit_map) const;
@ -87,13 +106,13 @@ class QnnModel {
QnnBackendType GetQnnBackendType() { return qnn_backend_type_; }
private:
size_t GetInputOutputIndex(const std::string& name, const std::unordered_map<std::string, OnnxTensorInfo>& io_info) const {
auto it = io_info.find(name);
ORT_ENFORCE(it != io_info.end(), "Input/Output name not found.");
return it->second.index_;
}
private:
const logging::Logger& logger_;
std::unique_ptr<GraphInfo> graph_info_;
QnnBackendManager* qnn_backend_manager_ = nullptr;
@ -102,6 +121,8 @@ class QnnModel {
std::unordered_map<std::string, size_t> model_output_index_map_;
// TODO: remove initializer_inputs_, use QnnModelWrapper
std::unordered_set<std::string> initializer_inputs_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
std::unordered_map<std::string, OnnxTensorInfo> inputs_info_;
std::unordered_map<std::string, OnnxTensorInfo> outputs_info_;
std::vector<Qnn_Tensor_t> qnn_inputs_;

View file

@ -101,6 +101,14 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_;
}
bool qnn_context_embed_mode = true;
static const std::string CONTEXT_CACHE_EMBED_MODE = "qnn_context_embed_mode";
auto context_cache_embed_mode_pos = runtime_options_.find(CONTEXT_CACHE_EMBED_MODE);
if (context_cache_embed_mode_pos != runtime_options_.end()) {
qnn_context_embed_mode = context_cache_embed_mode_pos->second == "1";
LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode;
}
static const std::string BACKEND_PATH = "backend_path";
auto backend_path_pos = runtime_options_.find(BACKEND_PATH);
@ -147,6 +155,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
rpc_control_latency_,
htp_performance_mode_,
std::move(qnn_saver_path));
qnn_cache_model_handler_ = std::make_unique<qnn::QnnCacheModelHandler>(qnn_context_embed_mode);
}
bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
@ -262,10 +271,16 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
const auto& logger = *GetLogger();
bool load_from_cached_context = false;
if (context_cache_enabled_) {
load_from_cached_context = qnn_backend_manager_->IsContextCacheFileExists(context_cache_path_,
graph_viewer.Description(),
graph_viewer.ModelPath().ToPathString());
bool is_qnn_ctx_model = qnn::IsQnnCtxModel(graph_viewer);
if (is_qnn_ctx_model) {
load_from_cached_context = true;
}
// This is for case: QDQ model + Onnx Qnn context cache model
if (context_cache_enabled_ && !is_qnn_ctx_model) {
load_from_cached_context = qnn_cache_model_handler_->IsContextCacheFileExists(context_cache_path_,
graph_viewer.Description(),
graph_viewer.ModelPath().ToPathString());
}
// Load from cached context will load the QnnSystem lib and skip the Qnn context creation
@ -275,7 +290,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
return result;
}
if (context_cache_enabled_ && !IsNpuBackend(qnn_backend_manager_->GetQnnBackendType())) {
if ((context_cache_enabled_ || is_qnn_ctx_model) && !IsNpuBackend(qnn_backend_manager_->GetQnnBackendType())) {
LOGS(logger, ERROR) << "Qnn context cache only works for HTP or DSP backend.";
return result;
}
@ -365,9 +380,10 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
const size_t num_of_partitions = result.size();
if (load_from_cached_context && 1 == num_of_partitions) {
rt = qnn_backend_manager_->ValidateWithContextFile(GetFileNameFromModelPath(graph_viewer.ModelPath()),
result[0]->sub_graph->GetMetaDef()->name);
if (!is_qnn_ctx_model && load_from_cached_context && 1 == num_of_partitions) {
rt = qnn_cache_model_handler_->ValidateWithContextFile(GetFileNameFromModelPath(graph_viewer.ModelPath()),
result[0]->sub_graph->GetMetaDef()->name,
logger);
if (Status::OK() != rt) {
LOGS(logger, ERROR) << "QNN failed to validate context cache metadata: " << rt.ErrorMessage();
return result;
@ -447,22 +463,28 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector<FusedNodeAndG
Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) {
const auto& logger = *GetLogger();
Node& fused_node = fused_nodes_and_graphs[0].fused_node;
const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[0].filtered_graph);
if (context_cache_enabled_) {
bool is_qnn_ctx_model = false;
ORT_RETURN_IF_ERROR(qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs, is_qnn_ctx_model));
if (context_cache_enabled_ || is_qnn_ctx_model) {
ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support single partition for context cache feature.");
Node& fused_node = fused_nodes_and_graphs[0].fused_node;
const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[0].filtered_graph);
// The dumy_model_description won't be used since IsContextCacheFileExists call cached the result
// The graph_viewer.Description here is not same with original model
std::string dumy_model_description = "";
bool load_from_cached_context = qnn_backend_manager_->IsContextCacheFileExists(context_cache_path_,
dumy_model_description,
graph_viewer.ModelPath().ToPathString());
std::unique_ptr<qnn::QnnModel> qnn_model = std::make_unique<qnn::QnnModel>(logger, qnn_backend_manager_.get());
bool loaded_from_cache = false;
std::string ep_engine_cache;
ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->GetEpContext(graph_viewer,
context_cache_path_,
is_qnn_ctx_model,
qnn_cache_model_handler_->GetIsContextCacheFileExists(),
ep_engine_cache,
logger));
ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnCtxFromOnnxModel(ep_engine_cache,
*(qnn_model.get()),
loaded_from_cache));
// Load and execute from cached context if exist
if (load_from_cached_context) {
std::unique_ptr<qnn::QnnModel> qnn_model = std::make_unique<qnn::QnnModel>(logger,
qnn_backend_manager_.get());
ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnContext(*(qnn_model.get())));
if (loaded_from_cache) {
ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node));
ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput());
@ -473,18 +495,22 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger));
return Status::OK();
} else {
// Load and execute from Onnx model if not exit and dump the context
ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger));
// graph_viewer.Name() is generated in GetCapability, e.g QNN_[hash_id]_[id]
// dump graph_viewer.Name() as metadata in context cache binary file, so that we can validate it in GetCapability
ORT_RETURN_IF_ERROR(qnn_backend_manager_->DumpQnnContext(GetFileNameFromModelPath(graph_viewer.ModelPath()),
graph_viewer.Name()));
}
return Status::OK();
}
ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger));
if (context_cache_enabled_ && !is_qnn_ctx_model) {
ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support single partition for context cache feature.");
uint64_t buffer_size(0);
auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size);
ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->GenerateCtxCacheOnnxModel(context_buffer.get(),
buffer_size,
qnn_backend_manager_->GetSdkVersion(),
fused_nodes_and_graphs,
qnn_models_,
logger));
}
qnn_cache_model_handler_.reset();
return Status::OK();
}

View file

@ -8,6 +8,7 @@
#include <string>
#include "core/providers/qnn/builder/qnn_backend_manager.h"
#include "core/providers/qnn/builder/qnn_model.h"
#include "core/providers/qnn/builder/onnx_ctx_model_helper.h"
namespace onnxruntime {
@ -66,6 +67,7 @@ class QNNExecutionProvider : public IExecutionProvider {
bool context_cache_enabled_ = false;
std::string context_cache_path_ = "";
bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session.
std::unique_ptr<qnn::QnnCacheModelHandler> qnn_cache_model_handler_;
};
} // namespace onnxruntime

View file

@ -56,6 +56,8 @@ void usage() {
"\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n"
"\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n"
"\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n"
"\t [QNN only] [qnn_context_embed_mode]: 1 means dump the QNN context binary into the Onnx skeleton model.\n"
"\t 0 means dump the QNN context binary into separate bin file and set the path in the Onnx skeleton model.\n"
"\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n"
"\t [Usage]: -e <provider_name> -i '<key1>|<value1> <key2>|<value2>' \n\n"
"\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n"
@ -454,6 +456,10 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
if (value.empty()) {
ORT_THROW("Please provide the QNN backend path.");
}
} else if (key == "qnn_context_embed_mode") {
if (value != "0") {
ORT_THROW("Set to 0 to disable qnn_context_embed_mode.");
}
} else if (key == "qnn_context_cache_enable") {
if (value != "1") {
ORT_THROW("Set to 1 to enable qnn_context_cache_enable.");

View file

@ -261,7 +261,8 @@ template <typename QuantType>
inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTestQDQModelFn<QuantType>& qdq_model_fn,
ProviderOptions qnn_options, int opset_version,
ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err = 1e-4f,
logging::Severity log_severity = logging::Severity::kERROR) {
logging::Severity log_severity = logging::Severity::kERROR,
const std::string& qnn_ctx_model_path = "") {
// Add kMSDomain to cover contrib op like Gelu
const std::unordered_map<std::string, int> domain_to_version = {{"", opset_version}, {kMSDomain, 1}};
@ -321,8 +322,21 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe
// Run QDQ model on QNN EP and collect outputs.
TryEnableQNNSaver(qnn_options);
std::vector<OrtValue> qnn_qdq_outputs;
InferenceModel(qdq_model_data, "qdq_model_logger", QnnExecutionProviderWithOptions(qnn_options),
expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs);
if (!qnn_ctx_model_path.empty()) {
onnx::ModelProto model_proto;
onnxruntime::Model qnn_ctx_model;
// Load the QNN context cache model from path specified
ASSERT_STATUS_OK(qnn_ctx_model.Load(ToPathString(qnn_ctx_model_path), model_proto));
std::string qnn_ctx_model_data;
model_proto.SerializeToString(&qnn_ctx_model_data);
// Run QNN context cache model on QNN EP and collect outputs.
InferenceModel(qnn_ctx_model_data, "qnn_ctx_model_logger", QnnExecutionProviderWithOptions(qnn_options),
expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs);
} else {
// Run QDQ model on QNN EP and collect outputs.
InferenceModel(qdq_model_data, "qdq_model_logger", QnnExecutionProviderWithOptions(qnn_options),
expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs);
}
if (expected_ep_assignment != ExpectedEPNodeAssignment::None) {
// Run QDQ model on CPU EP and collect outputs.

View file

@ -679,10 +679,11 @@ TEST_F(QnnHTPBackendTests, SpaceToDepthOp_U16) {
true); // Use com.microsoft domain for Q/DQ ops
}
// Run QDQ model on HTP twice
// 1st run will generate the Qnn context cache binary file
// 2nd run will load and run from Qnn context cache binary file
TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) {
// Run QDQ model on HTP 3 times
// 1st run will generate the Qnn context cache onnx file
// 2nd run will load and run from QDQ model + Qnn context cache model
// 3rd run directly loads and run from Qnn context cache model
TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
@ -690,7 +691,7 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) {
provider_options["backend_path"] = "libQnnHtp.so";
#endif
provider_options["qnn_context_cache_enable"] = "1";
const std::string context_binary_file = "./qnn_context_binary_test.bin";
const std::string context_binary_file = "./qnn_context_binary_test.onnx";
provider_options["qnn_context_cache_path"] = context_binary_file;
const TestInputDef<float> input_def({1, 2, 3}, false, -10.0f, 10.0f);
@ -707,12 +708,120 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) {
// Make sure the Qnn context cache binary file is generated
EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str()));
// 2nd run will load and run from Qnn context cache binary file
// 2nd run loads and run from QDQ model + Qnn context cache model
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def}, {}, {}),
BuildQDQOpTestCase<uint8_t>(op_type, {input_def}, {}, {}),
provider_options,
14,
ExpectedEPNodeAssignment::All);
// 3rd run directly loads and run from Qnn context cache model
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def}, {}, {}),
BuildQDQOpTestCase<uint8_t>(op_type, {input_def}, {}, {}),
provider_options,
14,
ExpectedEPNodeAssignment::All,
1e-4f,
logging::Severity::kERROR,
context_binary_file);
}
// Run QDQ model on HTP 3 times
// 1st run will generate the Onnx skeleton file + Qnn context cache binary file
// 2nd run will loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file
// 3rd run directly loads and run from Onnx skeleton file + Qnn context cache binary file
TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
provider_options["qnn_context_cache_enable"] = "1";
const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx";
provider_options["qnn_context_cache_path"] = context_binary_file;
provider_options["qnn_context_embed_mode"] = "0";
const TestInputDef<float> input_def({1, 2, 3}, false, -10.0f, 10.0f);
const std::string op_type = "Atan";
// Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs.
// 1st run will generate the Onnx skeleton file + Qnn context cache binary file
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def}, {}, {}),
BuildQDQOpTestCase<uint8_t>(op_type, {input_def}, {}, {}),
provider_options,
14,
ExpectedEPNodeAssignment::All);
// Check the Onnx skeleton file is generated
EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str()));
// Check the Qnn context cache binary file is generated
EXPECT_TRUE(std::filesystem::exists("qnn_context_cache_non_embed.onnx_QNN_8283143575221199085_1.bin"));
// 2nd run loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def}, {}, {}),
BuildQDQOpTestCase<uint8_t>(op_type, {input_def}, {}, {}),
provider_options,
14,
ExpectedEPNodeAssignment::All);
// 3rd run directly loads and run from Onnx skeleton file + Qnn context cache binary file
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def}, {}, {}),
BuildQDQOpTestCase<uint8_t>(op_type, {input_def}, {}, {}),
provider_options,
14,
ExpectedEPNodeAssignment::All,
1e-4f,
logging::Severity::kERROR,
context_binary_file);
}
// Run QDQ model on HTP with 2 inputs
// 1st run will generate the Qnn context cache onnx file
// 2nd run will load and run from QDQ model + Qnn context cache model
// 3rd run directly loads and run from Qnn context cache model
TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
provider_options["qnn_context_cache_enable"] = "1";
const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx";
provider_options["qnn_context_cache_path"] = context_binary_file;
const TestInputDef<float> input_def1({1, 2, 3}, false, -10.0f, 10.0f);
const TestInputDef<float> input_def2({1, 2, 3}, false, -10.0f, 10.0f);
const std::string op_type = "Add";
// Runs model with DQ-> Add-> Q and compares the outputs of the CPU and QNN EPs.
// 1st run will generate the Qnn context cache binary file
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def1, input_def2}, {}, {}),
BuildQDQOpTestCase<uint8_t>(op_type, {input_def1, input_def2}, {}, {}),
provider_options,
14,
ExpectedEPNodeAssignment::All);
// Make sure the Qnn context cache binary file is generated
EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str()));
// 2nd run loads and run from QDQ model + Qnn context cache model
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def1, input_def2}, {}, {}),
BuildQDQOpTestCase<uint8_t>(op_type, {input_def1, input_def2}, {}, {}),
provider_options,
14,
ExpectedEPNodeAssignment::All);
// 3rd run directly loads and run from Qnn context cache model
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def1, input_def2}, {}, {}),
BuildQDQOpTestCase<uint8_t>(op_type, {input_def1, input_def2}, {}, {}),
provider_options,
14,
ExpectedEPNodeAssignment::All,
1e-4f,
logging::Severity::kERROR,
context_binary_file);
}
TEST_F(QnnHTPBackendTests, QuantAccuracyTest) {

View file

@ -110,7 +110,7 @@ jobs:
inputs:
script: |
./build/Release/onnx_test_runner -e qnn \
-v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so qnn_context_cache_enable|1 qnn_context_cache_path|./build/Release/mobilenet_qdq.bin" \
-v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so qnn_context_cache_enable|1 qnn_context_cache_path|./build/Release/mobilenet_qdq.onnx_qnn_ctx.onnx" \
/data/qdq_models/mobilenetv2-1.0_add_transpose_quant
- task: CmdLine@2
@ -118,5 +118,5 @@ jobs:
inputs:
script: |
./build/Release/onnx_test_runner -e qnn \
-v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so qnn_context_cache_enable|1 qnn_context_cache_path|./build/Release/mobilenet_qdq.bin" \
-v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so qnn_context_cache_enable|1 qnn_context_cache_path|./build/Release/mobilenet_qdq.onnx_qnn_ctx.onnx" \
/data/qdq_models/mobilenetv2-1.0_add_transpose_quant