[QNN EP] Handle rank 3 InstanceNormalization with N != 1 (#17897)

### Description
The QNN HTP backend does not support rank 3 InstanceNorm if the batch
size is not 1. To work around this limitation, QNN EP can wrap a rank 4
QNN InstanceNorm op with Reshapes (with the H dim set to 1).

### Motivation and Context
Enable support for more models.
This commit is contained in:
Adrian Lizarraga 2023-10-12 21:52:09 -07:00 committed by GitHub
parent 0c5b1598d3
commit dad70ad4e8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 200 additions and 15 deletions

View file

@ -24,6 +24,12 @@ class InstanceNormOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override final ORT_MUST_USE_RESULT;
protected:
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) const override ORT_MUST_USE_RESULT;
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
@ -81,6 +87,66 @@ Status InstanceNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
return Status::OK();
}
Status InstanceNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) const {
const auto& inputs = node_unit.Inputs();
OnnxInputInfo input0_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[0], input0_info));
// HTP backend can only handle rank 3 inputs if the batch size is 1. If the batch size is not 1,
// QNN EP must reshape the input and output to (N, 1, W, C) and process the InstanceNorm as rank 4.
if (IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()) &&
input0_info.shape.size() == 3 && input0_info.shape[0] != 1) {
const std::string& orig_input0_name = inputs[0].node_arg.Name();
const std::string op_input0_name = input0_info.is_initializer ? orig_input0_name
: orig_input0_name + "_ort_qnn_ep_reshape";
input_names.push_back(op_input0_name);
std::vector<uint8_t> initializer_data;
if (input0_info.is_initializer) {
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input0_info.initializer_tensor, initializer_data));
}
std::vector<uint32_t> op_shape = {
input0_info.shape[0], // N
1, // Height == 1
input0_info.shape[1], // Width
input0_info.shape[2] // Channels
};
if (!input0_info.is_initializer) {
// Add Reshape node to transform 1D input to 2D (i.e., set height to 1).
// We don't need to do this for initializers, because the element layout does not change. We can just
// modify the shape dimensions.
bool is_graph_input = qnn_model_wrapper.IsGraphInput(orig_input0_name);
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(orig_input0_name,
op_input0_name,
input0_info.shape,
op_shape,
input0_info.qnn_data_type,
input0_info.quant_param,
do_op_validation,
is_graph_input));
}
Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, op_input0_name);
QnnTensorWrapper input_tensorwrapper(op_input0_name, tensor_type, input0_info.qnn_data_type, input0_info.quant_param,
std::move(op_shape), std::move(initializer_data));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
} else {
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); // Input 0
}
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[1], logger, input_names)); // Scale
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[2], logger, input_names)); // Bias
return Status::OK();
}
Status InstanceNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
@ -100,11 +166,59 @@ Status InstanceNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_m
param_tensor_names.push_back(epsilon_param_wrapper.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(epsilon_param_wrapper));
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit,
std::move(input_names),
std::move(param_tensor_names),
logger, do_op_validation, GetQnnOpType(node_unit.OpType())));
const auto& outputs = node_unit.Outputs();
OnnxInputInfo output_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(outputs[0], output_info));
// HTP backend can only handle rank 3 inputs/outputs if the batch size is 1. If the batch size is not 1,
// QNN EP must reshape the input and output to (N, 1, W, C) and process the InstanceNorm as rank 4.
if (!IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()) ||
output_info.shape.size() != 3 || output_info.shape[0] == 1) {
return ProcessOutputs(qnn_model_wrapper, node_unit,
std::move(input_names),
std::move(param_tensor_names),
logger, do_op_validation, GetQnnOpType(node_unit.OpType()));
}
//
// The output is meant to be rank 3 with batch size != 1. Must create a QNN InstanceNorm op with a rank 4 output
// that is then reshaped to rank 3 again.
//
const std::string& orig_output_name = outputs[0].node_arg.Name();
std::string op_output_name = orig_output_name + "_ort_qnn_ep_reshape";
std::vector<uint32_t> op_output_shape = {
output_info.shape[0], // N
1, // H == 1
output_info.shape[1], // W
output_info.shape[2], // C
};
QnnTensorWrapper output_tensorwrapper(op_output_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type,
output_info.quant_param, std::vector<uint32_t>(op_output_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(GetNodeName(node_unit),
QNN_OP_PACKAGE_NAME_QTI_AISW,
GetQnnOpType(node_unit.OpType()),
std::move(input_names),
{op_output_name},
std::move(param_tensor_names)),
"Failed to add node.");
const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(orig_output_name);
// Add Reshape to convert QNN InstanceNorm output back to rank 3 (as expected by the rest of the ONNX graph).
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(op_output_name,
orig_output_name,
op_output_shape,
output_info.shape,
output_info.qnn_data_type,
output_info.quant_param,
do_op_validation,
false,
is_graph_output));
return Status::OK();
}

