mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
[QNN EP] Qnn batchnorm Op support (#15222)
### Description Support BatchNorm Op in Qnn EP Node Unit group support for BatchNorm, Exp ops ### Motivation and Context Enable more models.
This commit is contained in:
parent
0ea965c541
commit
9ef11f1c6a
11 changed files with 316 additions and 7 deletions
|
|
@ -349,6 +349,30 @@ bool InstanceNormalizationNodeGroupSelector::Check(const GraphViewer& graph_view
|
|||
(dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32);
|
||||
}
|
||||
|
||||
bool BatchNormalizationNodeGroupSelector::Check(const GraphViewer& graph_viewer,
|
||||
const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes) const {
|
||||
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
int32_t dt_scale = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
if (dt_input != dt_output) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) {
|
||||
if (!int8_allowed_ || dt_scale != dt_input) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace QDQ
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
|
|
@ -147,6 +147,20 @@ class InstanceNormalizationNodeGroupSelector : public NodeGroupSelector {
|
|||
const std::vector<const Node*>& q_nodes) const override;
|
||||
};
|
||||
|
||||
// DQ nodes for X, W and optionally B, not used for mean, var -> node -> Q
|
||||
class BatchNormalizationNodeGroupSelector : public NodeGroupSelector {
|
||||
public:
|
||||
// default to 'true'
|
||||
BatchNormalizationNodeGroupSelector(bool int8_allowed = true) : int8_allowed_(int8_allowed) {}
|
||||
|
||||
private:
|
||||
bool Check(const GraphViewer& graph_viewer, const Node& node,
|
||||
const std::vector<const Node*>& dq_nodes,
|
||||
const std::vector<const Node*>& q_nodes) const override;
|
||||
|
||||
bool int8_allowed_;
|
||||
};
|
||||
|
||||
/*
|
||||
* NodeSelector instances for use in the QDQ::SelectorActionTransformer.
|
||||
*/
|
||||
|
|
@ -246,6 +260,13 @@ class InstanceNormalizationSelector : public BaseSelector {
|
|||
: BaseSelector(std::make_unique<InstanceNormalizationNodeGroupSelector>()) {}
|
||||
};
|
||||
|
||||
// DQ nodes for X, W and optionally B, (mean, var not required) -> node -> Q
|
||||
class BatchNormalizationSelector : public BaseSelector {
|
||||
public:
|
||||
BatchNormalizationSelector(bool int8_allowed = false)
|
||||
: BaseSelector(std::make_unique<BatchNormalizationNodeGroupSelector>(int8_allowed)) {}
|
||||
};
|
||||
|
||||
} // namespace QDQ
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
|
|
@ -53,7 +53,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() {
|
|||
{"Slice", {}},
|
||||
{"Softmax", {}},
|
||||
{"Sqrt", {}},
|
||||
{"Tanh", {}}};
|
||||
{"Tanh", {}},
|
||||
{"Exp", {}}};
|
||||
}
|
||||
static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() {
|
||||
return {{"Add", {}},
|
||||
|
|
@ -80,6 +81,9 @@ static const OpVersionsAndSelector::OpVersionsMap GetGemmOpVersionsMap() {
|
|||
static const OpVersionsAndSelector::OpVersionsMap GetInstanceNormalizationOpVersionsMap() {
|
||||
return {{"InstanceNormalization", {}}};
|
||||
}
|
||||
static const OpVersionsAndSelector::OpVersionsMap GetBatchNormalizationOpVersionsMap() {
|
||||
return {{"BatchNormalization", {}}};
|
||||
}
|
||||
|
||||
/* Selector rules registration related */
|
||||
void RegisterMiscSelectors(Selectors& qdq_selectors) {
|
||||
|
|
@ -146,6 +150,13 @@ void RegisterInstanceNormalizationSelector(Selectors& qdq_selectors) {
|
|||
std::move(selector));
|
||||
}
|
||||
|
||||
void RegisterBatchNormalizationSelector(Selectors& qdq_selectors) {
|
||||
/* register selector for BatchNormalization op */
|
||||
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<BatchNormalizationNodeGroupSelector>();
|
||||
qdq_selectors.RegisterSelector(GetBatchNormalizationOpVersionsMap(),
|
||||
std::move(selector));
|
||||
}
|
||||
|
||||
void SelectorManager::CreateSelectors() {
|
||||
RegisterMiscSelectors(qdq_selectors_);
|
||||
RegisterUnarySelectors(qdq_selectors_);
|
||||
|
|
@ -156,6 +167,7 @@ void SelectorManager::CreateSelectors() {
|
|||
RegisterMatMulSelector(qdq_selectors_);
|
||||
RegisterGemmSelector(qdq_selectors_);
|
||||
RegisterInstanceNormalizationSelector(qdq_selectors_);
|
||||
RegisterBatchNormalizationSelector(qdq_selectors_);
|
||||
}
|
||||
|
||||
void SelectorManager::InitializeSelectorsMap() {
|
||||
|
|
|
|||
|
|
@ -131,6 +131,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
|
|||
{
|
||||
CreateInstanceNormOpBuilder("InstanceNormalization", *this);
|
||||
}
|
||||
|
||||
{
|
||||
CreateBatchNormOpBuilder("BatchNormalization", *this);
|
||||
}
|
||||
}
|
||||
|
||||
const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) {
|
||||
|
|
|
|||
|
|
@ -82,5 +82,7 @@ void CreateInstanceNormOpBuilder(const std::string& op_type, OpBuilderRegistrati
|
|||
|
||||
void CreateReduceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
||||
void CreateBatchNormOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
||||
} // namespace qnn
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -147,7 +147,8 @@ Status BaseOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
|
|||
ORT_UNUSED_PARAMETER(do_op_validation);
|
||||
|
||||
const auto& inputs = node_unit.Inputs();
|
||||
for (size_t input_i = 0; input_i < inputs.size(); ++input_i) {
|
||||
const auto input_count = GetInputCountQnnRequired(node_unit);
|
||||
for (size_t input_i = 0; input_i < input_count; ++input_i) {
|
||||
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[input_i], logger, is_quantized_model, input_names));
|
||||
}
|
||||
|
||||
|
|
@ -180,7 +181,6 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper,
|
|||
// Add output
|
||||
// Output part is common for all Ops, only difference is the Op attribute
|
||||
const auto& outputs = node_unit.Outputs();
|
||||
auto output_size = outputs.size();
|
||||
std::vector<std::string> output_names;
|
||||
struct CastNodeInfo {
|
||||
std::string node_name;
|
||||
|
|
@ -189,7 +189,8 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper,
|
|||
};
|
||||
std::vector<CastNodeInfo> cast_node_info_vec;
|
||||
|
||||
for (size_t output_i = 0; output_i < output_size && output_i < output_count_; ++output_i) {
|
||||
const auto output_count = GetOutputCountQnnRequired(node_unit);
|
||||
for (size_t output_i = 0; output_i < output_count; ++output_i) {
|
||||
const auto& output_name = outputs[output_i].node_arg.Name();
|
||||
|
||||
Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT;
|
||||
|
|
|
|||
|
|
@ -162,7 +162,8 @@ class BaseOpBuilder : public IOpBuilder {
|
|||
{"ConvTranspose", "TransposeConv2d"},
|
||||
{"Tile", "Tile"},
|
||||
{"TopK", "TopK"},
|
||||
{"InstanceNormalization", "InstanceNorm"}};
|
||||
{"InstanceNormalization", "InstanceNorm"},
|
||||
{"BatchNormalization", "Batchnorm"}};
|
||||
auto it = onnx_op_type_to_qnn_op_type.find(onnx_op_type);
|
||||
ORT_ENFORCE(it != onnx_op_type_to_qnn_op_type.end());
|
||||
return it->second;
|
||||
|
|
@ -252,7 +253,32 @@ class BaseOpBuilder : public IOpBuilder {
|
|||
int32_t& default_axis_value) const;
|
||||
Qnn_TensorType_t GetInputTensorType(const QnnModelWrapper& qnn_model_wrapper, const std::string& input_name) const;
|
||||
|
||||
mutable size_t output_count_ = std::numeric_limits<uint32_t>::max();
|
||||
size_t GetInputCountQnnRequired(const NodeUnit& node_unit) const {
|
||||
auto input_output_cout = GetInputOutputCountQnnRequired(node_unit.OpType());
|
||||
|
||||
return 0 == input_output_cout.first ? node_unit.Inputs().size() : input_output_cout.first;
|
||||
}
|
||||
|
||||
size_t GetOutputCountQnnRequired(const NodeUnit& node_unit) const {
|
||||
auto input_output_cout = GetInputOutputCountQnnRequired(node_unit.OpType());
|
||||
|
||||
return 0 == input_output_cout.second ? node_unit.Outputs().size() : input_output_cout.second;
|
||||
}
|
||||
|
||||
private:
|
||||
static const std::pair<size_t, size_t> GetInputOutputCountQnnRequired(std::string onnx_op_type) {
|
||||
static const std::unordered_map<std::string, std::pair<size_t, size_t>> input_output_count_qnn_required = {
|
||||
{"GlobalAveragePool", {0, 1}},
|
||||
{"MaxPool", {0, 1}},
|
||||
{"BatchNormalization", {3, 1}}};
|
||||
|
||||
auto pos = input_output_count_qnn_required.find(onnx_op_type);
|
||||
if (pos == input_output_count_qnn_required.end()) {
|
||||
return std::make_pair< size_t, size_t>(0, 0);
|
||||
} else {
|
||||
return pos->second;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::string op_builder_type_;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,94 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
|
||||
#include "core/providers/qnn/builder/op_builder_factory.h"
|
||||
|
||||
#include "base_op_builder.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace qnn {
|
||||
class BatchNormOpBuilder : public BaseOpBuilder {
|
||||
public:
|
||||
BatchNormOpBuilder() : BaseOpBuilder("BatchNormOpBuilder") {}
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(BatchNormOpBuilder);
|
||||
|
||||
Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
|
||||
const NodeUnit& node_unit,
|
||||
const logging::Logger& logger,
|
||||
bool is_quantized_model) const override final ORT_MUST_USE_RESULT;
|
||||
};
|
||||
|
||||
// BatchNorm is sensitive with data layout, no special validation so far
|
||||
// The nodes from 1st call of GetCapability do not get layout transformer applied, it's still NCHW
|
||||
// The nodes from 2nd call of GetCapability get layout transformer applied, it's NHWC
|
||||
Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
|
||||
const NodeUnit& node_unit,
|
||||
const logging::Logger& logger,
|
||||
bool is_quantized_model) const {
|
||||
if (node_unit.Domain() == kMSInternalNHWCDomain) {
|
||||
// It's useless to fallback the node after layout transformation because CPU EP can't support it anyway
|
||||
// Still do it here so hopefully QNN Op validation API can tell us some details why it's not supported
|
||||
return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, is_quantized_model, true);
|
||||
} else {
|
||||
NodeAttrHelper node_helper(node_unit);
|
||||
const float default_epsilon = 1e-05f;
|
||||
const float epsilon = node_helper.Get("epsilon", 1e-05f); // Default is 1e-05 according to ONNX spec.
|
||||
ORT_RETURN_IF(abs(epsilon - default_epsilon) > default_epsilon, "QNN BatchNorm doesn't support epsilon.");
|
||||
|
||||
const auto float_elem_type = ONNX_NAMESPACE::Utils::DataTypeUtils::ToType("float");
|
||||
|
||||
const auto& inputs = node_unit.Inputs();
|
||||
ORT_ENFORCE(inputs.size() == 5, "5 input expected per BatchNorm Onnx Spec.");
|
||||
// Check input type is float for CPU.
|
||||
ONNX_NAMESPACE::DataType input_data_type = inputs[0].node_arg.Type();
|
||||
ORT_RETURN_IF(!is_quantized_model && input_data_type != float_elem_type, "QNN BatchNorm CPU only support float32.");
|
||||
|
||||
std::vector<uint32_t> input_shape;
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape of input 0.");
|
||||
const size_t input_rank = input_shape.size();
|
||||
|
||||
ORT_RETURN_IF(input_rank <= 2 || input_rank > 4,
|
||||
"QNN BatchNorm only supports input ranks of size 3 or 4.");
|
||||
|
||||
const uint32_t num_channels = input_shape[1];
|
||||
|
||||
std::vector<uint32_t> scale_shape;
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, scale_shape), "Cannot get shape of input 1 (scale).");
|
||||
ORT_RETURN_IF(scale_shape.size() != 1 || scale_shape[0] != num_channels,
|
||||
"QNN BatchNorm input 1 (scale) must have 1D shape [channel].");
|
||||
|
||||
std::vector<uint32_t> bias_shape;
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[2].node_arg, bias_shape), "Cannot get shape of input 2 (bias).");
|
||||
ORT_RETURN_IF(bias_shape.size() != 1 || bias_shape[0] != num_channels,
|
||||
"QNN BatchNorm input 2 (bias) must have 1D shape [channel].");
|
||||
|
||||
std::vector<uint32_t> mean_shape;
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[3].node_arg, mean_shape), "Cannot get shape of input 3 (mean).");
|
||||
ORT_RETURN_IF(mean_shape.size() != 1 || mean_shape[0] != num_channels,
|
||||
"QNN BatchNorm input 3 (mean) must have 1D shape [channel].");
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[3].node_arg.Name()), "QNN BatchNorm doesn't support dynamic mean.");
|
||||
|
||||
std::vector<uint32_t> var_shape;
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[4].node_arg, var_shape), "Cannot get shape of input 4 (var).");
|
||||
ORT_RETURN_IF(var_shape.size() != 1 || var_shape[0] != num_channels,
|
||||
"QNN BatchNorm input 4 (var) must have 1D shape [channel].");
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[4].node_arg.Name()), "QNN BatchNorm doesn't support dynamic var.");
|
||||
|
||||
ORT_RETURN_IF(node_unit.Outputs().size() > 1, "QNN BatchNorm only support 1 output.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CreateBatchNormOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.AddOpBuilder(op_type, std::make_unique<BatchNormOpBuilder>());
|
||||
}
|
||||
|
||||
} // namespace qnn
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -215,7 +215,6 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
|
|||
qnn_model_wrapper.AddParamWrapper(std::move(count_pad_for_edges_param));
|
||||
}
|
||||
|
||||
output_count_ = 1;
|
||||
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit,
|
||||
std::move(input_names),
|
||||
std::move(param_tensor_names),
|
||||
|
|
|
|||
|
|
@ -204,9 +204,17 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
|
|||
std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer);
|
||||
|
||||
const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(), logger);
|
||||
|
||||
if (supported_nodes.empty()) {
|
||||
LOGS(logger, INFO) << "Number of partitions supported by QNN EP: 0";
|
||||
return result;
|
||||
} else if (supported_nodes.size() == 1) {
|
||||
const auto* node = *supported_nodes.begin();
|
||||
if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") {
|
||||
LOGS(logger, INFO) << "It doesn't make sense just run a Q/DQ node on HTP.";
|
||||
LOGS(logger, INFO) << "Number of partitions supported by QNN EP: 0";
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
const auto gen_metadef_name = [&]() {
|
||||
|
|
|
|||
118
onnxruntime/test/providers/qnn/batch_norm_htp_test.cc
Normal file
118
onnxruntime/test/providers/qnn/batch_norm_htp_test.cc
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#include <string>
|
||||
#include "core/graph/graph.h"
|
||||
|
||||
#include "test/optimizer/qdq_test_utils.h"
|
||||
#include "test/providers/qnn/qnn_test_utils.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
|
||||
|
||||
// Creates the graph:
|
||||
// _______________________
|
||||
// input_u8 -> DQ -> | |
|
||||
// scale_u8 (initializer) -> DQ -> | |
|
||||
// bias_u8 (initializer) -> DQ -> | BatchNormalization | -> Q -> output_u8
|
||||
// mean_u8 (initializer) -> DQ -> | |
|
||||
// var_u8 (initializer) -> DQ -> |_______________________|
|
||||
//
|
||||
// Currently used to test QNN EP.
|
||||
template <typename InputQType, typename ScaleQType, typename BiasQType>
|
||||
GetQDQTestCaseFn BuildQDQBatchNormTestCase(const std::vector<int64_t>& input_shape) {
|
||||
return [input_shape](ModelTestBuilder& builder) {
|
||||
const int64_t num_channels = input_shape[1];
|
||||
const InputQType quant_zero_point = 0;
|
||||
const float quant_scale = 1.0f;
|
||||
|
||||
auto* input = builder.MakeInput<InputQType>(input_shape, static_cast<InputQType>(-1),
|
||||
static_cast<InputQType>(1));
|
||||
auto* dq_input = builder.MakeIntermediate();
|
||||
builder.AddDequantizeLinearNode<InputQType>(input, 0.0039f, quant_zero_point, dq_input);
|
||||
|
||||
auto* dq_scale_output = builder.MakeIntermediate();
|
||||
auto* scale = builder.MakeInitializer<ScaleQType>({num_channels}, static_cast<ScaleQType>(1), static_cast<ScaleQType>(127));
|
||||
builder.AddDequantizeLinearNode<ScaleQType>(scale, 0.0028f, quant_zero_point, dq_scale_output);
|
||||
|
||||
auto* dq_bias_output = builder.MakeIntermediate();
|
||||
auto* bias = builder.MakeInitializer<BiasQType>({num_channels}, static_cast<BiasQType>(0), static_cast<BiasQType>(0));
|
||||
builder.AddDequantizeLinearNode<BiasQType>(bias, quant_scale, quant_zero_point, dq_bias_output);
|
||||
|
||||
auto* dq_mean_output = builder.MakeIntermediate();
|
||||
auto* mean = builder.MakeInitializer<InputQType>({num_channels}, static_cast<InputQType>(0), static_cast<InputQType>(0));
|
||||
builder.AddDequantizeLinearNode<InputQType>(mean, quant_scale, quant_zero_point, dq_mean_output);
|
||||
|
||||
auto* dq_var_output = builder.MakeIntermediate();
|
||||
auto* var = builder.MakeInitializer<InputQType>({num_channels}, static_cast<InputQType>(255), static_cast<InputQType>(255));
|
||||
builder.AddDequantizeLinearNode<InputQType>(var, 0.003921f, 0, dq_var_output);
|
||||
|
||||
auto* batchnorm_output = builder.MakeIntermediate();
|
||||
builder.AddNode("BatchNormalization", {dq_input, dq_scale_output, dq_bias_output, dq_mean_output, dq_var_output}, {batchnorm_output});
|
||||
|
||||
auto* q_output = builder.MakeIntermediate();
|
||||
builder.AddQuantizeLinearNode<InputQType>(batchnorm_output, 0.00377f, quant_zero_point, q_output);
|
||||
|
||||
auto* final_output = builder.MakeOutput();
|
||||
builder.AddDequantizeLinearNode<InputQType>(q_output, 0.00377f,
|
||||
quant_zero_point,
|
||||
final_output);
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs an BatchNormalization model on the QNN HTP backend. Checks the graph node assignment, and that inference
|
||||
* outputs for QNN and CPU match.
|
||||
*
|
||||
* \param input_shape The input's shape.
|
||||
* \param test_description Description of the test for error reporting.
|
||||
* \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None).
|
||||
* \param num_modes_in_graph The number of expected nodes in the graph.
|
||||
*/
|
||||
static void RunBatchNormQDQTest(const std::vector<int64_t>& input_shape, const char* test_description,
|
||||
ExpectedEPNodeAssignment expected_ep_assignment, int num_nodes_in_graph) {
|
||||
ProviderOptions provider_options;
|
||||
#if defined(_WIN32)
|
||||
provider_options["backend_path"] = "QnnHtp.dll";
|
||||
#else
|
||||
provider_options["backend_path"] = "libQnnHtp.so";
|
||||
#endif
|
||||
|
||||
// Runs model with DQ-> InstanceNorm -> Q and compares the outputs of the CPU and QNN EPs.
|
||||
RunQnnModelTest(BuildQDQBatchNormTestCase<uint8_t, uint8_t, uint8_t>(input_shape),
|
||||
provider_options,
|
||||
11,
|
||||
expected_ep_assignment,
|
||||
num_nodes_in_graph,
|
||||
test_description);
|
||||
}
|
||||
|
||||
// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit.
|
||||
// Use an input of rank 3.
|
||||
TEST_F(QnnHTPBackendTests, TestQDQBatchNorm1D) {
|
||||
RunBatchNormQDQTest({1, 2, 3}, "TestQDQBatchNorm1D", ExpectedEPNodeAssignment::All, 1);
|
||||
}
|
||||
|
||||
// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit.
|
||||
// Use an input of rank 4.
|
||||
TEST_F(QnnHTPBackendTests, TestQDQBatchNorm2D) {
|
||||
RunBatchNormQDQTest({2, 3, 4, 5}, "TestQDQBatchNorm2D", ExpectedEPNodeAssignment::All, 1);
|
||||
}
|
||||
|
||||
// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit.
|
||||
// Use an input of rank 5. QNN BatchNormalization doesn't support 5D on HTP
|
||||
TEST_F(QnnHTPBackendTests, TestQDQBatchNorm3D) {
|
||||
RunBatchNormQDQTest({1, 2, 3, 4, 5}, "TestQDQBatchNorm3D", ExpectedEPNodeAssignment::None, 8);
|
||||
}
|
||||
|
||||
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif
|
||||
Loading…
Reference in a new issue