From 19464614e7a71c42016dcf0c5bc3b9e7880975db Mon Sep 17 00:00:00 2001 From: Guoyu Wang <62914304+gwang-msft@users.noreply.github.com> Date: Tue, 1 Mar 2022 15:08:36 -0800 Subject: [PATCH] [NNAPI QDQ] Add QDQ Concat (#10666) * add qdq concat Co-authored-by: Scott McKay Co-authored-by: rachguo --- .../nnapi/nnapi_builtin/builders/helper.cc | 2 + .../nnapi/nnapi_builtin/builders/helper.h | 1 + .../nnapi_builtin/builders/op_builder.cc | 90 +++++++++++++++---- .../builders/op_support_checker.cc | 69 +++++++++++++- onnxruntime/test/optimizer/qdq_test_utils.cc | 40 +++++++++ onnxruntime/test/optimizer/qdq_test_utils.h | 12 ++- .../test/optimizer/qdq_transformer_test.cc | 38 ++------ .../test/providers/nnapi/nnapi_basic_test.cc | 14 +++ 8 files changed, 210 insertions(+), 56 deletions(-) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc index 5ddf73f3d1..0bae9d7bc6 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc @@ -84,6 +84,8 @@ QuantizedOpType GetQuantizedOpType(const NodeUnit& node_unit) { return QuantizedOpType::QDQReshape; else if (op_type == "Softmax") return QuantizedOpType::QDQSoftmax; + else if (op_type == "Concat") + return QuantizedOpType::QDQConcat; } else { // throw? } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h index 9f8a7d8956..64e15b50fe 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h @@ -93,6 +93,7 @@ enum class QuantizedOpType : uint8_t { QDQTranspose, QDQReshape, QDQSoftmax, + QDQConcat, // TODO, add other QDQ NodeUnit types }; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc index 0374c76e4a..b7b406fa5a 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc @@ -1844,10 +1844,29 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const #pragma region op_concat class ConcatOpBuilder : public BaseOpBuilder { + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; + private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; + bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; +bool ConcatOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { + // TODO add support of QLinearConcat + return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQConcat; +} + +void ConcatOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { + if (IsQuantizedOp(node_unit)) { + for (size_t i = 0; i < node_unit.Inputs().size(); ++i) { + AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Inputs()[i].quant_param); + } + + AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Outputs()[0].quant_param); // y_scale, y_zp + } +} + Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); @@ -1859,21 +1878,40 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const auto& input0 = inputs[0].node_arg.Name(); const auto node_input_size = inputs.size(); - // First if the inputs are uint8, we need verify all the inputs have same scale and zero points - if (operand_types.at(input0).type == android::nn::wrapper::Type::TENSOR_QUANT8_ASYMM) { - auto scale = operand_types.at(input0).operandType.scale; - auto zero_point = operand_types.at(input0).operandType.zeroPoint; + bool is_quant_op = IsQuantizedOp(node_unit); - // Compare scale and zp of input0 to input1~n - for (size_t i = 1; i < node_input_size; i++) { - const auto& type = operand_types.at(inputs[i].node_arg.Name()); - ORT_RETURN_IF_NOT(scale == type.operandType.scale, - "Input[", i, "]'s scale: ", type.operandType.scale, - " is different than input[0]'s scale: ", scale); + if (!is_quant_op) { + // If the inputs are uint8 and this is not a quantized Concat, we need to verify all the inputs have the + // same scale and zero points. + // [Side note: int8 input is not supported currently by the NNAPI EP (enforced in ConcatOpSupportChecker). + // it is supported by NNAPI though and int8 input is allowed to have different scale and zp values.] + // + // ONNX allows Concat (not QlinearConcat, not QDQ concat) to run directly on uint8 without scales and zps. + // NNAPI requires all uint8 inputs to have scale values > 0. (zero point can be 0.) + // See https://android.googlesource.com/platform/frameworks/ml/+/master/nn/common/Validation.cpp#486 + // + // We need to use the scales and zps from the NNAPI input directly, there is no easy way to get the input + // scales and zps in OpSupportChecker, so we need to verify here. + // Also we have to assume the output scale and zp are the same as input 0 + if (operand_types.at(input0).type == android::nn::wrapper::Type::TENSOR_QUANT8_ASYMM) { + auto scale = operand_types.at(input0).operandType.scale; + auto zero_point = operand_types.at(input0).operandType.zeroPoint; - ORT_RETURN_IF_NOT(zero_point == type.operandType.zeroPoint, - "Input[", i, "]'s zero_point: ", type.operandType.zeroPoint, - " is different than input[0]'s zero_point: ", zero_point); + // TODO: if we see scale == 0 in real models we could consider using 1 as a default. This is what TF does + // https://github.com/tensorflow/tensorflow/blob/7737c518a864e54be9b676fe063436ccbbef21b9/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc#L468-L471 + ORT_RETURN_IF_NOT(scale > 0, "NNAPI requires scale to be > 0."); + + // Compare scale and zp of input0 to input1~n + for (size_t i = 1; i < node_input_size; i++) { + const auto& type = operand_types.at(inputs[i].node_arg.Name()); + ORT_RETURN_IF_NOT(scale == type.operandType.scale, + "Input[", i, "]'s scale: ", type.operandType.scale, + " is different than input[0]'s scale: ", scale); + + ORT_RETURN_IF_NOT(zero_point == type.operandType.zeroPoint, + "Input[", i, "]'s zero_point: ", type.operandType.zeroPoint, + " is different than input[0]'s zero_point: ", zero_point); + } } } @@ -1881,10 +1919,31 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const input_names.reserve(node_input_size); for (size_t i = 0; i < node_input_size; i++) { const auto& input = inputs[i].node_arg.Name(); + + if (is_quant_op) { + // scale and zp values consistency was checked in ConcatOpSupportChecker + float scale = 0.0f; + int32_t zero_point = 0; + ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( + model_builder.GetInitializerTensors(), node_unit.Inputs()[i], node_unit.ModelPath(), + scale, zero_point)); + + ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, scale, zero_point)); + } + input_indices.push_back(operand_indices.at(input)); input_names.push_back(input); } + // Get the output scale and zp for quantized concat, default value is from input 0 + float y_scale = operand_types.at(input0).operandType.scale; + int32_t y_zero_point = operand_types.at(input0).operandType.zeroPoint; + if (is_quant_op) { + ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( + model_builder.GetInitializerTensors(), node_unit.Outputs()[0], node_unit.ModelPath(), + y_scale, y_zero_point)); + } + int rank = shaper[input0].size(); int32_t axis = static_cast(HandleNegativeAxis(helper.Get("axis", 1), rank)); @@ -1892,8 +1951,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const auto& output = node_unit.Outputs()[0].node_arg.Name(); ORT_RETURN_IF_ERROR(shaper.Concat(input_names, axis, output)); - OperandType output_operand_type = operand_types.at(input0); - output_operand_type.SetDimensions(shaper[output]); + OperandType output_operand_type(operand_types.at(input0).type, shaper[output], y_scale, y_zero_point); ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_CONCATENATION, input_indices, {output}, {output_operand_type})); return Status::OK(); @@ -2653,4 +2711,4 @@ const std::unordered_map& GetOpBuilders() { #pragma endregion } // namespace nnapi -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc index 0e01b8a19e..1593b6b9d7 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc @@ -1539,10 +1539,18 @@ class ConcatOpSupportChecker : public BaseOpSupportChecker { const OpSupportCheckParams& params) const override; bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, - const OpSupportCheckParams& /* params */) const override; + const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const override; + + bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } + bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; +bool ConcatOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) const { + // TODO add support of QLinearConcat + return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQConcat; +} + bool ConcatOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { Shape input_shape; @@ -1560,8 +1568,11 @@ bool ConcatOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* in } bool ConcatOpSupportChecker::HasSupportedInputOutputsImpl( - const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, - const OpSupportCheckParams& /* params */) const { + const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const { + const auto& op_type = node_unit.OpType(); + const auto& op_name = node_unit.Name(); + const auto input_size = node_unit.Inputs().size(); int32_t input_type; if (!GetType(node_unit.Inputs()[0].node_arg, input_type)) return false; @@ -1574,6 +1585,56 @@ bool ConcatOpSupportChecker::HasSupportedInputOutputsImpl( return false; } + if (IsQuantizedOp(node_unit)) { + std::vector input_indices(input_size); + std::iota(input_indices.begin(), input_indices.end(), 0); + if (!IsQuantizedIOSupported(initializers, node_unit, input_indices, params, IOKind::Input)) { + return false; + } + + if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, IOKind::Output)) { + return false; + } + + // Need to verify all the input and output has the same scale and zp for API 28- + if (params.android_feature_level < ANEURALNETWORKS_FEATURE_LEVEL_3) { + std::vector input_scales(input_size); + std::vector input_zps(input_size); + size_t input_idx = 0; + + auto status = GetQuantizationScaleAndZeroPoint( + initializers, node_unit.Inputs()[input_idx], node_unit.ModelPath(), + input_scales[input_idx], input_zps[input_idx]); + + if (!status.IsOK()) { + LOGS_DEFAULT(ERROR) << "Op [" << op_type << "] name [" << op_name + << "] GetQuantizationScaleAndZeroPoint for input_scale/zp failed, message: " + << status.ErrorMessage(); + return false; + } + + for (++input_idx; input_idx < input_size; ++input_idx) { + if (!HasRequiredScaleAndZeroPoint(initializers, + MakeString("Op [", op_type, "] name [", op_name, "] input ", input_idx), + node_unit.Inputs()[input_idx], + node_unit.ModelPath(), + input_scales[0] /* required_scale */, + input_zps[0] /* required_zp */)) { + return false; + } + } + + // NNAPI (28-) requires the output scale and zp be the same as the input 0 + if (!HasRequiredScaleAndZeroPoint(initializers, + MakeString("Op [", op_type, "] name [", op_name, "]'s output 0"), + node_unit.Outputs()[0], node_unit.ModelPath(), + input_scales[0] /* required_scale */, + input_zps[0] /* required_zp */)) { + return false; + } + } + } + return true; } diff --git a/onnxruntime/test/optimizer/qdq_test_utils.cc b/onnxruntime/test/optimizer/qdq_test_utils.cc index d40889306b..607049917f 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.cc +++ b/onnxruntime/test/optimizer/qdq_test_utils.cc @@ -58,5 +58,45 @@ GetQDQTestCaseFn BuildQDQReshapeTestCase(const std::vector& input_shape }; } +GetQDQTestCaseFn BuildQDQConcatTestCase(const std::vector>& input_shapes, + int64_t axis, + bool has_input_float, + bool has_input_int8, + bool has_output_int8) { + return [input_shapes, axis, + has_input_float, has_input_int8, has_output_int8]( + ModelTestBuilder& builder) { + auto input_count = input_shapes.size(); + std::vector input_args; + std::vector q_input_args; + for (size_t i = 0; i < input_count; i++) { + input_args.push_back(builder.MakeInput(input_shapes[i], -1.f, 1.f)); + if (i == 0 && has_input_float) { + q_input_args.push_back(input_args.back()); + } else if (i == 0 && has_input_int8) { + q_input_args.push_back(AddQDQNodePair(builder, input_args.back(), 0.05f, 1)); + } else { + q_input_args.push_back(AddQDQNodePair(builder, input_args.back(), 0.05f, 128)); + } + } + auto* concat_output = builder.MakeIntermediate(); + Node& concat_node = builder.AddNode("Concat", q_input_args, {concat_output}); + concat_node.AddAttribute("axis", axis); + + auto* q_concat_output = builder.MakeIntermediate(); + if (has_output_int8) { + builder.AddQuantizeLinearNode(concat_output, 0.05f, 1, q_concat_output); + + auto* output_arg = builder.MakeOutput(); + builder.AddDequantizeLinearNode(q_concat_output, 0.05f, 1, output_arg); + } else { + builder.AddQuantizeLinearNode(concat_output, 0.05f, 128, q_concat_output); + + auto* output_arg = builder.MakeOutput(); + builder.AddDequantizeLinearNode(q_concat_output, 0.05f, 128, output_arg); + } + }; +} + } // namespace test } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index d9bcc9870a..66b9be1813 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -215,9 +215,6 @@ GetQDQTestCaseFn BuildQDQTransposeTestCase( }; } -GetQDQTestCaseFn BuildQDQReshapeTestCase(const std::vector& input_shape, - const std::vector& reshape_shape); - template GetQDQTestCaseFn BuildQDQSoftMaxTestCase(const std::vector& input_shape, const int64_t& axis = -1) { return [input_shape, axis](ModelTestBuilder& builder) { @@ -242,5 +239,14 @@ GetQDQTestCaseFn BuildQDQSoftMaxTestCase(const std::vector& input_shape }; } +GetQDQTestCaseFn BuildQDQReshapeTestCase(const std::vector& input_shape, + const std::vector& reshape_shape); + +GetQDQTestCaseFn BuildQDQConcatTestCase(const std::vector>& input_shapes, + int64_t axis, + bool has_input_float = false, + bool has_input_int8 = false, + bool has_output_int8 = false); + } // namespace test } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index c9f6805f32..62560ce5e1 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -1626,38 +1626,6 @@ TEST(QDQTransformerTests, Concat) { bool has_input_float = false, bool has_input_int8 = false, bool has_output_int8 = false) { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto input_count = input_shapes.size(); - std::vector input_args; - std::vector q_input_args; - for (size_t i = 0; i < input_count; i++) { - input_args.push_back(builder.MakeInput(input_shapes[i], -1.f, 1.f)); - if (i == 0 && has_input_float) { - q_input_args.push_back(input_args.back()); - } else if (i == 0 && has_input_int8) { - q_input_args.push_back(AddQDQNodePair(builder, input_args.back(), 0.05f, 1)); - } else { - q_input_args.push_back(AddQDQNodePair(builder, input_args.back(), 0.05f, 128)); - } - } - auto* concat_output = builder.MakeIntermediate(); - Node& concat_node = builder.AddNode("Concat", q_input_args, {concat_output}); - concat_node.AddAttribute("axis", axis); - - auto* q_concat_output = builder.MakeIntermediate(); - if (has_output_int8) { - builder.AddQuantizeLinearNode(concat_output, 0.05f, 1, q_concat_output); - - auto* output_arg = builder.MakeOutput(); - builder.AddDequantizeLinearNode(q_concat_output, 0.05f, 1, output_arg); - } else { - builder.AddQuantizeLinearNode(concat_output, 0.05f, 128, q_concat_output); - - auto* output_arg = builder.MakeOutput(); - builder.AddDequantizeLinearNode(q_concat_output, 0.05f, 128, output_arg); - } - }; - auto check_graph = [&input_shapes, &has_input_float, &has_input_int8, &has_output_int8](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); if (has_input_float || has_input_int8 || has_output_int8) { @@ -1669,7 +1637,11 @@ TEST(QDQTransformerTests, Concat) { } }; - TransformerTester(build_test_case, + TransformerTester(BuildQDQConcatTestCase(input_shapes, + axis, + has_input_float, + has_input_int8, + has_output_int8), check_graph, TransformerLevel::Level1, TransformerLevel::Level2, diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc index 042f9ec9e1..5bdac54702 100644 --- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc +++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc @@ -290,6 +290,7 @@ static void RunQDQModelTest(const GetQDQTestCaseFn& build_test_case, std::make_unique(0), helper.feeds_, params); #else + ORT_UNUSED_PARAMETER(params); // test load only SessionOptions so; InferenceSessionWrapper session_object{so, GetEnvironment()}; @@ -394,6 +395,19 @@ TEST(NnapiExecutionProviderTest, TestQDQSoftMax) { }); } +TEST(NnapiExecutionProviderTest, TestQDQConcat) { + RunQDQModelTest(BuildQDQConcatTestCase( + { + {1, 6, 36}, + {1, 6, 8}, + {1, 6, 2}, + } /* input_shapes */, + 2 /* axis */), + "nnapi_qdq_test_graph_concat", { + true /* verify_entire_graph_use_ep */ + }); +} + #endif // !(ORT_MINIMAL_BUILD) TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {