mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
[QNN EP] Handle rank 3 InstanceNormalization with N != 1 (#17897)
### Description The QNN HTP backend does not support rank 3 InstanceNorm if the batch size is not 1. To work around this limitation, QNN EP can wrap a rank 4 QNN InstanceNorm op with Reshapes (with the H dim set to 1). ### Motivation and Context Enable support for more models.
This commit is contained in:
parent
0c5b1598d3
commit
dad70ad4e8
2 changed files with 200 additions and 15 deletions
|
|
@ -24,6 +24,12 @@ class InstanceNormOpBuilder : public BaseOpBuilder {
|
|||
const logging::Logger& logger) const override final ORT_MUST_USE_RESULT;
|
||||
|
||||
protected:
|
||||
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
|
||||
const NodeUnit& node_unit,
|
||||
const logging::Logger& logger,
|
||||
std::vector<std::string>& input_names,
|
||||
bool do_op_validation) const override ORT_MUST_USE_RESULT;
|
||||
|
||||
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
|
||||
const NodeUnit& node_unit,
|
||||
std::vector<std::string>&& input_names,
|
||||
|
|
@ -81,6 +87,66 @@ Status InstanceNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InstanceNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
|
||||
const NodeUnit& node_unit,
|
||||
const logging::Logger& logger,
|
||||
std::vector<std::string>& input_names,
|
||||
bool do_op_validation) const {
|
||||
const auto& inputs = node_unit.Inputs();
|
||||
|
||||
OnnxInputInfo input0_info = {};
|
||||
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[0], input0_info));
|
||||
|
||||
// HTP backend can only handle rank 3 inputs if the batch size is 1. If the batch size is not 1,
|
||||
// QNN EP must reshape the input and output to (N, 1, W, C) and process the InstanceNorm as rank 4.
|
||||
if (IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()) &&
|
||||
input0_info.shape.size() == 3 && input0_info.shape[0] != 1) {
|
||||
const std::string& orig_input0_name = inputs[0].node_arg.Name();
|
||||
const std::string op_input0_name = input0_info.is_initializer ? orig_input0_name
|
||||
: orig_input0_name + "_ort_qnn_ep_reshape";
|
||||
input_names.push_back(op_input0_name);
|
||||
|
||||
std::vector<uint8_t> initializer_data;
|
||||
if (input0_info.is_initializer) {
|
||||
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input0_info.initializer_tensor, initializer_data));
|
||||
}
|
||||
|
||||
std::vector<uint32_t> op_shape = {
|
||||
input0_info.shape[0], // N
|
||||
1, // Height == 1
|
||||
input0_info.shape[1], // Width
|
||||
input0_info.shape[2] // Channels
|
||||
};
|
||||
|
||||
if (!input0_info.is_initializer) {
|
||||
// Add Reshape node to transform 1D input to 2D (i.e., set height to 1).
|
||||
// We don't need to do this for initializers, because the element layout does not change. We can just
|
||||
// modify the shape dimensions.
|
||||
bool is_graph_input = qnn_model_wrapper.IsGraphInput(orig_input0_name);
|
||||
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(orig_input0_name,
|
||||
op_input0_name,
|
||||
input0_info.shape,
|
||||
op_shape,
|
||||
input0_info.qnn_data_type,
|
||||
input0_info.quant_param,
|
||||
do_op_validation,
|
||||
is_graph_input));
|
||||
}
|
||||
|
||||
Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, op_input0_name);
|
||||
QnnTensorWrapper input_tensorwrapper(op_input0_name, tensor_type, input0_info.qnn_data_type, input0_info.quant_param,
|
||||
std::move(op_shape), std::move(initializer_data));
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); // Input 0
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[1], logger, input_names)); // Scale
|
||||
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[2], logger, input_names)); // Bias
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InstanceNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
|
||||
const NodeUnit& node_unit,
|
||||
std::vector<std::string>&& input_names,
|
||||
|
|
@ -100,11 +166,59 @@ Status InstanceNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_m
|
|||
param_tensor_names.push_back(epsilon_param_wrapper.GetParamTensorName());
|
||||
qnn_model_wrapper.AddParamWrapper(std::move(epsilon_param_wrapper));
|
||||
|
||||
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit,
|
||||
std::move(input_names),
|
||||
std::move(param_tensor_names),
|
||||
logger, do_op_validation, GetQnnOpType(node_unit.OpType())));
|
||||
const auto& outputs = node_unit.Outputs();
|
||||
|
||||
OnnxInputInfo output_info = {};
|
||||
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(outputs[0], output_info));
|
||||
|
||||
// HTP backend can only handle rank 3 inputs/outputs if the batch size is 1. If the batch size is not 1,
|
||||
// QNN EP must reshape the input and output to (N, 1, W, C) and process the InstanceNorm as rank 4.
|
||||
if (!IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()) ||
|
||||
output_info.shape.size() != 3 || output_info.shape[0] == 1) {
|
||||
return ProcessOutputs(qnn_model_wrapper, node_unit,
|
||||
std::move(input_names),
|
||||
std::move(param_tensor_names),
|
||||
logger, do_op_validation, GetQnnOpType(node_unit.OpType()));
|
||||
}
|
||||
|
||||
//
|
||||
// The output is meant to be rank 3 with batch size != 1. Must create a QNN InstanceNorm op with a rank 4 output
|
||||
// that is then reshaped to rank 3 again.
|
||||
//
|
||||
|
||||
const std::string& orig_output_name = outputs[0].node_arg.Name();
|
||||
std::string op_output_name = orig_output_name + "_ort_qnn_ep_reshape";
|
||||
|
||||
std::vector<uint32_t> op_output_shape = {
|
||||
output_info.shape[0], // N
|
||||
1, // H == 1
|
||||
output_info.shape[1], // W
|
||||
output_info.shape[2], // C
|
||||
};
|
||||
|
||||
QnnTensorWrapper output_tensorwrapper(op_output_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type,
|
||||
output_info.quant_param, std::vector<uint32_t>(op_output_shape));
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor.");
|
||||
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(GetNodeName(node_unit),
|
||||
QNN_OP_PACKAGE_NAME_QTI_AISW,
|
||||
GetQnnOpType(node_unit.OpType()),
|
||||
std::move(input_names),
|
||||
{op_output_name},
|
||||
std::move(param_tensor_names)),
|
||||
"Failed to add node.");
|
||||
|
||||
const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(orig_output_name);
|
||||
|
||||
// Add Reshape to convert QNN InstanceNorm output back to rank 3 (as expected by the rest of the ONNX graph).
|
||||
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(op_output_name,
|
||||
orig_output_name,
|
||||
op_output_shape,
|
||||
output_info.shape,
|
||||
output_info.qnn_data_type,
|
||||
output_info.quant_param,
|
||||
do_op_validation,
|
||||
false,
|
||||
is_graph_output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,21 +21,26 @@ template <typename QuantType>
|
|||
static GetTestQDQModelFn<QuantType> BuildQDQInstanceNormTestCase(const TestInputDef<float>& input_def,
|
||||
const TestInputDef<float>& scale_def,
|
||||
const TestInputDef<float>& bias_def,
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs) {
|
||||
return [input_def, scale_def, bias_def, attrs](ModelTestBuilder& builder,
|
||||
std::vector<QuantParams<QuantType>>& output_qparams) {
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs,
|
||||
bool use_contrib_qdq = false) {
|
||||
return [input_def, scale_def, bias_def, attrs,
|
||||
use_contrib_qdq](ModelTestBuilder& builder,
|
||||
std::vector<QuantParams<QuantType>>& output_qparams) {
|
||||
// input => Q => DQ =>
|
||||
NodeArg* input = MakeTestInput(builder, input_def);
|
||||
QuantParams<QuantType> input_qparams = GetTestInputQuantParams<QuantType>(input_def);
|
||||
NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point);
|
||||
NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point,
|
||||
use_contrib_qdq);
|
||||
|
||||
// scale => Q => DQ =>
|
||||
NodeArg* scale = MakeTestInput(builder, scale_def);
|
||||
QuantParams<QuantType> scale_qparams = GetTestInputQuantParams<QuantType>(scale_def);
|
||||
NodeArg* scale_qdq = AddQDQNodePair(builder, scale, scale_qparams.scale, scale_qparams.zero_point);
|
||||
NodeArg* scale_qdq = AddQDQNodePair(builder, scale, scale_qparams.scale, scale_qparams.zero_point,
|
||||
use_contrib_qdq);
|
||||
|
||||
// bias (as int32) => DQ =>
|
||||
NodeArg* bias_qdq = MakeTestQDQBiasInput(builder, bias_def, input_qparams.scale * scale_qparams.scale);
|
||||
NodeArg* bias_qdq = MakeTestQDQBiasInput(builder, bias_def, input_qparams.scale * scale_qparams.scale,
|
||||
use_contrib_qdq);
|
||||
|
||||
// InstanceNormalization operator.
|
||||
auto* instance_norm_output = builder.MakeIntermediate();
|
||||
|
|
@ -46,7 +51,8 @@ static GetTestQDQModelFn<QuantType> BuildQDQInstanceNormTestCase(const TestInput
|
|||
}
|
||||
|
||||
// Add instance_norm_output -> Q -> output_u8
|
||||
AddQDQNodePairWithOutputAsGraphOutput<QuantType>(builder, instance_norm_output, output_qparams[0].scale, output_qparams[0].zero_point);
|
||||
AddQDQNodePairWithOutputAsGraphOutput<QuantType>(builder, instance_norm_output, output_qparams[0].scale,
|
||||
output_qparams[0].zero_point, use_contrib_qdq);
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -65,7 +71,8 @@ static void RunInstanceNormQDQTest(const TestInputDef<float>& input_def,
|
|||
const TestInputDef<float>& scale_def,
|
||||
const TestInputDef<float>& bias_def,
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs,
|
||||
ExpectedEPNodeAssignment expected_ep_assignment) {
|
||||
ExpectedEPNodeAssignment expected_ep_assignment,
|
||||
bool use_contrib_qdq = false) {
|
||||
ProviderOptions provider_options;
|
||||
#if defined(_WIN32)
|
||||
provider_options["backend_path"] = "QnnHtp.dll";
|
||||
|
|
@ -75,11 +82,10 @@ static void RunInstanceNormQDQTest(const TestInputDef<float>& input_def,
|
|||
|
||||
// Runs model with DQ-> InstanceNorm -> Q and compares the outputs of the CPU and QNN EPs.
|
||||
TestQDQModelAccuracy(BuildOpTestCase<float>("InstanceNormalization", {input_def, scale_def, bias_def}, {}, attrs),
|
||||
BuildQDQInstanceNormTestCase<QuantType>(input_def, scale_def, bias_def, attrs),
|
||||
BuildQDQInstanceNormTestCase<QuantType>(input_def, scale_def, bias_def, attrs, use_contrib_qdq),
|
||||
provider_options,
|
||||
18,
|
||||
expected_ep_assignment,
|
||||
1e-5f);
|
||||
expected_ep_assignment);
|
||||
}
|
||||
|
||||
// Check that QNN compiles DQ -> InstanceNormalization -> Q as a single unit.
|
||||
|
|
@ -97,6 +103,19 @@ TEST_F(QnnHTPBackendTests, InstanceNormU8) {
|
|||
ExpectedEPNodeAssignment::All);
|
||||
}
|
||||
|
||||
TEST_F(QnnHTPBackendTests, InstanceNormU16) {
|
||||
std::vector<float> input_data = {3.21289f, -5.9981f, -1.72799f, 6.27263f, 3.36205f, -1.93515f, -5.40113f, 3.75648f, 6.15357f,
|
||||
-5.25769f, 2.73637f, -0.901382f, -6.55612f, 1.99497f, -4.79228f, 2.69813f, 8.3064f, 0.0362501f};
|
||||
std::vector<float> scale_data = {-0.148738f, -1.45158f};
|
||||
std::vector<float> bias_data = {-2.2785083772f, 2.3338717017f};
|
||||
RunInstanceNormQDQTest<uint16_t>(TestInputDef<float>({1, 2, 3, 3}, false, input_data).OverrideValueRange(-10.0f, 10.0f),
|
||||
TestInputDef<float>({2}, true, scale_data).OverrideValueRange(-2.0f, 2.0f),
|
||||
TestInputDef<float>({2}, true, bias_data).OverrideValueRange(-3.0f, 3.0f),
|
||||
{},
|
||||
ExpectedEPNodeAssignment::All,
|
||||
true); // Use contrib Q/DQ ops for 16bit support.
|
||||
}
|
||||
|
||||
// Check that QNN compiles DQ -> InstanceNormalization -> Q as a single unit.
|
||||
// Use an input of rank 3.
|
||||
TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3) {
|
||||
|
|
@ -107,6 +126,58 @@ TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3) {
|
|||
ExpectedEPNodeAssignment::All);
|
||||
}
|
||||
|
||||
// Test 8-bit QDQ InstanceNormalization with an input of rank 3 with N != 1,
|
||||
// which requires wrapping the QNN InstanceNorm op with reshapes.
|
||||
TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3_BatchSizeNot1) {
|
||||
std::vector<float> input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f,
|
||||
-8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f};
|
||||
RunInstanceNormQDQTest(TestInputDef<float>({2, 2, 3}, false, input_data),
|
||||
TestInputDef<float>({2}, true, {1.0f, 2.0f}),
|
||||
TestInputDef<float>({2}, true, {1.0f, 3.0f}),
|
||||
{},
|
||||
ExpectedEPNodeAssignment::All);
|
||||
}
|
||||
|
||||
// Test 16-bit QDQ InstanceNormalization with an input of rank 3 with N != 1,
|
||||
// which requires wrapping the QNN InstanceNorm op with reshapes.
|
||||
TEST_F(QnnHTPBackendTests, InstanceNormU16Rank3_BatchSizeNot1) {
|
||||
std::vector<float> input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f,
|
||||
-8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f};
|
||||
RunInstanceNormQDQTest<uint16_t>(TestInputDef<float>({2, 2, 3}, false, input_data),
|
||||
TestInputDef<float>({2}, true, {1.0f, 2.0f}),
|
||||
TestInputDef<float>({2}, true, {1.0f, 3.0f}),
|
||||
{},
|
||||
ExpectedEPNodeAssignment::All,
|
||||
true); // Use contrib Q/DQ ops for 16bit support.
|
||||
}
|
||||
|
||||
// Test 8-bit QDQ InstanceNormalization with an input of rank 3 with N != 1,
|
||||
// which requires wrapping the QNN InstanceNorm op with reshapes.
|
||||
// Input 0 is an initializer.
|
||||
TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3_BatchSizeNot1_Initializer) {
|
||||
std::vector<float> input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f,
|
||||
-8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f};
|
||||
RunInstanceNormQDQTest(TestInputDef<float>({2, 2, 3}, true, input_data),
|
||||
TestInputDef<float>({2}, true, {1.0f, 2.0f}),
|
||||
TestInputDef<float>({2}, false, {1.0f, 3.0f}),
|
||||
{},
|
||||
ExpectedEPNodeAssignment::All);
|
||||
}
|
||||
|
||||
// Test 16-bit QDQ InstanceNormalization with an input of rank 3 with N != 1,
|
||||
// which requires wrapping the QNN InstanceNorm op with reshapes.
|
||||
// Input 0 is an initializer.
|
||||
TEST_F(QnnHTPBackendTests, InstanceNormU16Rank3_BatchSizeNot1_Initializer) {
|
||||
std::vector<float> input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f,
|
||||
-8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f};
|
||||
RunInstanceNormQDQTest<uint16_t>(TestInputDef<float>({2, 2, 3}, true, input_data),
|
||||
TestInputDef<float>({2}, true, {1.0f, 2.0f}),
|
||||
TestInputDef<float>({2}, false, {1.0f, 3.0f}),
|
||||
{},
|
||||
ExpectedEPNodeAssignment::All,
|
||||
true); // Use contrib Q/DQ ops for 16-bit support.
|
||||
}
|
||||
|
||||
// Check that QNN InstanceNorm operator does not handle inputs with rank > 4.
|
||||
TEST_F(QnnHTPBackendTests, InstanceNormU8Rank5) {
|
||||
RunInstanceNormQDQTest(TestInputDef<float>({1, 2, 3, 3, 3}, false, -10.0f, 10.0f),
|
||||
|
|
|
|||
Loading…
Reference in a new issue