[QNN EP] Qnn batchnorm Op support (#15222)

### Description
Support BatchNorm Op in Qnn EP
Node Unit group support for BatchNorm, Exp ops

### Motivation and Context
Enable more models.
This commit is contained in:
Hector Li 2023-04-10 10:36:57 -07:00 committed by GitHub
parent 0ea965c541
commit 9ef11f1c6a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 316 additions and 7 deletions

View file

@ -349,6 +349,30 @@ bool InstanceNormalizationNodeGroupSelector::Check(const GraphViewer& graph_view
(dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32);
}
bool BatchNormalizationNodeGroupSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) {
return false;
}
int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
int32_t dt_scale = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
if (dt_input != dt_output) {
return false;
}
if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) {
if (!int8_allowed_ || dt_scale != dt_input) {
return false;
}
}
return true;
}
} // namespace QDQ
} // namespace onnxruntime

View file

@ -147,6 +147,20 @@ class InstanceNormalizationNodeGroupSelector : public NodeGroupSelector {
const std::vector<const Node*>& q_nodes) const override;
};
// DQ nodes for X, W and optionally B, not used for mean, var -> node -> Q
class BatchNormalizationNodeGroupSelector : public NodeGroupSelector {
public:
// default to 'true'
BatchNormalizationNodeGroupSelector(bool int8_allowed = true) : int8_allowed_(int8_allowed) {}
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;
bool int8_allowed_;
};
/*
* NodeSelector instances for use in the QDQ::SelectorActionTransformer.
*/
@ -246,6 +260,13 @@ class InstanceNormalizationSelector : public BaseSelector {
: BaseSelector(std::make_unique<InstanceNormalizationNodeGroupSelector>()) {}
};
// DQ nodes for X, W and optionally B, (mean, var not required) -> node -> Q
class BatchNormalizationSelector : public BaseSelector {
public:
BatchNormalizationSelector(bool int8_allowed = false)
: BaseSelector(std::make_unique<BatchNormalizationNodeGroupSelector>(int8_allowed)) {}
};
} // namespace QDQ
} // namespace onnxruntime

View file