View file

@ -21,21 +21,26 @@ template <typename QuantType>
static GetTestQDQModelFn<QuantType> BuildQDQInstanceNormTestCase(const TestInputDef<float>& input_def,
const TestInputDef<float>& scale_def,
const TestInputDef<float>& bias_def,
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs) {
return [input_def, scale_def, bias_def, attrs](ModelTestBuilder& builder,
std::vector<QuantParams<QuantType>>& output_qparams) {
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs,
bool use_contrib_qdq = false) {
return [input_def, scale_def, bias_def, attrs,
use_contrib_qdq](ModelTestBuilder& builder,
std::vector<QuantParams<QuantType>>& output_qparams) {
// input => Q => DQ =>
NodeArg* input = MakeTestInput(builder, input_def);
QuantParams<QuantType> input_qparams = GetTestInputQuantParams<QuantType>(input_def);
NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point);
NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point,
use_contrib_qdq);
// scale => Q => DQ =>
NodeArg* scale = MakeTestInput(builder, scale_def);
QuantParams<QuantType> scale_qparams = GetTestInputQuantParams<QuantType>(scale_def);
NodeArg* scale_qdq = AddQDQNodePair(builder, scale, scale_qparams.scale, scale_qparams.zero_point);
NodeArg* scale_qdq = AddQDQNodePair(builder, scale, scale_qparams.scale, scale_qparams.zero_point,
use_contrib_qdq);
// bias (as int32) => DQ =>
NodeArg* bias_qdq = MakeTestQDQBiasInput(builder, bias_def, input_qparams.scale * scale_qparams.scale);
NodeArg* bias_qdq = MakeTestQDQBiasInput(builder, bias_def, input_qparams.scale * scale_qparams.scale,
use_contrib_qdq);
// InstanceNormalization operator.
auto* instance_norm_output = builder.MakeIntermediate();
@ -46,7 +51,8 @@ static GetTestQDQModelFn<QuantType> BuildQDQInstanceNormTestCase(const TestInput
}
// Add instance_norm_output -> Q -> output_u8
AddQDQNodePairWithOutputAsGraphOutput<QuantType>(builder, instance_norm_output, output_qparams[0].scale, output_qparams[0].zero_point);
AddQDQNodePairWithOutputAsGraphOutput<QuantType>(builder, instance_norm_output, output_qparams[0].scale,
output_qparams[0].zero_point, use_contrib_qdq);
};
}
@ -65,7 +71,8 @@ static void RunInstanceNormQDQTest(const TestInputDef<float>& input_def,
const TestInputDef<float>& scale_def,
const TestInputDef<float>& bias_def,
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs,
ExpectedEPNodeAssignment expected_ep_assignment) {
ExpectedEPNodeAssignment expected_ep_assignment,
bool use_contrib_qdq = false) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
@ -75,11 +82,10 @@ static void RunInstanceNormQDQTest(const TestInputDef<float>& input_def,
// Runs model with DQ-> InstanceNorm -> Q and compares the outputs of the CPU and QNN EPs.
TestQDQModelAccuracy(BuildOpTestCase<float>("InstanceNormalization", {input_def, scale_def, bias_def}, {}, attrs),
BuildQDQInstanceNormTestCase<QuantType>(input_def, scale_def, bias_def, attrs),
BuildQDQInstanceNormTestCase<QuantType>(input_def, scale_def, bias_def, attrs, use_contrib_qdq),
provider_options,
18,
expected_ep_assignment,
1e-5f);
expected_ep_assignment);
}
// Check that QNN compiles DQ -> InstanceNormalization -> Q as a single unit.
@ -97,6 +103,19 @@ TEST_F(QnnHTPBackendTests, InstanceNormU8) {
ExpectedEPNodeAssignment::All);
}
TEST_F(QnnHTPBackendTests, InstanceNormU16) {
std::vector<float> input_data = {3.21289f, -5.9981f, -1.72799f, 6.27263f, 3.36205f, -1.93515f, -5.40113f, 3.75648f, 6.15357f,
-5.25769f, 2.73637f, -0.901382f, -6.55612f, 1.99497f, -4.79228f, 2.69813f, 8.3064f, 0.0362501f};
std::vector<float> scale_data = {-0.148738f, -1.45158f};
std::vector<float> bias_data = {-2.2785083772f, 2.3338717017f};
RunInstanceNormQDQTest<uint16_t>(TestInputDef<float>({1, 2, 3, 3}, false, input_data).OverrideValueRange(-10.0f, 10.0f),
TestInputDef<float>({2}, true, scale_data).OverrideValueRange(-2.0f, 2.0f),
TestInputDef<float>({2}, true, bias_data).OverrideValueRange(-3.0f, 3.0f),
{},
ExpectedEPNodeAssignment::All,
true); // Use contrib Q/DQ ops for 16bit support.
}
// Check that QNN compiles DQ -> InstanceNormalization -> Q as a single unit.
// Use an input of rank 3.
TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3) {
@ -107,6 +126,58 @@ TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3) {
ExpectedEPNodeAssignment::All);
}
// Test 8-bit QDQ InstanceNormalization with an input of rank 3 with N != 1,
// which requires wrapping the QNN InstanceNorm op with reshapes.
TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3_BatchSizeNot1) {
std::vector<float> input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f,
-8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f};
RunInstanceNormQDQTest(TestInputDef<float>({2, 2, 3}, false, input_data),
TestInputDef<float>({2}, true, {1.0f, 2.0f}),
TestInputDef<float>({2}, true, {1.0f, 3.0f}),
{},
ExpectedEPNodeAssignment::All);
}
// Test 16-bit QDQ InstanceNormalization with an input of rank 3 with N != 1,
// which requires wrapping the QNN InstanceNorm op with reshapes.
TEST_F(QnnHTPBackendTests, InstanceNormU16Rank3_BatchSizeNot1) {
std::vector<float> input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f,
-8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f};
RunInstanceNormQDQTest<uint16_t>(TestInputDef<float>({2, 2, 3}, false, input_data),
TestInputDef<float>({2}, true, {1.0f, 2.0f}),
TestInputDef<float>({2}, true, {1.0f, 3.0f}),
{},
ExpectedEPNodeAssignment::All,
true); // Use contrib Q/DQ ops for 16bit support.
}
// Test 8-bit QDQ InstanceNormalization with an input of rank 3 with N != 1,
// which requires wrapping the QNN InstanceNorm op with reshapes.
// Input 0 is an initializer.
TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3_BatchSizeNot1_Initializer) {
std::vector<float> input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f,
-8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f};
RunInstanceNormQDQTest(TestInputDef<float>({2, 2, 3}, true, input_data),
TestInputDef<float>({2}, true, {1.0f, 2.0f}),
TestInputDef<float>({2}, false, {1.0f, 3.0f}),
{},
ExpectedEPNodeAssignment::All);
}
// Test 16-bit QDQ InstanceNormalization with an input of rank 3 with N != 1,
// which requires wrapping the QNN InstanceNorm op with reshapes.
// Input 0 is an initializer.
TEST_F(QnnHTPBackendTests, InstanceNormU16Rank3_BatchSizeNot1_Initializer) {
std::vector<float> input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f,
-8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f};
RunInstanceNormQDQTest<uint16_t>(TestInputDef<float>({2, 2, 3}, true, input_data),
TestInputDef<float>({2}, true, {1.0f, 2.0f}),
TestInputDef<float>({2}, false, {1.0f, 3.0f}),
{},
ExpectedEPNodeAssignment::All,
true); // Use contrib Q/DQ ops for 16-bit support.
}
// Check that QNN InstanceNorm operator does not handle inputs with rank > 4.
TEST_F(QnnHTPBackendTests, InstanceNormU8Rank5) {
RunInstanceNormQDQTest(TestInputDef<float>({1, 2, 3, 3, 3}, false, -10.0f, 10.0f),