mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
[QNN EP] Support LRN operator (#15741)
### Description Adds support for the LRN operator to QNN EP. ### Motivation and Context Enables basic models like googlenet and alexnet to run entirely on QNN EP.
This commit is contained in:
parent
65020d433e
commit
d32c540b2d
7 changed files with 352 additions and 5 deletions
|
|
@ -58,7 +58,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() {
|
|||
{"Sqrt", {}},
|
||||
{"Atan", {}},
|
||||
{"Tanh", {}},
|
||||
{"Exp", {}}};
|
||||
{"Exp", {}},
|
||||
{"LRN", {}}};
|
||||
}
|
||||
static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() {
|
||||
return {{"Add", {}},
|
||||
|
|
|
|||
|
|
@ -138,6 +138,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
|
|||
{
|
||||
CreateBatchNormOpBuilder("BatchNormalization", *this);
|
||||
}
|
||||
|
||||
{
|
||||
CreateLRNOpBuilder("LRN", *this);
|
||||
}
|
||||
}
|
||||
|
||||
const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) {
|
||||
|
|
|
|||
|
|
@ -84,5 +84,7 @@ void CreateReduceOpBuilder(const std::string& op_type, OpBuilderRegistrations& o
|
|||
|
||||
void CreateBatchNormOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
||||
void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
||||
} // namespace qnn
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@
|
|||
#include "core/providers/qnn/builder/op_builder.h"
|
||||
#include "core/framework/allocator.h"
|
||||
|
||||
#include "QnnOpDef.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace qnn {
|
||||
|
||||
|
|
@ -91,6 +93,7 @@ class BaseOpBuilder : public IOpBuilder {
|
|||
}
|
||||
|
||||
static const std::string& GetQnnOpType(const std::string& onnx_op_type) {
|
||||
// TODO: Use QNN operator names defined in "QnnOpDef.h"
|
||||
static const std::unordered_map<std::string, std::string> onnx_op_type_to_qnn_op_type = {
|
||||
{"Add", "ElementWiseAdd"},
|
||||
{"Mul", "ElementWiseMultiply"},
|
||||
|
|
@ -171,7 +174,9 @@ class BaseOpBuilder : public IOpBuilder {
|
|||
{"Tile", "Tile"},
|
||||
{"TopK", "TopK"},
|
||||
{"InstanceNormalization", "InstanceNorm"},
|
||||
{"BatchNormalization", "Batchnorm"}};
|
||||
{"BatchNormalization", "Batchnorm"},
|
||||
|
||||
{"LRN", QNN_OP_LRN}};
|
||||
auto it = onnx_op_type_to_qnn_op_type.find(onnx_op_type);
|
||||
ORT_ENFORCE(it != onnx_op_type_to_qnn_op_type.end());
|
||||
return it->second;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,185 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
|
||||
#include "core/providers/qnn/builder/op_builder_factory.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "onnx/defs/data_type_utils.h"
|
||||
|
||||
#include "QnnOpDef.h" // From QNN SDK: contains QNN constants (e.g., op names, param values).
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace qnn {
|
||||
|
||||
class LRNOpBuilder : public BaseOpBuilder {
|
||||
public:
|
||||
LRNOpBuilder() : BaseOpBuilder("LRNOpBuilder") {}
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(LRNOpBuilder);
|
||||
|
||||
Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
|
||||
const NodeUnit& node_unit,
|
||||
const logging::Logger& logger,
|
||||
bool is_quantized_model) const override final ORT_MUST_USE_RESULT;
|
||||
|
||||
protected:
|
||||
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
|
||||
const NodeUnit& node_unit,
|
||||
std::vector<std::string>&& input_names,
|
||||
const logging::Logger& logger,
|
||||
bool is_quantized_model,
|
||||
bool do_op_validation) const override ORT_MUST_USE_RESULT;
|
||||
|
||||
private:
|
||||
static const OnnxAttrInfo<float> onnx_alpha_attr;
|
||||
static const OnnxAttrInfo<float> onnx_beta_attr;
|
||||
static const OnnxAttrInfo<float> onnx_bias_attr;
|
||||
static const OnnxAttrInfo<int64_t> onnx_size_attr;
|
||||
};
|
||||
|
||||
const OnnxAttrInfo<float> LRNOpBuilder::onnx_alpha_attr = {"alpha", 0.0001f};
|
||||
const OnnxAttrInfo<float> LRNOpBuilder::onnx_beta_attr = {"beta", 0.75f};
|
||||
const OnnxAttrInfo<float> LRNOpBuilder::onnx_bias_attr = {"bias", 1.0f};
|
||||
const OnnxAttrInfo<int64_t> LRNOpBuilder::onnx_size_attr = {"size", 0};
|
||||
|
||||
// The LRN operator is layout sensitive. ONNX LRN has layout NCHW, but QNN requires layout NHWC.
|
||||
// The nodes from 1st call of GetCapability do not get layout transformer applied, so their shapes are still NCHW.
|
||||
// The nodes from 2nd call of GetCapability get their layout transformed to NHWC.
|
||||
// Therefore, we need to check the node domain to determine if the layout has been transformed.
|
||||
Status LRNOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
|
||||
const NodeUnit& node_unit,
|
||||
const logging::Logger& logger,
|
||||
bool is_quantized_model) const {
|
||||
if (node_unit.Domain() == kMSInternalNHWCDomain) {
|
||||
return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, is_quantized_model, true);
|
||||
}
|
||||
|
||||
const auto float_elem_type = ONNX_NAMESPACE::Utils::DataTypeUtils::ToType("float");
|
||||
const auto& inputs = node_unit.Inputs();
|
||||
const auto& outputs = node_unit.Outputs();
|
||||
|
||||
ORT_RETURN_IF(inputs.size() != 1, "QNN EP: LRN operator must have 1 input.");
|
||||
ORT_RETURN_IF(outputs.size() != 1, "QNN EP: LRN operator must have 1 output.");
|
||||
|
||||
const auto& input = inputs[0];
|
||||
const auto& output = outputs[0];
|
||||
|
||||
// Check that the input type is float for CPU.
|
||||
ONNX_NAMESPACE::DataType input_data_type = input.node_arg.Type();
|
||||
ORT_RETURN_IF(!is_quantized_model && input_data_type != float_elem_type,
|
||||
"QNN EP: LRN operator does not support the input type '", input_data_type->c_str(),
|
||||
"' on the CPU backend.");
|
||||
|
||||
// Check that the output type is float for CPU.
|
||||
ONNX_NAMESPACE::DataType output_data_type = output.node_arg.Type();
|
||||
ORT_RETURN_IF(!is_quantized_model && output_data_type != float_elem_type,
|
||||
"QNN EP: LRN operator does not support the input type '", input_data_type->c_str(),
|
||||
"' on the CPU backend.");
|
||||
|
||||
// Check that the input and output have the same shape.
|
||||
std::vector<uint32_t> input_shape;
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input.node_arg, input_shape), "Cannot get shape of input 0");
|
||||
const size_t input_rank = input_shape.size();
|
||||
|
||||
ORT_RETURN_IF(input_rank <= 2 || input_rank > 4, "QNN EP: LRN operator only supports input ranks of size 3 or 4.");
|
||||
|
||||
std::vector<uint32_t> output_shape;
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output.node_arg, output_shape), "Cannot get shape of output 0");
|
||||
|
||||
ORT_RETURN_IF(output_shape != input_shape, "QNN EP: LRN operator's output must have the same shape as the input.");
|
||||
|
||||
NodeAttrHelper node_helper(node_unit);
|
||||
|
||||
// 'size' attribute must be odd and > 0.
|
||||
const int64_t onnx_size = GetOnnxAttr(node_helper, onnx_size_attr);
|
||||
ORT_RETURN_IF(onnx_size % 2 == 0, "QNN EP: LRN operator's size attribute must be odd.");
|
||||
|
||||
// 'alpha' attribute must be > 0.0f.
|
||||
const float onnx_alpha = GetOnnxAttr(node_helper, onnx_alpha_attr);
|
||||
ORT_RETURN_IF(onnx_alpha <= 0.0f, "QNN EP: LRN operator's alpha attribute must be greater than zero.");
|
||||
|
||||
// 'alpha' attribute must be > 0.0f.
|
||||
const float onnx_beta = GetOnnxAttr(node_helper, onnx_beta_attr);
|
||||
ORT_RETURN_IF(onnx_beta <= 0.0f, "QNN EP: LRN operator's beta attribute must be greater than zero.");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LRNOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
|
||||
const NodeUnit& node_unit,
|
||||
std::vector<std::string>&& input_names,
|
||||
const logging::Logger& logger,
|
||||
bool is_quantized_model,
|
||||
bool do_op_validation) const {
|
||||
std::vector<std::string> param_tensor_names;
|
||||
NodeAttrHelper node_helper(node_unit);
|
||||
|
||||
const int64_t onnx_size = GetOnnxAttr(node_helper, onnx_size_attr);
|
||||
|
||||
// Parameter 'radius'
|
||||
{
|
||||
Qnn_Scalar_t qnn_radius = QNN_SCALAR_INIT;
|
||||
qnn_radius.dataType = QNN_DATATYPE_INT_32;
|
||||
qnn_radius.int32Value = SafeInt<int32_t>((onnx_size - 1) / 2); // Convert ONNX size into QNN radius.
|
||||
|
||||
QnnParamWrapper qnn_param(node_unit.Index(), node_unit.Name(), QNN_OP_LRN_PARAM_RADIUS, qnn_radius);
|
||||
param_tensor_names.push_back(qnn_param.GetParamTensorName());
|
||||
qnn_model_wrapper.AddParamWrapper(std::move(qnn_param));
|
||||
}
|
||||
|
||||
// Parameter 'alpha'
|
||||
{
|
||||
float onnx_alpha = GetOnnxAttr(node_helper, onnx_alpha_attr);
|
||||
Qnn_Scalar_t qnn_alpha = QNN_SCALAR_INIT;
|
||||
qnn_alpha.dataType = QNN_DATATYPE_FLOAT_32;
|
||||
qnn_alpha.floatValue = onnx_alpha / static_cast<float>(onnx_size); // QNN doesn't scale alpha by size.
|
||||
|
||||
QnnParamWrapper qnn_param(node_unit.Index(), node_unit.Name(), QNN_OP_LRN_PARAM_ALPHA, qnn_alpha);
|
||||
param_tensor_names.push_back(qnn_param.GetParamTensorName());
|
||||
qnn_model_wrapper.AddParamWrapper(std::move(qnn_param));
|
||||
}
|
||||
|
||||
// Parameter 'beta'
|
||||
{
|
||||
Qnn_Scalar_t qnn_beta = QNN_SCALAR_INIT;
|
||||
qnn_beta.dataType = QNN_DATATYPE_FLOAT_32;
|
||||
qnn_beta.floatValue = GetOnnxAttr(node_helper, onnx_beta_attr);
|
||||
|
||||
QnnParamWrapper qnn_param(node_unit.Index(), node_unit.Name(), QNN_OP_LRN_PARAM_BETA, qnn_beta);
|
||||
param_tensor_names.push_back(qnn_param.GetParamTensorName());
|
||||
qnn_model_wrapper.AddParamWrapper(std::move(qnn_param));
|
||||
}
|
||||
|
||||
// Parameter 'bias'
|
||||
{
|
||||
Qnn_Scalar_t qnn_bias = QNN_SCALAR_INIT;
|
||||
qnn_bias.dataType = QNN_DATATYPE_FLOAT_32;
|
||||
qnn_bias.floatValue = GetOnnxAttr(node_helper, onnx_bias_attr);
|
||||
|
||||
QnnParamWrapper qnn_param(node_unit.Index(), node_unit.Name(), QNN_OP_LRN_PARAM_BIAS, qnn_bias);
|
||||
param_tensor_names.push_back(qnn_param.GetParamTensorName());
|
||||
qnn_model_wrapper.AddParamWrapper(std::move(qnn_param));
|
||||
}
|
||||
|
||||
// Parameter 'region'
|
||||
{
|
||||
Qnn_Scalar_t qnn_region = QNN_SCALAR_INIT;
|
||||
qnn_region.dataType = QNN_DATATYPE_UINT_32;
|
||||
qnn_region.uint32Value = QNN_OP_LRN_REGION_ACROSS_CHANNEL; // ONNX's LRN only supports "across channel".
|
||||
|
||||
QnnParamWrapper qnn_param(node_unit.Index(), node_unit.Name(), QNN_OP_LRN_PARAM_REGION, qnn_region);
|
||||
param_tensor_names.push_back(qnn_param.GetParamTensorName());
|
||||
qnn_model_wrapper.AddParamWrapper(std::move(qnn_param));
|
||||
}
|
||||
|
||||
return ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names),
|
||||
logger, is_quantized_model, do_op_validation);
|
||||
}
|
||||
|
||||
void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.AddOpBuilder(op_type, std::make_unique<LRNOpBuilder>());
|
||||
}
|
||||
|
||||
} // namespace qnn
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -424,11 +424,9 @@ Status QnnModelWrapper::UnpackInitializerData(const ONNX_NAMESPACE::TensorProto&
|
|||
std::vector<uint8_t>& unpacked_tensor) const {
|
||||
if (initializer.data_location() == onnx::TensorProto_DataLocation_EXTERNAL) {
|
||||
return onnxruntime::utils::UnpackInitializerData(initializer, graph_viewer_.ModelPath(), unpacked_tensor);
|
||||
} else {
|
||||
return onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
return onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor);
|
||||
}
|
||||
|
||||
} // namespace qnn
|
||||
|
|
|
|||
152
onnxruntime/test/providers/qnn/lrn_op_test.cc
Normal file
152
onnxruntime/test/providers/qnn/lrn_op_test.cc
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "test/optimizer/qdq_test_utils.h"
|
||||
#include "test/providers/qnn/qnn_test_utils.h"
|
||||
|
||||
#include "onnx/onnx_pb.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
// Creates a graph with a single LRN operator. Used for testing CPU backend.
|
||||
static GetTestModelFn BuildLRNTestCase(const std::vector<int64_t>& shape, int64_t size,
|
||||
float alpha = 0.0001f, float beta = 0.75f, float bias = 1.0f) {
|
||||
return [shape, size, alpha, beta, bias](ModelTestBuilder& builder) {
|
||||
auto* input = builder.MakeInput<float>(shape, 0.0f, 20.0f);
|
||||
auto* output = builder.MakeOutput();
|
||||
|
||||
Node& lrn_node = builder.AddNode("LRN", {input}, {output});
|
||||
lrn_node.AddAttribute("size", size);
|
||||
lrn_node.AddAttribute("alpha", alpha);
|
||||
lrn_node.AddAttribute("beta", beta);
|
||||
lrn_node.AddAttribute("bias", bias);
|
||||
};
|
||||
}
|
||||
|
||||
// Q/DQ scaled used to build Q/DQ test model. This is a global constant
|
||||
// because results from HTP backend are off by exactly this amount.
|
||||
static constexpr float qdq_scale = 0.0038f;
|
||||
|
||||
// Creates a graph with a single Q/DQ LRN operator. Used for testing HTP backend.
|
||||
template <typename InputQType = uint8_t>
|
||||
static GetTestModelFn BuildQDQLRNTestCase(const std::vector<int64_t>& shape, int64_t size,
|
||||
float alpha = 0.0001f, float beta = 0.75f, float bias = 1.0f) {
|
||||
return [shape, size, alpha, beta, bias](ModelTestBuilder& builder) {
|
||||
const InputQType zero_point = std::numeric_limits<InputQType>::max() / 2;
|
||||
|
||||
auto* input = builder.MakeInput<float>(shape, -1.0f, 1.0f);
|
||||
auto* output = builder.MakeOutput();
|
||||
|
||||
// input -> Q -> DQ -> LRN
|
||||
auto* qdq_output = AddQDQNodePair<InputQType>(builder, input, qdq_scale, zero_point);
|
||||
auto* lrn_output = builder.MakeIntermediate();
|
||||
|
||||
Node& lrn_node = builder.AddNode("LRN", {qdq_output}, {lrn_output});
|
||||
lrn_node.AddAttribute("size", size);
|
||||
lrn_node.AddAttribute("alpha", alpha);
|
||||
lrn_node.AddAttribute("beta", beta);
|
||||
lrn_node.AddAttribute("bias", bias);
|
||||
|
||||
// -> Q -> DQ -> output
|
||||
auto* q_output = builder.MakeIntermediate();
|
||||
builder.AddQuantizeLinearNode<InputQType>(lrn_output, qdq_scale, zero_point, q_output);
|
||||
builder.AddDequantizeLinearNode<InputQType>(q_output, qdq_scale, zero_point, output);
|
||||
};
|
||||
}
|
||||
|
||||
// Runs an LRN model on the QNN CPU backend. Checks the graph node assignment, and that inference
|
||||
// outputs for QNN EP and CPU EP match.
|
||||
static void RunCPULRNOpTest(const std::vector<int64_t>& shape, int64_t size,
|
||||
ExpectedEPNodeAssignment expected_ep_assignment, const char* test_description,
|
||||
float alpha = 0.0001f, float beta = 0.75f, float bias = 1.0f, int opset = 13) {
|
||||
ProviderOptions provider_options;
|
||||
float fp32_abs_err = 1e-5f; // default tolerance
|
||||
|
||||
#if defined(_WIN32)
|
||||
provider_options["backend_path"] = "QnnCpu.dll";
|
||||
#else
|
||||
provider_options["backend_path"] = "libQnnCpu.so";
|
||||
fp32_abs_err = 1.5e-5f; // On linux we need slightly larger tolerance.
|
||||
#endif
|
||||
|
||||
constexpr int expected_nodes_in_partition = 1;
|
||||
RunQnnModelTest(BuildLRNTestCase(shape, size, alpha, beta, bias),
|
||||
provider_options,
|
||||
opset,
|
||||
expected_ep_assignment,
|
||||
expected_nodes_in_partition,
|
||||
test_description,
|
||||
fp32_abs_err);
|
||||
}
|
||||
|
||||
// Runs an LRN model on the QNN HTP backend. Checks the graph node assignment, and that inference
|
||||
// outputs for QNN EP and CPU EP match.
|
||||
template <typename QuantType>
|
||||
static void RunQDQLRNOpTest(const std::vector<int64_t>& shape, int64_t size,
|
||||
ExpectedEPNodeAssignment expected_ep_assignment, const char* test_description,
|
||||
float alpha = 0.0001f, float beta = 0.75f, float bias = 1.0f,
|
||||
int opset = 13, float fp32_abs_err = qdq_scale) {
|
||||
ProviderOptions provider_options;
|
||||
#if defined(_WIN32)
|
||||
provider_options["backend_path"] = "QnnHtp.dll";
|
||||
#else
|
||||
provider_options["backend_path"] = "libQnnHtp.so";
|
||||
#endif
|
||||
|
||||
constexpr int expected_nodes_in_partition = 1;
|
||||
RunQnnModelTest(BuildQDQLRNTestCase<QuantType>(shape, size, alpha, beta, bias),
|
||||
provider_options,
|
||||
opset,
|
||||
expected_ep_assignment,
|
||||
expected_nodes_in_partition,
|
||||
test_description,
|
||||
fp32_abs_err + 0.0001f);
|
||||
}
|
||||
|
||||
//
|
||||
// CPU tests:
|
||||
//
|
||||
|
||||
TEST_F(QnnCPUBackendTests, TestCPULRNSize3) {
|
||||
RunCPULRNOpTest({1, 128, 4, 5}, 3, ExpectedEPNodeAssignment::All, "TestCPULRNSize3");
|
||||
}
|
||||
|
||||
TEST_F(QnnCPUBackendTests, TestCPULRNSize5) {
|
||||
RunCPULRNOpTest({1, 128, 4, 5}, 5, ExpectedEPNodeAssignment::All, "TestCPULRNSize5");
|
||||
}
|
||||
|
||||
TEST_F(QnnCPUBackendTests, TestCPULRN_size_larger_than_channel) {
|
||||
RunCPULRNOpTest({1, 128, 4, 5}, 255, ExpectedEPNodeAssignment::All, "TestCPULRN_size_larger_than_channel");
|
||||
}
|
||||
|
||||
#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
|
||||
//
|
||||
// HTP tests:
|
||||
//
|
||||
|
||||
TEST_F(QnnHTPBackendTests, TestHTPLRNSize3) {
|
||||
RunQDQLRNOpTest<uint8_t>({1, 128, 4, 5}, 3, ExpectedEPNodeAssignment::All, "TestHTPLRNSize3");
|
||||
}
|
||||
|
||||
TEST_F(QnnHTPBackendTests, TestHTPLRNSize5) {
|
||||
RunQDQLRNOpTest<uint8_t>({1, 128, 4, 5}, 5, ExpectedEPNodeAssignment::All, "TestHTPLRNSize5");
|
||||
}
|
||||
|
||||
TEST_F(QnnHTPBackendTests, TestHTPLRN_size_larger_than_channel) {
|
||||
RunQDQLRNOpTest<uint8_t>({1, 128, 4, 5}, 255, ExpectedEPNodeAssignment::All, "TestHTPLRN_size_larger_than_channel");
|
||||
}
|
||||
|
||||
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
Loading…
Reference in a new issue