@ -53,7 +53,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() {
{"Slice", {}},
{"Softmax", {}},
{"Sqrt", {}},
{"Tanh", {}}};
{"Tanh", {}},
{"Exp", {}}};
}
static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() {
return {{"Add", {}},
@ -80,6 +81,9 @@ static const OpVersionsAndSelector::OpVersionsMap GetGemmOpVersionsMap() {
static const OpVersionsAndSelector::OpVersionsMap GetInstanceNormalizationOpVersionsMap() {
return {{"InstanceNormalization", {}}};
}
static const OpVersionsAndSelector::OpVersionsMap GetBatchNormalizationOpVersionsMap() {
return {{"BatchNormalization", {}}};
}
/* Selector rules registration related */
void RegisterMiscSelectors(Selectors& qdq_selectors) {
@ -146,6 +150,13 @@ void RegisterInstanceNormalizationSelector(Selectors& qdq_selectors) {
std::move(selector));
}
void RegisterBatchNormalizationSelector(Selectors& qdq_selectors) {
/* register selector for BatchNormalization op */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<BatchNormalizationNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetBatchNormalizationOpVersionsMap(),
std::move(selector));
}
void SelectorManager::CreateSelectors() {
RegisterMiscSelectors(qdq_selectors_);
RegisterUnarySelectors(qdq_selectors_);
@ -156,6 +167,7 @@ void SelectorManager::CreateSelectors() {
RegisterMatMulSelector(qdq_selectors_);
RegisterGemmSelector(qdq_selectors_);
RegisterInstanceNormalizationSelector(qdq_selectors_);
RegisterBatchNormalizationSelector(qdq_selectors_);
}
void SelectorManager::InitializeSelectorsMap() {

View file

@ -131,6 +131,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
{
CreateInstanceNormOpBuilder("InstanceNormalization", *this);
}
{
CreateBatchNormOpBuilder("BatchNormalization", *this);
}
}
const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) {

View file

@ -82,5 +82,7 @@ void CreateInstanceNormOpBuilder(const std::string& op_type, OpBuilderRegistrati
void CreateReduceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateBatchNormOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
} // namespace qnn
} // namespace onnxruntime

View file

@ -147,7 +147,8 @@ Status BaseOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
ORT_UNUSED_PARAMETER(do_op_validation);
const auto& inputs = node_unit.Inputs();
for (size_t input_i = 0; input_i < inputs.size(); ++input_i) {
const auto input_count = GetInputCountQnnRequired(node_unit);
for (size_t input_i = 0; input_i < input_count; ++input_i) {
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[input_i], logger, is_quantized_model, input_names));
}
@ -180,7 +181,6 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper,
// Add output
// Output part is common for all Ops, only difference is the Op attribute
const auto& outputs = node_unit.Outputs();
auto output_size = outputs.size();
std::vector<std::string> output_names;
struct CastNodeInfo {
std::string node_name;
@ -189,7 +189,8 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper,
};
std::vector<CastNodeInfo> cast_node_info_vec;
for (size_t output_i = 0; output_i < output_size && output_i < output_count_; ++output_i) {
const auto output_count = GetOutputCountQnnRequired(node_unit);
for (size_t output_i = 0; output_i < output_count; ++output_i) {
const auto& output_name = outputs[output_i].node_arg.Name();
Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT;

View file

@ -162,7 +162,8 @@ class BaseOpBuilder : public IOpBuilder {
{"ConvTranspose", "TransposeConv2d"},
{"Tile", "Tile"},
{"TopK", "TopK"},
{"InstanceNormalization", "InstanceNorm"}};
{"InstanceNormalization", "InstanceNorm"},
{"BatchNormalization", "Batchnorm"}};
auto it = onnx_op_type_to_qnn_op_type.find(onnx_op_type);
ORT_ENFORCE(it != onnx_op_type_to_qnn_op_type.end());
return it->second;
@ -252,7 +253,32 @@ class BaseOpBuilder : public IOpBuilder {
int32_t& default_axis_value) const;
Qnn_TensorType_t GetInputTensorType(const QnnModelWrapper& qnn_model_wrapper, const std::string& input_name) const;
mutable size_t output_count_ = std::numeric_limits<uint32_t>::max();
size_t GetInputCountQnnRequired(const NodeUnit& node_unit) const {
auto input_output_cout = GetInputOutputCountQnnRequired(node_unit.OpType());
return 0 == input_output_cout.first ? node_unit.Inputs().size() : input_output_cout.first;
}
size_t GetOutputCountQnnRequired(const NodeUnit& node_unit) const {
auto input_output_cout = GetInputOutputCountQnnRequired(node_unit.OpType());
return 0 == input_output_cout.second ? node_unit.Outputs().size() : input_output_cout.second;
}
private:
static const std::pair<size_t, size_t> GetInputOutputCountQnnRequired(std::string onnx_op_type) {
static const std::unordered_map<std::string, std::pair<size_t, size_t>> input_output_count_qnn_required = {
{"GlobalAveragePool", {0, 1}},
{"MaxPool", {0, 1}},
{"BatchNormalization", {3, 1}}};
auto pos = input_output_count_qnn_required.find(onnx_op_type);
if (pos == input_output_count_qnn_required.end()) {
return std::make_pair< size_t, size_t>(0, 0);
} else {
return pos->second;
}
}
private:
std::string op_builder_type_;

View file

@ -0,0 +1,94 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/framework/tensorprotoutils.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
#include "core/providers/qnn/builder/op_builder_factory.h"
#include "base_op_builder.h"
#include <limits>
namespace onnxruntime {
namespace qnn {
class BatchNormOpBuilder : public BaseOpBuilder {
public:
BatchNormOpBuilder() : BaseOpBuilder("BatchNormOpBuilder") {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(BatchNormOpBuilder);
Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
bool is_quantized_model) const override final ORT_MUST_USE_RESULT;
};
// BatchNorm is sensitive with data layout, no special validation so far
// The nodes from 1st call of GetCapability do not get layout transformer applied, it's still NCHW
// The nodes from 2nd call of GetCapability get layout transformer applied, it's NHWC
Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
bool is_quantized_model) const {
if (node_unit.Domain() == kMSInternalNHWCDomain) {
// It's useless to fallback the node after layout transformation because CPU EP can't support it anyway
// Still do it here so hopefully QNN Op validation API can tell us some details why it's not supported
return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, is_quantized_model, true);
} else {
NodeAttrHelper node_helper(node_unit);
const float default_epsilon = 1e-05f;
const float epsilon = node_helper.Get("epsilon", 1e-05f); // Default is 1e-05 according to ONNX spec.
ORT_RETURN_IF(abs(epsilon - default_epsilon) > default_epsilon, "QNN BatchNorm doesn't support epsilon.");
const auto float_elem_type = ONNX_NAMESPACE::Utils::DataTypeUtils::ToType("float");
const auto& inputs = node_unit.Inputs();
ORT_ENFORCE(inputs.size() == 5, "5 input expected per BatchNorm Onnx Spec.");
// Check input type is float for CPU.
ONNX_NAMESPACE::DataType input_data_type = inputs[0].node_arg.Type();
ORT_RETURN_IF(!is_quantized_model && input_data_type != float_elem_type, "QNN BatchNorm CPU only support float32.");
std::vector<uint32_t> input_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape of input 0.");
const size_t input_rank = input_shape.size();
ORT_RETURN_IF(input_rank <= 2 || input_rank > 4,
"QNN BatchNorm only supports input ranks of size 3 or 4.");
const uint32_t num_channels = input_shape[1];
std::vector<uint32_t> scale_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, scale_shape), "Cannot get shape of input 1 (scale).");
ORT_RETURN_IF(scale_shape.size() != 1 || scale_shape[0] != num_channels,
"QNN BatchNorm input 1 (scale) must have 1D shape [channel].");
std::vector<uint32_t> bias_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[2].node_arg, bias_shape), "Cannot get shape of input 2 (bias).");
ORT_RETURN_IF(bias_shape.size() != 1 || bias_shape[0] != num_channels,
"QNN BatchNorm input 2 (bias) must have 1D shape [channel].");
std::vector<uint32_t> mean_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[3].node_arg, mean_shape), "Cannot get shape of input 3 (mean).");
ORT_RETURN_IF(mean_shape.size() != 1 || mean_shape[0] != num_channels,
"QNN BatchNorm input 3 (mean) must have 1D shape [channel].");
ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[3].node_arg.Name()), "QNN BatchNorm doesn't support dynamic mean.");
std::vector<uint32_t> var_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[4].node_arg, var_shape), "Cannot get shape of input 4 (var).");
ORT_RETURN_IF(var_shape.size() != 1 || var_shape[0] != num_channels,
"QNN BatchNorm input 4 (var) must have 1D shape [channel].");
ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[4].node_arg.Name()), "QNN BatchNorm doesn't support dynamic var.");
ORT_RETURN_IF(node_unit.Outputs().size() > 1, "QNN BatchNorm only support 1 output.");
}
return Status::OK();
}
void CreateBatchNormOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.AddOpBuilder(op_type, std::make_unique<BatchNormOpBuilder>());
}
} // namespace qnn
} // namespace onnxruntime

