mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
[QNN EP] Qnn cache improvement (#17757)
### Description Improve the QNN context binary cache feature to reduce the memory overhead and initialization time overhead. Instead of dumping a Qnn context binary file with metadata as header, we dump a Onnx format file with metadata inside Onnx node. ### Motivation and Context reduce the memory overhead and initialization time overhead
This commit is contained in:
parent
569876fb16
commit
385fab5bae
15 changed files with 770 additions and 285 deletions
|
|
@ -27,6 +27,7 @@ Do not modify directly.*
|
|||
* <a href="#com.microsoft.DequantizeWithOrder">com.microsoft.DequantizeWithOrder</a>
|
||||
* <a href="#com.microsoft.DynamicQuantizeLSTM">com.microsoft.DynamicQuantizeLSTM</a>
|
||||
* <a href="#com.microsoft.DynamicQuantizeMatMul">com.microsoft.DynamicQuantizeMatMul</a>
|
||||
* <a href="#com.microsoft.EPContext">com.microsoft.EPContext</a>
|
||||
* <a href="#com.microsoft.EmbedLayerNormalization">com.microsoft.EmbedLayerNormalization</a>
|
||||
* <a href="#com.microsoft.ExpandDims">com.microsoft.ExpandDims</a>
|
||||
* <a href="#com.microsoft.FastGelu">com.microsoft.FastGelu</a>
|
||||
|
|
@ -1520,6 +1521,55 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.EPContext"></a><a name="com.microsoft.epcontext">**com.microsoft.EPContext**</a>
|
||||
|
||||
Onnx node container for EP context.
|
||||
|
||||
#### Version
|
||||
|
||||
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
|
||||
|
||||
#### Attributes
|
||||
|
||||
<dl>
|
||||
<dt><tt>embed_mode</tt> : int</dt>
|
||||
<dd>1: indicate ep_cache_context is the context content. 0: indicate ep_cache_context is the file path to the context content.The path is relative to this Onnx file. Default is 1.</dd>
|
||||
<dt><tt>ep_cache_context</tt> : string</dt>
|
||||
<dd>payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0.</dd>
|
||||
<dt><tt>ep_sdk_version</tt> : string</dt>
|
||||
<dd>(Optional) SDK version used to convert the model.</dd>
|
||||
<dt><tt>main_context</tt> : int</dt>
|
||||
<dd>Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.</dd>
|
||||
<dt><tt>notes</tt> : string</dt>
|
||||
<dd>(Optional) Some notes for the model</dd>
|
||||
<dt><tt>partition_name</tt> : string</dt>
|
||||
<dd>(Optional) partitioned graph name.</dd>
|
||||
<dt><tt>source</tt> : string</dt>
|
||||
<dd>(Optional) the source used to generate the engine/context cache file. Ort EP or native SDK tool chain</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs (1 - ∞)
|
||||
|
||||
<dl>
|
||||
<dt><tt>inputs</tt> (variadic) : T</dt>
|
||||
<dd>List of tensors for inputs</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs (1 - ∞)
|
||||
|
||||
<dl>
|
||||
<dt><tt>outputs</tt> (variadic) : T</dt>
|
||||
<dd>One or more outputs, list of tensors for outputs</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
||||
<dl>
|
||||
<dt><tt>T</tt> : tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(float16), tensor(float), tensor(double)</dt>
|
||||
<dd>Constrain input and output types.</dd>
|
||||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.EmbedLayerNormalization"></a><a name="com.microsoft.embedlayernormalization">**com.microsoft.EmbedLayerNormalization**</a>
|
||||
|
||||
EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing.
|
||||
|
|
|
|||
|
|
@ -3597,6 +3597,9 @@ struct OrtApi {
|
|||
* "rpc_control_latency": QNN RPC control latency.
|
||||
* "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance",
|
||||
* "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default".
|
||||
* "qnn_context_embed_mode", 1 means dump the QNN context binary into node attribute EPContext->ep_cache_context in the Onnx skeleton model.
|
||||
* 0 means dump the QNN context binary into separate bin file and set the path to EPContext->ep_cache_context.
|
||||
* The path is relative path to the Onnx skeleton model file.
|
||||
* "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will
|
||||
* dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and
|
||||
* may alter model/EP partitioning. Use only for debugging.
|
||||
|
|
|
|||
|
|
@ -2844,6 +2844,83 @@ void RegisterContribSchemas() {
|
|||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(EPContext)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc("Onnx node container for EP context.")
|
||||
.Attr(
|
||||
"main_context",
|
||||
"Usually each single EPContext associate with a graph partition."
|
||||
"But for some case like QNN, it has single EPContext contains all partitions."
|
||||
"In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context."
|
||||
"The path is relative to this Onnx file. Default is 1.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(1))
|
||||
.Attr(
|
||||
"ep_cache_context",
|
||||
"payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0.",
|
||||
AttributeProto::STRING,
|
||||
OPTIONAL_VALUE)
|
||||
.Attr(
|
||||
"embed_mode",
|
||||
"1: indicate ep_cache_context is the context content. 0: indicate ep_cache_context is the file path to the context content."
|
||||
"The path is relative to this Onnx file. Default is 1.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(1))
|
||||
.Attr(
|
||||
"ep_sdk_version",
|
||||
"(Optional) SDK version used to convert the model.",
|
||||
AttributeProto::STRING,
|
||||
OPTIONAL_VALUE)
|
||||
.Attr(
|
||||
"partition_name",
|
||||
"(Optional) partitioned graph name.",
|
||||
AttributeProto::STRING,
|
||||
OPTIONAL_VALUE)
|
||||
.Attr(
|
||||
"source",
|
||||
"(Optional) the source used to generate the engine/context cache file. Ort EP or native SDK tool chain",
|
||||
AttributeProto::STRING,
|
||||
OPTIONAL_VALUE)
|
||||
.Attr("notes", "(Optional) Some notes for the model", AttributeProto::STRING, OPTIONAL_VALUE)
|
||||
.AllowUncheckedAttributes()
|
||||
.Input(
|
||||
0,
|
||||
"inputs",
|
||||
"List of tensors for inputs",
|
||||
"T",
|
||||
OpSchema::Variadic,
|
||||
true,
|
||||
1,
|
||||
OpSchema::NonDifferentiable)
|
||||
.Output(
|
||||
0,
|
||||
"outputs",
|
||||
"One or more outputs, list of tensors for outputs",
|
||||
"T",
|
||||
OpSchema::Variadic,
|
||||
true,
|
||||
1,
|
||||
OpSchema::NonDifferentiable)
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(int8)",
|
||||
"tensor(int16)",
|
||||
"tensor(int32)",
|
||||
"tensor(int64)",
|
||||
"tensor(uint8)",
|
||||
"tensor(uint16)",
|
||||
"tensor(uint32)",
|
||||
"tensor(uint64)",
|
||||
"tensor(float16)",
|
||||
"tensor(float)",
|
||||
"tensor(double)"},
|
||||
"Constrain input and output types.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
// Type inference
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
});
|
||||
|
||||
static const char* BitmaskDropout_ver1_doc = R"DOC(
|
||||
BitmaskDropout takes an input floating-point tensor, an optional input ratio (floating-point scalar) and an optional input training_mode (boolean scalar).
|
||||
It produces two tensor outputs: output (floating-point tensor) and mask (optional `Tensor<uint32>`). If `training_mode` is true then the output Y will be a random dropout.
|
||||
|
|
|
|||
264
onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Normal file
264
onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/qnn/builder/onnx_ctx_model_helper.h"
|
||||
#include "core/graph/constants.h"
|
||||
#include "core/providers/qnn/builder/qnn_model.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <filesystem>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace qnn {
|
||||
|
||||
Status IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
bool& is_qnn_ctx_model) {
|
||||
is_qnn_ctx_model = false;
|
||||
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
|
||||
const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph);
|
||||
// It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type
|
||||
int count = 0;
|
||||
for (const auto& node : graph_viewer.Nodes()) {
|
||||
if (EPCONTEXT_OP == node.OpType()) {
|
||||
is_qnn_ctx_model = true;
|
||||
}
|
||||
++count;
|
||||
}
|
||||
ORT_RETURN_IF(is_qnn_ctx_model && count > 1, "Fused graph should only has 1 single EPContext node.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer) {
|
||||
// It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type
|
||||
for (const auto& node : graph_viewer.Nodes()) {
|
||||
if (EPCONTEXT_OP == node.OpType()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status CreateNodeArgs(const std::vector<std::string>& names,
|
||||
const std::unordered_map<std::string, OnnxTensorInfo>& tensor_info_table,
|
||||
std::vector<NodeArg*>& node_args,
|
||||
onnxruntime::Graph& graph) {
|
||||
using namespace ONNX_NAMESPACE;
|
||||
for (size_t i = 0; i < names.size(); ++i) {
|
||||
std::string name = names[i];
|
||||
ORT_RETURN_IF(tensor_info_table.find(name) == tensor_info_table.end(), "Tensor name: ", name, " not found in tensor_info_table");
|
||||
const OnnxTensorInfo& tensor_info = tensor_info_table.at(name);
|
||||
TypeProto tensor_type;
|
||||
tensor_type.mutable_tensor_type()->set_elem_type(tensor_info.data_type_);
|
||||
for (size_t j = 0; j < tensor_info.shape_.size(); ++j) {
|
||||
tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(tensor_info.shape_[j]);
|
||||
}
|
||||
auto& input_arg = graph.GetOrCreateNodeArg(name, &tensor_type);
|
||||
node_args.push_back(&input_arg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
|
||||
std::string& ep_cache_context,
|
||||
const logging::Logger& logger) {
|
||||
using namespace onnxruntime;
|
||||
std::shared_ptr<Model> model;
|
||||
ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger));
|
||||
const auto& graph = model->MainGraph();
|
||||
ORT_RETURN_IF_ERROR(GetEpContextFromGraph(GraphViewer(graph), ctx_onnx_model_path, ep_cache_context));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::string& ctx_onnx_model_path,
|
||||
std::string& ep_cache_context) {
|
||||
const auto& node = graph_viewer.Nodes().begin();
|
||||
NodeAttrHelper node_helper(*node);
|
||||
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
|
||||
if (is_embed_mode) {
|
||||
ep_cache_context = node_helper.Get(EP_CACHE_CONTEXT, "");
|
||||
} else {
|
||||
std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, "");
|
||||
|
||||
std::string context_binary_path(std::filesystem::path(ctx_onnx_model_path).parent_path().string() +
|
||||
"/" + external_qnn_context_binary_file_name);
|
||||
size_t buffer_size{0};
|
||||
std::ifstream cache_file(context_binary_path.c_str(), std::ifstream::binary);
|
||||
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file.");
|
||||
cache_file.seekg(0, cache_file.end);
|
||||
buffer_size = static_cast<size_t>(cache_file.tellg());
|
||||
ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered.");
|
||||
cache_file.seekg(0, cache_file.beg);
|
||||
ep_cache_context.reserve(buffer_size);
|
||||
// Load file into buffer
|
||||
ep_cache_context.assign(std::istreambuf_iterator<char>(cache_file), std::istreambuf_iterator<char>());
|
||||
cache_file.close();
|
||||
ORT_RETURN_IF(ep_cache_context.length() != buffer_size, "Failed to read contents from cached context file.");
|
||||
}
|
||||
ORT_RETURN_IF(ep_cache_context.empty(), "Cached context empty.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QnnCacheModelHandler::GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path,
|
||||
std::string& model_name,
|
||||
std::string& model_description,
|
||||
std::string& graph_partition_name,
|
||||
std::string& cache_source,
|
||||
const logging::Logger& logger) {
|
||||
if (!is_metadata_ready_) {
|
||||
using namespace onnxruntime;
|
||||
std::shared_ptr<Model> model;
|
||||
ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger));
|
||||
const auto& graph = GraphViewer(model->MainGraph());
|
||||
const auto& node = graph.Nodes().begin();
|
||||
NodeAttrHelper node_helper(*node);
|
||||
model_name_ = graph.Name();
|
||||
model_description_ = graph.Description();
|
||||
graph_partition_name_ = node_helper.Get(PARTITION_NAME, "");
|
||||
cache_source_ = node_helper.Get(SOURCE, "");
|
||||
is_metadata_ready_ = true;
|
||||
}
|
||||
model_name = model_name_;
|
||||
model_description = model_description_;
|
||||
graph_partition_name = graph_partition_name_;
|
||||
cache_source = cache_source_;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool QnnCacheModelHandler::IsContextCacheFileExists(const std::string& customer_context_cache_path,
|
||||
const std::string& model_description,
|
||||
const onnxruntime::PathString& model_pathstring) {
|
||||
// Avoid duplicate work
|
||||
if (ctx_file_exists_) {
|
||||
return ctx_file_exists_;
|
||||
}
|
||||
model_description_ = model_description;
|
||||
// Use user provided context cache file path if exist, otherwise try model_file.onnx_ctx.onnx by default
|
||||
if (customer_context_cache_path.empty()) {
|
||||
context_cache_path_ = PathToUTF8String(model_pathstring) + "_qnn_ctx.onnx";
|
||||
} else {
|
||||
context_cache_path_ = customer_context_cache_path;
|
||||
}
|
||||
|
||||
ctx_file_exists_ = std::filesystem::exists(context_cache_path_);
|
||||
|
||||
return ctx_file_exists_;
|
||||
}
|
||||
|
||||
Status QnnCacheModelHandler::ValidateWithContextFile(const std::string& model_name,
|
||||
const std::string& graph_partition_name,
|
||||
const logging::Logger& logger) {
|
||||
ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!");
|
||||
|
||||
std::string model_name_from_ctx_cache;
|
||||
std::string model_description_from_ctx_cache;
|
||||
std::string graph_partition_name_from_ctx_cache;
|
||||
std::string cache_source;
|
||||
ORT_RETURN_IF_ERROR(GetMetadataFromEpContextModel(context_cache_path_,
|
||||
model_name_from_ctx_cache,
|
||||
model_description_from_ctx_cache,
|
||||
graph_partition_name_from_ctx_cache,
|
||||
cache_source,
|
||||
logger));
|
||||
|
||||
// The source attribute from the skeleton onnx file indicate whether it's generated from QNN toolchain or ORT
|
||||
if (cache_source != kQnnExecutionProvider) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ORT_RETURN_IF(model_name != model_name_from_ctx_cache,
|
||||
"Model file name from context cache metadata: " + model_name_from_ctx_cache +
|
||||
" is different with target: " + model_name +
|
||||
". Please make sure the context binary file matches the model.");
|
||||
|
||||
ORT_RETURN_IF(model_description_ != model_description_from_ctx_cache,
|
||||
"Model description from context cache metadata: " + model_description_from_ctx_cache +
|
||||
" is different with target: " + model_description_ +
|
||||
". Please make sure the context binary file matches the model.");
|
||||
|
||||
ORT_RETURN_IF(graph_partition_name != graph_partition_name_from_ctx_cache && get_capability_round_2_,
|
||||
"Graph name from context cache metadata: " + graph_partition_name_from_ctx_cache +
|
||||
" is different with target: " + graph_partition_name +
|
||||
". You may need to re-generate the context binary file.");
|
||||
|
||||
get_capability_round_2_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QnnCacheModelHandler::GenerateCtxCacheOnnxModel(unsigned char* buffer,
|
||||
uint64_t buffer_size,
|
||||
const std::string& sdk_build_version,
|
||||
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
|
||||
const logging::Logger& logger) {
|
||||
std::unordered_map<std::string, int> domain_to_version = {{kOnnxDomain, 11}, {kMSDomain, 1}};
|
||||
Model model(model_name_, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
domain_to_version, {}, logger);
|
||||
auto& graph = model.MainGraph();
|
||||
graph.SetDescription(model_description_);
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
int index = 0;
|
||||
// Still need more work to support multiple partition, it's out of EP's scope.
|
||||
// Already have code to make sure it's single partition before this method get invoked.
|
||||
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
|
||||
const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph);
|
||||
Node& fused_node = fused_node_graph.fused_node;
|
||||
// graph_viewer.Name() is generated in GetCapability, e.g QNN_[hash_id]_[id]
|
||||
// dump graph_viewer.Name() as metadata in context cache binary file, so that we can validate it in GetCapability
|
||||
auto qnn_model_kv = qnn_models.find(fused_node.Name());
|
||||
ORT_RETURN_IF(qnn_model_kv == qnn_models.end(), fused_node.Name(), " not exist in QnnModel table.");
|
||||
|
||||
auto qnn_model = qnn_model_kv->second.get();
|
||||
std::vector<NodeArg*> inputs;
|
||||
std::vector<NodeArg*> outputs;
|
||||
ORT_RETURN_IF_ERROR(CreateNodeArgs(qnn_model->GetInputNames(), qnn_model->GetInputsInfo(), inputs, graph));
|
||||
ORT_RETURN_IF_ERROR(CreateNodeArgs(qnn_model->GetOutputNames(), qnn_model->GetOutputsInfo(), outputs, graph));
|
||||
|
||||
const std::string& graph_name = graph_viewer.Name();
|
||||
auto& ep_node = graph.AddNode(graph_name,
|
||||
EPCONTEXT_OP,
|
||||
"Onnx Qnn context binary cache for graph partition: " + graph_name,
|
||||
inputs,
|
||||
outputs,
|
||||
nullptr,
|
||||
kMSDomain);
|
||||
|
||||
// Only dump the context buffer once since all QNN graph are in one single context
|
||||
if (0 == index) {
|
||||
if (qnn_context_embed_mode_) {
|
||||
std::string cache_payload(buffer, buffer + buffer_size);
|
||||
ep_node.AddAttribute(EP_CACHE_CONTEXT, cache_payload);
|
||||
} else {
|
||||
std::string context_cache_path(context_cache_path_ + "_" + graph_name + ".bin");
|
||||
std::string context_cache_name(std::filesystem::path(context_cache_path).filename().string());
|
||||
std::ofstream of_stream(context_cache_path.c_str(), std::ofstream::binary);
|
||||
if (!of_stream) {
|
||||
LOGS(logger, ERROR) << "Failed to open create context file.";
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to open context cache file.");
|
||||
}
|
||||
of_stream.write(reinterpret_cast<char*>(buffer), buffer_size);
|
||||
ep_node.AddAttribute(EP_CACHE_CONTEXT, context_cache_name);
|
||||
}
|
||||
} else {
|
||||
ep_node.AddAttribute(MAIN_CONTEXT, static_cast<int64_t>(0));
|
||||
}
|
||||
int64_t embed_mode = qnn_context_embed_mode_ ? static_cast<int64_t>(1) : static_cast<int64_t>(0);
|
||||
ep_node.AddAttribute(EMBED_MODE, embed_mode);
|
||||
ep_node.AddAttribute(EP_SDK_VER, sdk_build_version);
|
||||
ep_node.AddAttribute(PARTITION_NAME, graph_name);
|
||||
ep_node.AddAttribute(SOURCE, kQnnExecutionProvider);
|
||||
++index;
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
ORT_RETURN_IF_ERROR(Model::Save(model, context_cache_path_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace qnn
|
||||
} // namespace onnxruntime
|
||||
108
onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
Normal file
108
onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "qnn_def.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace qnn {
|
||||
|
||||
class QnnModel;
|
||||
|
||||
static const std::string EPCONTEXT_OP = "EPContext";
|
||||
static const std::string MAIN_CONTEXT = "main_context";
|
||||
static const std::string EMBED_MODE = "embed_mode";
|
||||
static const std::string EP_CACHE_CONTEXT = "ep_cache_context";
|
||||
static const std::string EP_SDK_VER = "ep_sdk_version";
|
||||
static const std::string PARTITION_NAME = "partition_name";
|
||||
static const std::string SOURCE = "source";
|
||||
|
||||
Status IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
bool& is_qnn_ctx_model);
|
||||
|
||||
bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer);
|
||||
|
||||
Status CreateNodeArgs(const std::vector<std::string>& names,
|
||||
const std::unordered_map<std::string, OnnxTensorInfo>& tensor_info_table,
|
||||
std::vector<NodeArg*>& node_args,
|
||||
onnxruntime::Graph& graph);
|
||||
|
||||
Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
|
||||
std::string& ep_engine_cache,
|
||||
const logging::Logger& logger);
|
||||
|
||||
Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::string& ctx_onnx_model_path,
|
||||
std::string& ep_cache_context);
|
||||
|
||||
class QnnCacheModelHandler {
|
||||
public:
|
||||
QnnCacheModelHandler(bool qnn_context_embed_mode) : qnn_context_embed_mode_(qnn_context_embed_mode) {
|
||||
}
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnCacheModelHandler);
|
||||
|
||||
Status GetEpContext(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::string& ctx_onnx_model_path,
|
||||
bool is_qnn_ctx_model,
|
||||
bool is_ctx_cache_file_exist,
|
||||
std::string& ep_engine_cache,
|
||||
const logging::Logger& logger) const {
|
||||
if (is_qnn_ctx_model) {
|
||||
ORT_RETURN_IF_ERROR(GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, ep_engine_cache));
|
||||
} else if (is_ctx_cache_file_exist) {
|
||||
ORT_RETURN_IF_ERROR(GetEpContextFromModel(ctx_onnx_model_path, ep_engine_cache, logger));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
|
||||
const std::string& model_description,
|
||||
const onnxruntime::PathString& model_pathstring);
|
||||
|
||||
bool GetIsContextCacheFileExists() const {
|
||||
return ctx_file_exists_;
|
||||
}
|
||||
|
||||
Status ValidateWithContextFile(const std::string& model_name,
|
||||
const std::string& graph_name,
|
||||
const logging::Logger& logger);
|
||||
|
||||
Status GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path,
|
||||
std::string& model_name,
|
||||
std::string& model_description,
|
||||
std::string& graph_partition_name,
|
||||
std::string& cache_source,
|
||||
const logging::Logger& logger);
|
||||
|
||||
Status GenerateCtxCacheOnnxModel(unsigned char* buffer,
|
||||
uint64_t buffer_size,
|
||||
const std::string& sdk_build_version,
|
||||
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
|
||||
const logging::Logger& logger);
|
||||
|
||||
private:
|
||||
bool is_metadata_ready_ = false;
|
||||
std::string model_name_ = "";
|
||||
std::string model_description_ = "";
|
||||
std::string graph_partition_name_ = "";
|
||||
std::string cache_source_ = "";
|
||||
std::string context_cache_path_ = "";
|
||||
bool ctx_file_exists_ = false;
|
||||
bool get_capability_round_2_ = false;
|
||||
bool qnn_context_embed_mode_ = true;
|
||||
}; // QnnCacheModelHandler
|
||||
|
||||
} // namespace qnn
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -3,8 +3,6 @@
|
|||
|
||||
#include "qnn_backend_manager.h"
|
||||
#include "qnn_model.h"
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <filesystem>
|
||||
#include "QnnOpDef.h"
|
||||
#include "HTP/QnnHtpPerfInfrastructure.h"
|
||||
|
|
@ -16,6 +14,7 @@
|
|||
#include "core/common/gsl.h"
|
||||
#include "core/framework/endian_utils.h"
|
||||
#include "core/common/logging/capture.h"
|
||||
#include "core/providers/qnn/builder/onnx_ctx_model_helper.h"
|
||||
|
||||
// Flag to determine if Backend should do node validation for each opNode added
|
||||
#define DO_GRAPH_NODE_VALIDATIONS 1
|
||||
|
|
@ -28,8 +27,6 @@ typedef Qnn_ErrorHandle_t (*QnnInterfaceGetProvidersFn_t)(const QnnInterface_t**
|
|||
typedef Qnn_ErrorHandle_t (*QnnSystemInterfaceGetProvidersFn_t)(const QnnSystemInterface_t*** providerList,
|
||||
uint32_t* numProviders);
|
||||
|
||||
constexpr const char* QNN_PROVIDER = "ORTQNNEP";
|
||||
|
||||
static Qnn_Version_t GetQnnInterfaceApiVersion(const QnnInterface_t* qnn_interface) {
|
||||
return qnn_interface->apiVersion.coreApiVersion;
|
||||
}
|
||||
|
|
@ -412,135 +409,64 @@ Status QnnBackendManager::ReleaseContext() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool QnnBackendManager::IsContextCacheFileExists(const std::string& customer_context_cache_path,
|
||||
const std::string& model_description,
|
||||
const onnxruntime::PathString& model_pathstring) {
|
||||
// Avoid duplicate work
|
||||
if (!context_cache_path_.empty()) {
|
||||
return ctx_file_exists_;
|
||||
}
|
||||
model_description_ = model_description;
|
||||
// Use user provided context cache file path if exist, otherwise try model_file.onnx.bin by default
|
||||
if (customer_context_cache_path.empty()) {
|
||||
context_cache_path_ = PathToUTF8String(model_pathstring) + ".bin";
|
||||
} else {
|
||||
context_cache_path_ = customer_context_cache_path;
|
||||
}
|
||||
|
||||
ctx_file_exists_ = std::filesystem::exists(context_cache_path_);
|
||||
|
||||
return ctx_file_exists_;
|
||||
}
|
||||
|
||||
Status WriteInt16ToBinaryFile(std::ofstream& of_stream, uint16_t value) {
|
||||
const std::vector<uint16_t> data{value};
|
||||
std::vector<unsigned char> data_bytes(sizeof(uint16_t) / sizeof(unsigned char));
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::utils::WriteLittleEndian(gsl::make_span(data), gsl::make_span(data_bytes)));
|
||||
of_stream.write(reinterpret_cast<char*>(data_bytes.data()), data_bytes.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QnnBackendManager::DumpQnnContext(const std::string& model_name, const std::string& graph_name) {
|
||||
std::unique_ptr<unsigned char[]> QnnBackendManager::GetContextBinaryBuffer(uint64_t& written_buffer_size) {
|
||||
if (nullptr == qnn_interface_.contextGetBinarySize ||
|
||||
nullptr == qnn_interface_.contextGetBinary) {
|
||||
LOGS(*logger_, ERROR) << "Failed to get valid function pointer.";
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get valid function pointer.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
uint64_t required_buffer_size(0);
|
||||
Qnn_ErrorHandle_t rt = qnn_interface_.contextGetBinarySize(context_, &required_buffer_size);
|
||||
if (QNN_CONTEXT_NO_ERROR != rt) {
|
||||
LOGS(*logger_, ERROR) << "Failed to get QNN context binary size. Error code: " << rt;
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get QNN context binary size.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<unsigned char[]> context_buffer = std::make_unique<unsigned char[]>(required_buffer_size);
|
||||
if (nullptr == context_buffer) {
|
||||
LOGS(*logger_, ERROR) << "Failed to allocate buffer for context cache.";
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to allocate buffer for context cache.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
uint64_t written_buffer_size(0);
|
||||
rt = qnn_interface_.contextGetBinary(context_,
|
||||
reinterpret_cast<void*>(context_buffer.get()),
|
||||
required_buffer_size,
|
||||
&written_buffer_size);
|
||||
if (QNN_CONTEXT_NO_ERROR != rt) {
|
||||
LOGS(*logger_, ERROR) << "Failed to get context binary.";
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get context binary.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (required_buffer_size < written_buffer_size) {
|
||||
LOGS(*logger_, ERROR) << "Context written buffer size: " << written_buffer_size
|
||||
<< " exceeds allocated buffer size: " << required_buffer_size;
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Context written buffer exceeds allocated buffer size.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::ofstream of_stream(context_cache_path_.c_str(), std::ofstream::binary);
|
||||
if (!of_stream) {
|
||||
LOGS(*logger_, ERROR) << "Failed to open cached context file.";
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to open context cache file.");
|
||||
LOGS(*logger_, VERBOSE) << "Get context binary buffer succeed.";
|
||||
return context_buffer;
|
||||
}
|
||||
|
||||
Status QnnBackendManager::LoadCachedQnnCtxFromOnnxModel(const std::string& ep_engine_cache,
|
||||
QnnModel& qnn_model,
|
||||
bool& loaded_from_cache) {
|
||||
loaded_from_cache = false;
|
||||
|
||||
if (!ep_engine_cache.empty()) {
|
||||
ORT_RETURN_IF_ERROR(LoadCachedQnnContextFromBuffer(ep_engine_cache, qnn_model));
|
||||
loaded_from_cache = true;
|
||||
}
|
||||
|
||||
// Write Ort metadata into context binary file
|
||||
uint16_t model_name_length = static_cast<uint16_t>(model_name.length());
|
||||
uint16_t graph_name_length = static_cast<uint16_t>(graph_name.length());
|
||||
uint16_t model_description_length = static_cast<uint16_t>(model_description_.length());
|
||||
|
||||
// Header: uint16_t(totale_length)|uint16_t(model_name_length)|model_name|uint16_t(graph_name_length)|graph_name|uint16_t(model_description_length)|model_description
|
||||
uint16_t header_length = 4 * sizeof(uint16_t) + model_name_length + graph_name_length + model_description_length;
|
||||
uint16_t totale_length = header_length + static_cast<uint16_t>(strlen(QNN_PROVIDER));
|
||||
of_stream.write(QNN_PROVIDER, strlen(QNN_PROVIDER));
|
||||
|
||||
ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, header_length));
|
||||
|
||||
ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, model_name_length));
|
||||
of_stream.write(model_name.c_str(), model_name_length);
|
||||
|
||||
ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, graph_name_length));
|
||||
of_stream.write(graph_name.c_str(), graph_name_length);
|
||||
|
||||
ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, model_description_length));
|
||||
of_stream.write(model_description_.c_str(), model_description_length);
|
||||
model_description_.clear();
|
||||
|
||||
LOGS(*logger_, VERBOSE) << "Dump metadata with length: " << totale_length;
|
||||
|
||||
of_stream.write(reinterpret_cast<char*>(context_buffer.get()), written_buffer_size);
|
||||
|
||||
LOGS(*logger_, VERBOSE) << "Dump QNN Context completed.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QnnBackendManager::LoadCachedQnnContext(QnnModel& qnn_model) {
|
||||
Status QnnBackendManager::LoadCachedQnnContextFromBuffer(const std::string& buffer, QnnModel& qnn_model) {
|
||||
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
|
||||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
|
||||
nullptr == qnn_sys_interface_.systemContextFree;
|
||||
ORT_RETURN_IF(result, "Failed to get valid function pointer.");
|
||||
|
||||
ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!");
|
||||
|
||||
uint64_t buffer_size{0};
|
||||
std::ifstream cache_file(context_cache_path_.c_str(), std::ifstream::binary);
|
||||
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file.");
|
||||
cache_file.seekg(0, cache_file.end);
|
||||
buffer_size = cache_file.tellg();
|
||||
ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered.");
|
||||
cache_file.seekg(0, cache_file.beg);
|
||||
// Skip Ort generated metadata
|
||||
if (ort_generated_ctx_cache_) {
|
||||
cache_file.seekg(ort_ctx_metadata_length_);
|
||||
buffer_size -= ort_ctx_metadata_length_;
|
||||
}
|
||||
|
||||
std::unique_ptr<unsigned char[]> buffer = std::make_unique<unsigned char[]>(buffer_size);
|
||||
ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for cache file.");
|
||||
|
||||
// Load file into buffer
|
||||
const auto& read_result = cache_file.read(reinterpret_cast<char*>(buffer.get()), buffer_size);
|
||||
cache_file.close();
|
||||
ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file.");
|
||||
|
||||
QnnSystemContext_Handle_t sys_ctx_handle = nullptr;
|
||||
auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle);
|
||||
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle.");
|
||||
|
|
@ -548,8 +474,8 @@ Status QnnBackendManager::LoadCachedQnnContext(QnnModel& qnn_model) {
|
|||
const QnnSystemContext_BinaryInfo_t* binary_info = nullptr;
|
||||
Qnn_ContextBinarySize_t binary_info_size{0};
|
||||
rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle,
|
||||
static_cast<void*>(buffer.get()),
|
||||
buffer_size,
|
||||
static_cast<void*>(const_cast<char*>(buffer.c_str())),
|
||||
static_cast<uint64_t>(buffer.length()),
|
||||
&binary_info,
|
||||
&binary_info_size);
|
||||
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info.");
|
||||
|
|
@ -576,12 +502,14 @@ Status QnnBackendManager::LoadCachedQnnContext(QnnModel& qnn_model) {
|
|||
rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
|
||||
device_handle_,
|
||||
(const QnnContext_Config_t**)&context_config_,
|
||||
static_cast<void*>(buffer.get()),
|
||||
buffer_size,
|
||||
static_cast<void*>(const_cast<char*>(buffer.c_str())),
|
||||
static_cast<uint64_t>(buffer.length()),
|
||||
&context_,
|
||||
profile_backend_handle_);
|
||||
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary.");
|
||||
|
||||
// More work to support multiple partition, how to map the graph name in compile to qnn graph name
|
||||
// Need the lower level framework to understand EPContext op and pass in the partition_name in fused_node during Compile
|
||||
ORT_RETURN_IF_ERROR(qnn_model.DeserializeGraphInfoFromBinaryInfo(graphs_info[0]));
|
||||
|
||||
qnn_sys_interface_.systemContextFree(sys_ctx_handle);
|
||||
|
|
@ -590,126 +518,10 @@ Status QnnBackendManager::LoadCachedQnnContext(QnnModel& qnn_model) {
|
|||
ORT_RETURN_IF_ERROR(ExtractBackendProfilingInfo());
|
||||
context_created_ = true;
|
||||
|
||||
model_description_.clear();
|
||||
model_description_from_ctx_cache_.clear();
|
||||
LOGS(*logger_, VERBOSE) << "Load from cached QNN Context completed.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* \brief: Read string data from binary file with given length
|
||||
* \param[in] binary_file - file stream of the binary file
|
||||
* \param[out] result_str - string read from binary file
|
||||
* \param[out] length - length to read
|
||||
*/
|
||||
Status ReadStringFromBinaryFile(std::ifstream& binary_file, std::string& result_str, size_t length) {
|
||||
result_str.resize(length);
|
||||
const auto& read_result = binary_file.read(result_str.data(), length);
|
||||
ORT_RETURN_IF(!read_result, "Failed to read contents from cached context binary file.");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* \brief: Read a uint16_t from binary file
|
||||
* \param[in] binary_file - file stream of the binary file
|
||||
* \param[out] value - uint16_t value
|
||||
*/
|
||||
Status ReadInt16FromBinaryFile(std::ifstream& binary_file, uint16_t& value) {
|
||||
std::unique_ptr<char[]> buffer = std::make_unique<char[]>(sizeof(uint16_t));
|
||||
ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for buffer.");
|
||||
const auto& read_result = binary_file.read(buffer.get(), sizeof(uint16_t));
|
||||
ORT_RETURN_IF(!read_result, "Failed to read contents from cached context binary file.");
|
||||
|
||||
auto src = gsl::make_span<const unsigned char>(reinterpret_cast<unsigned char*>(buffer.get()), sizeof(uint16_t));
|
||||
std::vector<uint16_t> dst(1);
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::utils::ReadLittleEndian(src, gsl::make_span(dst)));
|
||||
value = dst[0];
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* \brief: Try to get metadata from Ort generated context cache binary file.
|
||||
* Cached context binary file generated by Ort has some metadata which can be used for validation with the model
|
||||
* to avoid user choose a wrong context binary file which is not for this model
|
||||
* It is treated as Qnn generated context binary file if no metadata found from the file
|
||||
*/
|
||||
Status QnnBackendManager::GetMetadataFromOrtContextFile() {
|
||||
// Only try parse meta data once
|
||||
if (ctx_metadata_tried_) {
|
||||
return Status::OK();
|
||||
}
|
||||
ctx_metadata_tried_ = true;
|
||||
|
||||
uint64_t buffer_size = 0;
|
||||
std::ifstream cache_file(context_cache_path_.c_str(), std::ifstream::binary);
|
||||
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open context cache file.");
|
||||
cache_file.seekg(0, cache_file.end);
|
||||
buffer_size = cache_file.tellg();
|
||||
ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered.");
|
||||
cache_file.seekg(0, cache_file.beg);
|
||||
|
||||
// Read ort flag
|
||||
std::string ort_flag("");
|
||||
size_t ort_flag_length = strlen(QNN_PROVIDER);
|
||||
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, ort_flag, ort_flag_length));
|
||||
|
||||
// It's not Ort generated context binary file
|
||||
if (strncmp(ort_flag.c_str(), QNN_PROVIDER, ort_flag_length) != 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
ort_generated_ctx_cache_ = true;
|
||||
|
||||
uint16_t str_length = 0;
|
||||
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
|
||||
ort_ctx_metadata_length_ = str_length + static_cast<uint16_t>(ort_flag_length);
|
||||
|
||||
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
|
||||
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, model_name_from_ctx_cache_, static_cast<size_t>(str_length)));
|
||||
|
||||
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
|
||||
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, graph_name_from_ctx_cache_, static_cast<size_t>(str_length)));
|
||||
|
||||
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
|
||||
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, model_description_from_ctx_cache_, static_cast<size_t>(str_length)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* \brief: Validate the model file name and graph name with Ort generated context cache metadata
|
||||
* \param[in] model_name - model file name
|
||||
* \param[in] graph_name - graph name, e.g Ort_QNN_[hash_id]_[id]. Since GetCapability is called twice,
|
||||
* [hash_id]_[id] changes even for same graph,
|
||||
* so only validate the graph name for 2nd call
|
||||
*/
|
||||
Status QnnBackendManager::ValidateWithContextFile(const std::string& model_name, const std::string& graph_name) {
|
||||
ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!");
|
||||
|
||||
// Get metadata from cached context binary file
|
||||
ORT_RETURN_IF_ERROR(GetMetadataFromOrtContextFile());
|
||||
|
||||
// The context binary file doesn't have ORT metadata, so it is generated from QNN toolchain not from ORT
|
||||
if (!ort_generated_ctx_cache_) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ORT_RETURN_IF(model_name != model_name_from_ctx_cache_,
|
||||
"Model file name from context cache metadata: " + model_name_from_ctx_cache_ +
|
||||
" is different with target: " + model_name +
|
||||
". Please make sure the context binary file matches the model.");
|
||||
|
||||
ORT_RETURN_IF(model_description_ != model_description_from_ctx_cache_,
|
||||
"Model description from context cache metadata: " + model_description_from_ctx_cache_ +
|
||||
" is different with target: " + model_description_ +
|
||||
". Please make sure the context binary file matches the model.");
|
||||
|
||||
ORT_RETURN_IF(graph_name != graph_name_from_ctx_cache_ && get_capability_round_2_,
|
||||
"Graph name from context cache metadata: " + graph_name_from_ctx_cache_ +
|
||||
" is different with target: " + graph_name +
|
||||
". You may need to re-generate the context binary file.");
|
||||
|
||||
get_capability_round_2_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context) {
|
||||
if (backend_setup_completed_) {
|
||||
LOGS(logger, VERBOSE) << "Backend setup already!";
|
||||
|
|
@ -728,8 +540,9 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_
|
|||
ORT_RETURN_IF_ERROR(LoadQnnSystemLib());
|
||||
}
|
||||
|
||||
sdk_build_version_ = GetBackendBuildId();
|
||||
LOGS(logger, VERBOSE) << "Backend build version: "
|
||||
<< GetBackendBuildId();
|
||||
<< sdk_build_version_;
|
||||
|
||||
SetLogger(&logger);
|
||||
LOGS(logger, VERBOSE) << "SetLogger succeed.";
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ namespace onnxruntime {
|
|||
namespace qnn {
|
||||
|
||||
class QnnModel;
|
||||
class QnnCacheModelHandler;
|
||||
|
||||
class QnnBackendManager {
|
||||
public:
|
||||
|
|
@ -71,13 +72,11 @@ class QnnBackendManager {
|
|||
return CreateContext();
|
||||
}
|
||||
|
||||
Status DumpQnnContext(const std::string& model_name, const std::string& graph_name);
|
||||
std::unique_ptr<unsigned char[]> GetContextBinaryBuffer(uint64_t& written_buffer_size);
|
||||
|
||||
Status LoadCachedQnnContext(QnnModel& qnn_model);
|
||||
|
||||
Status GetMetadataFromOrtContextFile();
|
||||
|
||||
Status ValidateWithContextFile(const std::string& model_name, const std::string& graph_name);
|
||||
Status LoadCachedQnnCtxFromOnnxModel(const std::string& ep_engine_cache,
|
||||
QnnModel& qnn_model,
|
||||
bool& loaded_from_cache);
|
||||
|
||||
Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context);
|
||||
|
||||
|
|
@ -91,14 +90,6 @@ class QnnBackendManager {
|
|||
|
||||
const Qnn_ProfileHandle_t& GetQnnProfileHandle() { return profile_backend_handle_; }
|
||||
|
||||
std::string GetBackendBuildId() {
|
||||
char* backend_build_id{nullptr};
|
||||
if (QNN_SUCCESS != qnn_interface_.backendGetBuildId((const char**)&backend_build_id)) {
|
||||
LOGS(*logger_, ERROR) << "Unable to get build Id from the backend.";
|
||||
}
|
||||
return (backend_build_id == nullptr ? std::string("") : std::string(backend_build_id));
|
||||
}
|
||||
|
||||
void SetLogger(const logging::Logger* logger) {
|
||||
if (logger_ == nullptr) {
|
||||
logger_ = logger;
|
||||
|
|
@ -133,9 +124,7 @@ class QnnBackendManager {
|
|||
void SetQnnBackendType(uint32_t backend_id);
|
||||
QnnBackendType GetQnnBackendType() { return qnn_backend_type_; }
|
||||
|
||||
bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
|
||||
const std::string& model_description,
|
||||
const onnxruntime::PathString& model_pathstring);
|
||||
const std::string& GetSdkVersion() { return sdk_build_version_; }
|
||||
|
||||
private:
|
||||
void* LoadLib(const char* file_name, int flags, std::string& error_msg);
|
||||
|
|
@ -177,6 +166,16 @@ class QnnBackendManager {
|
|||
return ret;
|
||||
}
|
||||
|
||||
std::string GetBackendBuildId() {
|
||||
char* backend_build_id{nullptr};
|
||||
if (QNN_SUCCESS != qnn_interface_.backendGetBuildId((const char**)&backend_build_id)) {
|
||||
LOGS(*logger_, ERROR) << "Unable to get build Id from the backend.";
|
||||
}
|
||||
return (backend_build_id == nullptr ? std::string("") : std::string(backend_build_id));
|
||||
}
|
||||
|
||||
Status LoadCachedQnnContextFromBuffer(const std::string& buffer, QnnModel& qnn_model);
|
||||
|
||||
private:
|
||||
const std::string backend_path_;
|
||||
const logging::Logger* logger_ = nullptr;
|
||||
|
|
@ -201,16 +200,7 @@ class QnnBackendManager {
|
|||
std::vector<std::string> op_package_paths_;
|
||||
uint32_t rpc_control_latency_ = 0;
|
||||
HtpPerformanceMode htp_performance_mode_;
|
||||
std::string model_name_from_ctx_cache_ = "";
|
||||
std::string graph_name_from_ctx_cache_ = "";
|
||||
std::string model_description_from_ctx_cache_ = "";
|
||||
std::string model_description_ = "";
|
||||
std::string context_cache_path_ = "";
|
||||
bool ctx_file_exists_ = false;
|
||||
bool ctx_metadata_tried_ = false;
|
||||
bool ort_generated_ctx_cache_ = false;
|
||||
bool get_capability_round_2_ = false;
|
||||
uint16_t ort_ctx_metadata_length_ = 0;
|
||||
std::string sdk_build_version_ = "";
|
||||
#ifdef _WIN32
|
||||
std::set<HMODULE> mod_handles_;
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -36,15 +36,16 @@ Status QnnModel::SetGraphInputOutputInfo(const GraphViewer& graph_viewer,
|
|||
initializer_inputs_.emplace(graph_ini.first);
|
||||
}
|
||||
auto input_defs = fused_node.InputDefs();
|
||||
ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(input_defs, inputs_info_, model_input_index_map_, true));
|
||||
ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(input_defs, input_names_, inputs_info_, model_input_index_map_, true));
|
||||
|
||||
auto output_defs = fused_node.OutputDefs();
|
||||
ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(output_defs, outputs_info_, model_output_index_map_));
|
||||
ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(output_defs, output_names_, outputs_info_, model_output_index_map_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QnnModel::ParseGraphInputOrOutput(ConstPointerContainer<std::vector<NodeArg*>>& input_output_defs,
|
||||
std::vector<std::string>& input_output_names,
|
||||
std::unordered_map<std::string, OnnxTensorInfo>& input_output_info_table,
|
||||
std::unordered_map<std::string, size_t>& input_output_index_map,
|
||||
bool is_input) {
|
||||
|
|
@ -72,6 +73,7 @@ Status QnnModel::ParseGraphInputOrOutput(ConstPointerContainer<std::vector<NodeA
|
|||
int32_t data_type = type_proto->tensor_type().elem_type();
|
||||
// use index i so that for graph input, it has initializers included
|
||||
input_output_info_table.emplace(std::piecewise_construct, std::forward_as_tuple(name), std::forward_as_tuple(i, data_type, std::move(shape)));
|
||||
input_output_names.push_back(name);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ class QnnModel {
|
|||
Status SetGraphInputOutputInfo(const GraphViewer& graph_viewer,
|
||||
const onnxruntime::Node& fused_node);
|
||||
Status ParseGraphInputOrOutput(ConstPointerContainer<std::vector<NodeArg*>>& input_output_defs,
|
||||
std::vector<std::string>& input_output_names,
|
||||
std::unordered_map<std::string, OnnxTensorInfo>& input_output_info_table,
|
||||
std::unordered_map<std::string, size_t>& input_output_index,
|
||||
bool is_input = false);
|
||||
|
|
@ -74,6 +75,24 @@ class QnnModel {
|
|||
|
||||
Status DeserializeGraphInfoFromBinaryInfo(const QnnSystemContext_GraphInfo_t& qnn_sys_ctx_graph_info);
|
||||
|
||||
const std::vector<std::string>& GetInputNames() const {
|
||||
return input_names_;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& GetOutputNames() const {
|
||||
return output_names_;
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, OnnxTensorInfo>& GetInputsInfo() const {
|
||||
return inputs_info_;
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, OnnxTensorInfo>& GetOutputsInfo() const {
|
||||
return outputs_info_;
|
||||
}
|
||||
|
||||
const std::string& Name() { return graph_info_->Name(); }
|
||||
|
||||
private:
|
||||
const NodeUnit& GetNodeUnit(const Node* node,
|
||||
const std::unordered_map<const Node*, const NodeUnit*>& node_unit_map) const;
|
||||
|
|
@ -87,13 +106,13 @@ class QnnModel {
|
|||
|
||||
QnnBackendType GetQnnBackendType() { return qnn_backend_type_; }
|
||||
|
||||
private:
|
||||
size_t GetInputOutputIndex(const std::string& name, const std::unordered_map<std::string, OnnxTensorInfo>& io_info) const {
|
||||
auto it = io_info.find(name);
|
||||
ORT_ENFORCE(it != io_info.end(), "Input/Output name not found.");
|
||||
return it->second.index_;
|
||||
}
|
||||
|
||||
private:
|
||||
const logging::Logger& logger_;
|
||||
std::unique_ptr<GraphInfo> graph_info_;
|
||||
QnnBackendManager* qnn_backend_manager_ = nullptr;
|
||||
|
|
@ -102,6 +121,8 @@ class QnnModel {
|
|||
std::unordered_map<std::string, size_t> model_output_index_map_;
|
||||
// TODO: remove initializer_inputs_, use QnnModelWrapper
|
||||
std::unordered_set<std::string> initializer_inputs_;
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<std::string> output_names_;
|
||||
std::unordered_map<std::string, OnnxTensorInfo> inputs_info_;
|
||||
std::unordered_map<std::string, OnnxTensorInfo> outputs_info_;
|
||||
std::vector<Qnn_Tensor_t> qnn_inputs_;
|
||||
|
|
|
|||
|
|
@ -101,6 +101,14 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
|
|||
LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_;
|
||||
}
|
||||
|
||||
bool qnn_context_embed_mode = true;
|
||||
static const std::string CONTEXT_CACHE_EMBED_MODE = "qnn_context_embed_mode";
|
||||
auto context_cache_embed_mode_pos = runtime_options_.find(CONTEXT_CACHE_EMBED_MODE);
|
||||
if (context_cache_embed_mode_pos != runtime_options_.end()) {
|
||||
qnn_context_embed_mode = context_cache_embed_mode_pos->second == "1";
|
||||
LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode;
|
||||
}
|
||||
|
||||
static const std::string BACKEND_PATH = "backend_path";
|
||||
auto backend_path_pos = runtime_options_.find(BACKEND_PATH);
|
||||
|
||||
|
|
@ -147,6 +155,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
|
|||
rpc_control_latency_,
|
||||
htp_performance_mode_,
|
||||
std::move(qnn_saver_path));
|
||||
qnn_cache_model_handler_ = std::make_unique<qnn::QnnCacheModelHandler>(qnn_context_embed_mode);
|
||||
}
|
||||
|
||||
bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
|
||||
|
|
@ -262,10 +271,16 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
|
|||
|
||||
const auto& logger = *GetLogger();
|
||||
bool load_from_cached_context = false;
|
||||
if (context_cache_enabled_) {
|
||||
load_from_cached_context = qnn_backend_manager_->IsContextCacheFileExists(context_cache_path_,
|
||||
graph_viewer.Description(),
|
||||
graph_viewer.ModelPath().ToPathString());
|
||||
bool is_qnn_ctx_model = qnn::IsQnnCtxModel(graph_viewer);
|
||||
if (is_qnn_ctx_model) {
|
||||
load_from_cached_context = true;
|
||||
}
|
||||
|
||||
// This is for case: QDQ model + Onnx Qnn context cache model
|
||||
if (context_cache_enabled_ && !is_qnn_ctx_model) {
|
||||
load_from_cached_context = qnn_cache_model_handler_->IsContextCacheFileExists(context_cache_path_,
|
||||
graph_viewer.Description(),
|
||||
graph_viewer.ModelPath().ToPathString());
|
||||
}
|
||||
|
||||
// Load from cached context will load the QnnSystem lib and skip the Qnn context creation
|
||||
|
|
@ -275,7 +290,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
|
|||
return result;
|
||||
}
|
||||
|
||||
if (context_cache_enabled_ && !IsNpuBackend(qnn_backend_manager_->GetQnnBackendType())) {
|
||||
if ((context_cache_enabled_ || is_qnn_ctx_model) && !IsNpuBackend(qnn_backend_manager_->GetQnnBackendType())) {
|
||||
LOGS(logger, ERROR) << "Qnn context cache only works for HTP or DSP backend.";
|
||||
return result;
|
||||
}
|
||||
|
|
@ -365,9 +380,10 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
|
|||
|
||||
const size_t num_of_partitions = result.size();
|
||||
|
||||
if (load_from_cached_context && 1 == num_of_partitions) {
|
||||
rt = qnn_backend_manager_->ValidateWithContextFile(GetFileNameFromModelPath(graph_viewer.ModelPath()),
|
||||
result[0]->sub_graph->GetMetaDef()->name);
|
||||
if (!is_qnn_ctx_model && load_from_cached_context && 1 == num_of_partitions) {
|
||||
rt = qnn_cache_model_handler_->ValidateWithContextFile(GetFileNameFromModelPath(graph_viewer.ModelPath()),
|
||||
result[0]->sub_graph->GetMetaDef()->name,
|
||||
logger);
|
||||
if (Status::OK() != rt) {
|
||||
LOGS(logger, ERROR) << "QNN failed to validate context cache metadata: " << rt.ErrorMessage();
|
||||
return result;
|
||||
|
|
@ -447,22 +463,28 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector<FusedNodeAndG
|
|||
Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
const auto& logger = *GetLogger();
|
||||
Node& fused_node = fused_nodes_and_graphs[0].fused_node;
|
||||
const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[0].filtered_graph);
|
||||
|
||||
if (context_cache_enabled_) {
|
||||
bool is_qnn_ctx_model = false;
|
||||
ORT_RETURN_IF_ERROR(qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs, is_qnn_ctx_model));
|
||||
|
||||
if (context_cache_enabled_ || is_qnn_ctx_model) {
|
||||
ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support single partition for context cache feature.");
|
||||
Node& fused_node = fused_nodes_and_graphs[0].fused_node;
|
||||
const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[0].filtered_graph);
|
||||
// The dumy_model_description won't be used since IsContextCacheFileExists call cached the result
|
||||
// The graph_viewer.Description here is not same with original model
|
||||
std::string dumy_model_description = "";
|
||||
bool load_from_cached_context = qnn_backend_manager_->IsContextCacheFileExists(context_cache_path_,
|
||||
dumy_model_description,
|
||||
graph_viewer.ModelPath().ToPathString());
|
||||
std::unique_ptr<qnn::QnnModel> qnn_model = std::make_unique<qnn::QnnModel>(logger, qnn_backend_manager_.get());
|
||||
bool loaded_from_cache = false;
|
||||
std::string ep_engine_cache;
|
||||
ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->GetEpContext(graph_viewer,
|
||||
context_cache_path_,
|
||||
is_qnn_ctx_model,
|
||||
qnn_cache_model_handler_->GetIsContextCacheFileExists(),
|
||||
ep_engine_cache,
|
||||
logger));
|
||||
ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnCtxFromOnnxModel(ep_engine_cache,
|
||||
*(qnn_model.get()),
|
||||
loaded_from_cache));
|
||||
// Load and execute from cached context if exist
|
||||
if (load_from_cached_context) {
|
||||
std::unique_ptr<qnn::QnnModel> qnn_model = std::make_unique<qnn::QnnModel>(logger,
|
||||
qnn_backend_manager_.get());
|
||||
ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnContext(*(qnn_model.get())));
|
||||
if (loaded_from_cache) {
|
||||
ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node));
|
||||
ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput());
|
||||
|
||||
|
|
@ -473,18 +495,22 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
|
|||
|
||||
ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger));
|
||||
return Status::OK();
|
||||
} else {
|
||||
// Load and execute from Onnx model if not exit and dump the context
|
||||
ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger));
|
||||
// graph_viewer.Name() is generated in GetCapability, e.g QNN_[hash_id]_[id]
|
||||
// dump graph_viewer.Name() as metadata in context cache binary file, so that we can validate it in GetCapability
|
||||
ORT_RETURN_IF_ERROR(qnn_backend_manager_->DumpQnnContext(GetFileNameFromModelPath(graph_viewer.ModelPath()),
|
||||
graph_viewer.Name()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger));
|
||||
if (context_cache_enabled_ && !is_qnn_ctx_model) {
|
||||
ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support single partition for context cache feature.");
|
||||
uint64_t buffer_size(0);
|
||||
auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size);
|
||||
ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->GenerateCtxCacheOnnxModel(context_buffer.get(),
|
||||
buffer_size,
|
||||
qnn_backend_manager_->GetSdkVersion(),
|
||||
fused_nodes_and_graphs,
|
||||
qnn_models_,
|
||||
logger));
|
||||
}
|
||||
qnn_cache_model_handler_.reset();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include <string>
|
||||
#include "core/providers/qnn/builder/qnn_backend_manager.h"
|
||||
#include "core/providers/qnn/builder/qnn_model.h"
|
||||
#include "core/providers/qnn/builder/onnx_ctx_model_helper.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -66,6 +67,7 @@ class QNNExecutionProvider : public IExecutionProvider {
|
|||
bool context_cache_enabled_ = false;
|
||||
std::string context_cache_path_ = "";
|
||||
bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session.
|
||||
std::unique_ptr<qnn::QnnCacheModelHandler> qnn_cache_model_handler_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -56,6 +56,8 @@ void usage() {
|
|||
"\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n"
|
||||
"\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n"
|
||||
"\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n"
|
||||
"\t [QNN only] [qnn_context_embed_mode]: 1 means dump the QNN context binary into the Onnx skeleton model.\n"
|
||||
"\t 0 means dump the QNN context binary into separate bin file and set the path in the Onnx skeleton model.\n"
|
||||
"\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n"
|
||||
"\t [Usage]: -e <provider_name> -i '<key1>|<value1> <key2>|<value2>' \n\n"
|
||||
"\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n"
|
||||
|
|
@ -454,6 +456,10 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
if (value.empty()) {
|
||||
ORT_THROW("Please provide the QNN backend path.");
|
||||
}
|
||||
} else if (key == "qnn_context_embed_mode") {
|
||||
if (value != "0") {
|
||||
ORT_THROW("Set to 0 to disable qnn_context_embed_mode.");
|
||||
}
|
||||
} else if (key == "qnn_context_cache_enable") {
|
||||
if (value != "1") {
|
||||
ORT_THROW("Set to 1 to enable qnn_context_cache_enable.");
|
||||
|
|
|
|||
|
|
@ -261,7 +261,8 @@ template <typename QuantType>
|
|||
inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTestQDQModelFn<QuantType>& qdq_model_fn,
|
||||
ProviderOptions qnn_options, int opset_version,
|
||||
ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err = 1e-4f,
|
||||
logging::Severity log_severity = logging::Severity::kERROR) {
|
||||
logging::Severity log_severity = logging::Severity::kERROR,
|
||||
const std::string& qnn_ctx_model_path = "") {
|
||||
// Add kMSDomain to cover contrib op like Gelu
|
||||
const std::unordered_map<std::string, int> domain_to_version = {{"", opset_version}, {kMSDomain, 1}};
|
||||
|
||||
|
|
@ -321,8 +322,21 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe
|
|||
// Run QDQ model on QNN EP and collect outputs.
|
||||
TryEnableQNNSaver(qnn_options);
|
||||
std::vector<OrtValue> qnn_qdq_outputs;
|
||||
InferenceModel(qdq_model_data, "qdq_model_logger", QnnExecutionProviderWithOptions(qnn_options),
|
||||
expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs);
|
||||
if (!qnn_ctx_model_path.empty()) {
|
||||
onnx::ModelProto model_proto;
|
||||
onnxruntime::Model qnn_ctx_model;
|
||||
// Load the QNN context cache model from path specified
|
||||
ASSERT_STATUS_OK(qnn_ctx_model.Load(ToPathString(qnn_ctx_model_path), model_proto));
|
||||
std::string qnn_ctx_model_data;
|
||||
model_proto.SerializeToString(&qnn_ctx_model_data);
|
||||
// Run QNN context cache model on QNN EP and collect outputs.
|
||||
InferenceModel(qnn_ctx_model_data, "qnn_ctx_model_logger", QnnExecutionProviderWithOptions(qnn_options),
|
||||
expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs);
|
||||
} else {
|
||||
// Run QDQ model on QNN EP and collect outputs.
|
||||
InferenceModel(qdq_model_data, "qdq_model_logger", QnnExecutionProviderWithOptions(qnn_options),
|
||||
expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs);
|
||||
}
|
||||
|
||||
if (expected_ep_assignment != ExpectedEPNodeAssignment::None) {
|
||||
// Run QDQ model on CPU EP and collect outputs.
|
||||
|
|
|
|||
|
|
@ -679,10 +679,11 @@ TEST_F(QnnHTPBackendTests, SpaceToDepthOp_U16) {
|
|||
true); // Use com.microsoft domain for Q/DQ ops
|
||||
}
|
||||
|
||||
// Run QDQ model on HTP twice
|
||||
// 1st run will generate the Qnn context cache binary file
|
||||
// 2nd run will load and run from Qnn context cache binary file
|
||||
TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) {
|
||||
// Run QDQ model on HTP 3 times
|
||||
// 1st run will generate the Qnn context cache onnx file
|
||||
// 2nd run will load and run from QDQ model + Qnn context cache model
|
||||
// 3rd run directly loads and run from Qnn context cache model
|
||||
TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) {
|
||||
ProviderOptions provider_options;
|
||||
#if defined(_WIN32)
|
||||
provider_options["backend_path"] = "QnnHtp.dll";
|
||||
|
|
@ -690,7 +691,7 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) {
|
|||
provider_options["backend_path"] = "libQnnHtp.so";
|
||||
#endif
|
||||
provider_options["qnn_context_cache_enable"] = "1";
|
||||
const std::string context_binary_file = "./qnn_context_binary_test.bin";
|
||||
const std::string context_binary_file = "./qnn_context_binary_test.onnx";
|
||||
provider_options["qnn_context_cache_path"] = context_binary_file;
|
||||
|
||||
const TestInputDef<float> input_def({1, 2, 3}, false, -10.0f, 10.0f);
|
||||
|
|
@ -707,12 +708,120 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) {
|
|||
// Make sure the Qnn context cache binary file is generated
|
||||
EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str()));
|
||||
|
||||
// 2nd run will load and run from Qnn context cache binary file
|
||||
// 2nd run loads and run from QDQ model + Qnn context cache model
|
||||
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def}, {}, {}),
|
||||
BuildQDQOpTestCase<uint8_t>(op_type, {input_def}, {}, {}),
|
||||
provider_options,
|
||||
14,
|
||||
ExpectedEPNodeAssignment::All);
|
||||
|
||||
// 3rd run directly loads and run from Qnn context cache model
|
||||
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def}, {}, {}),
|
||||
BuildQDQOpTestCase<uint8_t>(op_type, {input_def}, {}, {}),
|
||||
provider_options,
|
||||
14,
|
||||
ExpectedEPNodeAssignment::All,
|
||||
1e-4f,
|
||||
logging::Severity::kERROR,
|
||||
context_binary_file);
|
||||
}
|
||||
|
||||
// Run QDQ model on HTP 3 times
|
||||
// 1st run will generate the Onnx skeleton file + Qnn context cache binary file
|
||||
// 2nd run will loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file
|
||||
// 3rd run directly loads and run from Onnx skeleton file + Qnn context cache binary file
|
||||
TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) {
|
||||
ProviderOptions provider_options;
|
||||
#if defined(_WIN32)
|
||||
provider_options["backend_path"] = "QnnHtp.dll";
|
||||
#else
|
||||
provider_options["backend_path"] = "libQnnHtp.so";
|
||||
#endif
|
||||
provider_options["qnn_context_cache_enable"] = "1";
|
||||
const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx";
|
||||
provider_options["qnn_context_cache_path"] = context_binary_file;
|
||||
provider_options["qnn_context_embed_mode"] = "0";
|
||||
|
||||
const TestInputDef<float> input_def({1, 2, 3}, false, -10.0f, 10.0f);
|
||||
const std::string op_type = "Atan";
|
||||
|
||||
// Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs.
|
||||
// 1st run will generate the Onnx skeleton file + Qnn context cache binary file
|
||||
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def}, {}, {}),
|
||||
BuildQDQOpTestCase<uint8_t>(op_type, {input_def}, {}, {}),
|
||||
provider_options,
|
||||
14,
|
||||
ExpectedEPNodeAssignment::All);
|
||||
|
||||
// Check the Onnx skeleton file is generated
|
||||
EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str()));
|
||||
// Check the Qnn context cache binary file is generated
|
||||
EXPECT_TRUE(std::filesystem::exists("qnn_context_cache_non_embed.onnx_QNN_8283143575221199085_1.bin"));
|
||||
|
||||
// 2nd run loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file
|
||||
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def}, {}, {}),
|
||||
BuildQDQOpTestCase<uint8_t>(op_type, {input_def}, {}, {}),
|
||||
provider_options,
|
||||
14,
|
||||
ExpectedEPNodeAssignment::All);
|
||||
|
||||
// 3rd run directly loads and run from Onnx skeleton file + Qnn context cache binary file
|
||||
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def}, {}, {}),
|
||||
BuildQDQOpTestCase<uint8_t>(op_type, {input_def}, {}, {}),
|
||||
provider_options,
|
||||
14,
|
||||
ExpectedEPNodeAssignment::All,
|
||||
1e-4f,
|
||||
logging::Severity::kERROR,
|
||||
context_binary_file);
|
||||
}
|
||||
|
||||
// Run QDQ model on HTP with 2 inputs
|
||||
// 1st run will generate the Qnn context cache onnx file
|
||||
// 2nd run will load and run from QDQ model + Qnn context cache model
|
||||
// 3rd run directly loads and run from Qnn context cache model
|
||||
TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) {
|
||||
ProviderOptions provider_options;
|
||||
#if defined(_WIN32)
|
||||
provider_options["backend_path"] = "QnnHtp.dll";
|
||||
#else
|
||||
provider_options["backend_path"] = "libQnnHtp.so";
|
||||
#endif
|
||||
provider_options["qnn_context_cache_enable"] = "1";
|
||||
const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx";
|
||||
provider_options["qnn_context_cache_path"] = context_binary_file;
|
||||
|
||||
const TestInputDef<float> input_def1({1, 2, 3}, false, -10.0f, 10.0f);
|
||||
const TestInputDef<float> input_def2({1, 2, 3}, false, -10.0f, 10.0f);
|
||||
const std::string op_type = "Add";
|
||||
|
||||
// Runs model with DQ-> Add-> Q and compares the outputs of the CPU and QNN EPs.
|
||||
// 1st run will generate the Qnn context cache binary file
|
||||
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def1, input_def2}, {}, {}),
|
||||
BuildQDQOpTestCase<uint8_t>(op_type, {input_def1, input_def2}, {}, {}),
|
||||
provider_options,
|
||||
14,
|
||||
ExpectedEPNodeAssignment::All);
|
||||
|
||||
// Make sure the Qnn context cache binary file is generated
|
||||
EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str()));
|
||||
|
||||
// 2nd run loads and run from QDQ model + Qnn context cache model
|
||||
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def1, input_def2}, {}, {}),
|
||||
BuildQDQOpTestCase<uint8_t>(op_type, {input_def1, input_def2}, {}, {}),
|
||||
provider_options,
|
||||
14,
|
||||
ExpectedEPNodeAssignment::All);
|
||||
|
||||
// 3rd run directly loads and run from Qnn context cache model
|
||||
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def1, input_def2}, {}, {}),
|
||||
BuildQDQOpTestCase<uint8_t>(op_type, {input_def1, input_def2}, {}, {}),
|
||||
provider_options,
|
||||
14,
|
||||
ExpectedEPNodeAssignment::All,
|
||||
1e-4f,
|
||||
logging::Severity::kERROR,
|
||||
context_binary_file);
|
||||
}
|
||||
|
||||
TEST_F(QnnHTPBackendTests, QuantAccuracyTest) {
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ jobs:
|
|||
inputs:
|
||||
script: |
|
||||
./build/Release/onnx_test_runner -e qnn \
|
||||
-v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so qnn_context_cache_enable|1 qnn_context_cache_path|./build/Release/mobilenet_qdq.bin" \
|
||||
-v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so qnn_context_cache_enable|1 qnn_context_cache_path|./build/Release/mobilenet_qdq.onnx_qnn_ctx.onnx" \
|
||||
/data/qdq_models/mobilenetv2-1.0_add_transpose_quant
|
||||
|
||||
- task: CmdLine@2
|
||||
|
|
@ -118,5 +118,5 @@ jobs:
|
|||
inputs:
|
||||
script: |
|
||||
./build/Release/onnx_test_runner -e qnn \
|
||||
-v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so qnn_context_cache_enable|1 qnn_context_cache_path|./build/Release/mobilenet_qdq.bin" \
|
||||
-v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so qnn_context_cache_enable|1 qnn_context_cache_path|./build/Release/mobilenet_qdq.onnx_qnn_ctx.onnx" \
|
||||
/data/qdq_models/mobilenetv2-1.0_add_transpose_quant
|
||||
|
|
|
|||
Loading…
Reference in a new issue