View file

@ -215,7 +215,6 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
qnn_model_wrapper.AddParamWrapper(std::move(count_pad_for_edges_param));
}
output_count_ = 1;
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit,
std::move(input_names),
std::move(param_tensor_names),

View file

@ -204,9 +204,17 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer);
const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(), logger);
if (supported_nodes.empty()) {
LOGS(logger, INFO) << "Number of partitions supported by QNN EP: 0";
return result;
} else if (supported_nodes.size() == 1) {
const auto* node = *supported_nodes.begin();
if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") {
LOGS(logger, INFO) << "It doesn't make sense just run a Q/DQ node on HTP.";
LOGS(logger, INFO) << "Number of partitions supported by QNN EP: 0";
return result;
}
}
const auto gen_metadef_name = [&]() {

View file

@ -0,0 +1,118 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#if !defined(ORT_MINIMAL_BUILD)
#include <string>
#include "core/graph/graph.h"
#include "test/optimizer/qdq_test_utils.h"
#include "test/providers/qnn/qnn_test_utils.h"
#include "gtest/gtest.h"
namespace onnxruntime {
namespace test {
#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
// Creates the graph:
// _______________________
// input_u8 -> DQ -> | |
// scale_u8 (initializer) -> DQ -> | |
// bias_u8 (initializer) -> DQ -> | BatchNormalization | -> Q -> output_u8
// mean_u8 (initializer) -> DQ -> | |
// var_u8 (initializer) -> DQ -> |_______________________|
//
// Currently used to test QNN EP.
template <typename InputQType, typename ScaleQType, typename BiasQType>
GetQDQTestCaseFn BuildQDQBatchNormTestCase(const std::vector<int64_t>& input_shape) {
return [input_shape](ModelTestBuilder& builder) {
const int64_t num_channels = input_shape[1];
const InputQType quant_zero_point = 0;
const float quant_scale = 1.0f;
auto* input = builder.MakeInput<InputQType>(input_shape, static_cast<InputQType>(-1),
static_cast<InputQType>(1));
auto* dq_input = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<InputQType>(input, 0.0039f, quant_zero_point, dq_input);
auto* dq_scale_output = builder.MakeIntermediate();
auto* scale = builder.MakeInitializer<ScaleQType>({num_channels}, static_cast<ScaleQType>(1), static_cast<ScaleQType>(127));
builder.AddDequantizeLinearNode<ScaleQType>(scale, 0.0028f, quant_zero_point, dq_scale_output);
auto* dq_bias_output = builder.MakeIntermediate();
auto* bias = builder.MakeInitializer<BiasQType>({num_channels}, static_cast<BiasQType>(0), static_cast<BiasQType>(0));
builder.AddDequantizeLinearNode<BiasQType>(bias, quant_scale, quant_zero_point, dq_bias_output);
auto* dq_mean_output = builder.MakeIntermediate();
auto* mean = builder.MakeInitializer<InputQType>({num_channels}, static_cast<InputQType>(0), static_cast<InputQType>(0));
builder.AddDequantizeLinearNode<InputQType>(mean, quant_scale, quant_zero_point, dq_mean_output);
auto* dq_var_output = builder.MakeIntermediate();
auto* var = builder.MakeInitializer<InputQType>({num_channels}, static_cast<InputQType>(255), static_cast<InputQType>(255));
builder.AddDequantizeLinearNode<InputQType>(var, 0.003921f, 0, dq_var_output);
auto* batchnorm_output = builder.MakeIntermediate();
builder.AddNode("BatchNormalization", {dq_input, dq_scale_output, dq_bias_output, dq_mean_output, dq_var_output}, {batchnorm_output});
auto* q_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<InputQType>(batchnorm_output, 0.00377f, quant_zero_point, q_output);
auto* final_output = builder.MakeOutput();
builder.AddDequantizeLinearNode<InputQType>(q_output, 0.00377f,
quant_zero_point,
final_output);
};
}
/**
* Runs an BatchNormalization model on the QNN HTP backend. Checks the graph node assignment, and that inference
* outputs for QNN and CPU match.
*
* \param input_shape The input's shape.
* \param test_description Description of the test for error reporting.
* \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None).
* \param num_modes_in_graph The number of expected nodes in the graph.
*/
static void RunBatchNormQDQTest(const std::vector<int64_t>& input_shape, const char* test_description,
ExpectedEPNodeAssignment expected_ep_assignment, int num_nodes_in_graph) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
// Runs model with DQ-> InstanceNorm -> Q and compares the outputs of the CPU and QNN EPs.
RunQnnModelTest(BuildQDQBatchNormTestCase<uint8_t, uint8_t, uint8_t>(input_shape),
provider_options,
11,
expected_ep_assignment,
num_nodes_in_graph,
test_description);
}
// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit.
// Use an input of rank 3.
TEST_F(QnnHTPBackendTests, TestQDQBatchNorm1D) {
RunBatchNormQDQTest({1, 2, 3}, "TestQDQBatchNorm1D", ExpectedEPNodeAssignment::All, 1);
}
// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit.
// Use an input of rank 4.
TEST_F(QnnHTPBackendTests, TestQDQBatchNorm2D) {
RunBatchNormQDQTest({2, 3, 4, 5}, "TestQDQBatchNorm2D", ExpectedEPNodeAssignment::All, 1);
}
// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit.
// Use an input of rank 5. QNN BatchNormalization doesn't support 5D on HTP
TEST_F(QnnHTPBackendTests, TestQDQBatchNorm3D) {
RunBatchNormQDQTest({1, 2, 3, 4, 5}, "TestQDQBatchNorm3D", ExpectedEPNodeAssignment::None, 8);
}
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
} // namespace test
} // namespace onnxruntime
#endif