[NNAPI QDQ] Add QDQ Concat (#10666)

* add qdq concat

Co-authored-by: Scott McKay <skottmckay@gmail.com>
Co-authored-by: rachguo <rachguo@rachguos-Mini.attlocal.net>
This commit is contained in:
Guoyu Wang 2022-03-01 15:08:36 -08:00 committed by GitHub
parent 6448ca64e6
commit 19464614e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 210 additions and 56 deletions

View file

@ -84,6 +84,8 @@ QuantizedOpType GetQuantizedOpType(const NodeUnit& node_unit) {
return QuantizedOpType::QDQReshape;
else if (op_type == "Softmax")
return QuantizedOpType::QDQSoftmax;
else if (op_type == "Concat")
return QuantizedOpType::QDQConcat;
} else {
// throw?
}

View file

@ -93,6 +93,7 @@ enum class QuantizedOpType : uint8_t {
QDQTranspose,
QDQReshape,
QDQSoftmax,
QDQConcat,
// TODO, add other QDQ NodeUnit types
};

View file

@ -1844,10 +1844,29 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
#pragma region op_concat
class ConcatOpBuilder : public BaseOpBuilder {
public:
void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const override;
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override;
bool IsQuantizedOp(const NodeUnit& node_unit) const override;
};
bool ConcatOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const {
// TODO add support of QLinearConcat
return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQConcat;
}
void ConcatOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const {
if (IsQuantizedOp(node_unit)) {
for (size_t i = 0; i < node_unit.Inputs().size(); ++i) {
AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Inputs()[i].quant_param);
}
AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Outputs()[0].quant_param); // y_scale, y_zp
}
}
Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const {
auto& shaper(model_builder.GetShaper());
const auto& operand_indices(model_builder.GetOperandIndices());
@ -1859,21 +1878,40 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
const auto& input0 = inputs[0].node_arg.Name();
const auto node_input_size = inputs.size();
// First if the inputs are uint8, we need verify all the inputs have same scale and zero points
if (operand_types.at(input0).type == android::nn::wrapper::Type::TENSOR_QUANT8_ASYMM) {
auto scale = operand_types.at(input0).operandType.scale;
auto zero_point = operand_types.at(input0).operandType.zeroPoint;
bool is_quant_op = IsQuantizedOp(node_unit);
// Compare scale and zp of input0 to input1~n
for (size_t i = 1; i < node_input_size; i++) {
const auto& type = operand_types.at(inputs[i].node_arg.Name());
ORT_RETURN_IF_NOT(scale == type.operandType.scale,
"Input[", i, "]'s scale: ", type.operandType.scale,
" is different than input[0]'s scale: ", scale);
if (!is_quant_op) {
// If the inputs are uint8 and this is not a quantized Concat, we need to verify all the inputs have the
// same scale and zero points.
// [Side note: int8 input is not supported currently by the NNAPI EP (enforced in ConcatOpSupportChecker).
// it is supported by NNAPI though and int8 input is allowed to have different scale and zp values.]
//
// ONNX allows Concat (not QlinearConcat, not QDQ concat) to run directly on uint8 without scales and zps.
// NNAPI requires all uint8 inputs to have scale values > 0. (zero point can be 0.)
// See https://android.googlesource.com/platform/frameworks/ml/+/master/nn/common/Validation.cpp#486
//
// We need to use the scales and zps from the NNAPI input directly, there is no easy way to get the input
// scales and zps in OpSupportChecker, so we need to verify here.
// Also we have to assume the output scale and zp are the same as input 0
if (operand_types.at(input0).type == android::nn::wrapper::Type::TENSOR_QUANT8_ASYMM) {
auto scale = operand_types.at(input0).operandType.scale;
auto zero_point = operand_types.at(input0).operandType.zeroPoint;
ORT_RETURN_IF_NOT(zero_point == type.operandType.zeroPoint,
"Input[", i, "]'s zero_point: ", type.operandType.zeroPoint,
" is different than input[0]'s zero_point: ", zero_point);
// TODO: if we see scale == 0 in real models we could consider using 1 as a default. This is what TF does
// https://github.com/tensorflow/tensorflow/blob/7737c518a864e54be9b676fe063436ccbbef21b9/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc#L468-L471
ORT_RETURN_IF_NOT(scale > 0, "NNAPI requires scale to be > 0.");
// Compare scale and zp of input0 to input1~n
for (size_t i = 1; i < node_input_size; i++) {
const auto& type = operand_types.at(inputs[i].node_arg.Name());
ORT_RETURN_IF_NOT(scale == type.operandType.scale,
"Input[", i, "]'s scale: ", type.operandType.scale,
" is different than input[0]'s scale: ", scale);
ORT_RETURN_IF_NOT(zero_point == type.operandType.zeroPoint,
"Input[", i, "]'s zero_point: ", type.operandType.zeroPoint,
" is different than input[0]'s zero_point: ", zero_point);
}
}
}
@ -1881,10 +1919,31 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
input_names.reserve(node_input_size);
for (size_t i = 0; i < node_input_size; i++) {
const auto& input = inputs[i].node_arg.Name();
if (is_quant_op) {
// scale and zp values consistency was checked in ConcatOpSupportChecker
float scale = 0.0f;
int32_t zero_point = 0;
ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint(
model_builder.GetInitializerTensors(), node_unit.Inputs()[i], node_unit.ModelPath(),
scale, zero_point));
ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, scale, zero_point));
}
input_indices.push_back(operand_indices.at(input));
input_names.push_back(input);
}
// Get the output scale and zp for quantized concat, default value is from input 0
float y_scale = operand_types.at(input0).operandType.scale;
int32_t y_zero_point = operand_types.at(input0).operandType.zeroPoint;
if (is_quant_op) {
ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint(
model_builder.GetInitializerTensors(), node_unit.Outputs()[0], node_unit.ModelPath(),
y_scale, y_zero_point));
}
int rank = shaper[input0].size();
int32_t axis = static_cast<int32_t>(HandleNegativeAxis(helper.Get("axis", 1), rank));
@ -1892,8 +1951,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
const auto& output = node_unit.Outputs()[0].node_arg.Name();
ORT_RETURN_IF_ERROR(shaper.Concat(input_names, axis, output));
OperandType output_operand_type = operand_types.at(input0);
output_operand_type.SetDimensions(shaper[output]);
OperandType output_operand_type(operand_types.at(input0).type, shaper[output], y_scale, y_zero_point);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_CONCATENATION, input_indices,
{output}, {output_operand_type}));
return Status::OK();
@ -2653,4 +2711,4 @@ const std::unordered_map<std::string, const IOpBuilder*>& GetOpBuilders() {
#pragma endregion
} // namespace nnapi
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -1539,10 +1539,18 @@ class ConcatOpSupportChecker : public BaseOpSupportChecker {
const OpSupportCheckParams& params) const override;
bool HasSupportedInputOutputsImpl(
const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const override;
const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const override;
bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; }
bool IsQuantizedOp(const NodeUnit& node_unit) const override;
};
bool ConcatOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) const {
// TODO add support of QLinearConcat
return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQConcat;
}
bool ConcatOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const {
Shape input_shape;
@ -1560,8 +1568,11 @@ bool ConcatOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* in
}
bool ConcatOpSupportChecker::HasSupportedInputOutputsImpl(
const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const {
const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const {
const auto& op_type = node_unit.OpType();
const auto& op_name = node_unit.Name();
const auto input_size = node_unit.Inputs().size();
int32_t input_type;
if (!GetType(node_unit.Inputs()[0].node_arg, input_type))
return false;
@ -1574,6 +1585,56 @@ bool ConcatOpSupportChecker::HasSupportedInputOutputsImpl(
return false;
}
if (IsQuantizedOp(node_unit)) {
std::vector<size_t> input_indices(input_size);
std::iota(input_indices.begin(), input_indices.end(), 0);
if (!IsQuantizedIOSupported(initializers, node_unit, input_indices, params, IOKind::Input)) {
return false;
}
if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, IOKind::Output)) {
return false;
}
// Need to verify all the input and output has the same scale and zp for API 28-
if (params.android_feature_level < ANEURALNETWORKS_FEATURE_LEVEL_3) {
std::vector<float> input_scales(input_size);
std::vector<int32_t> input_zps(input_size);
size_t input_idx = 0;
auto status = GetQuantizationScaleAndZeroPoint(
initializers, node_unit.Inputs()[input_idx], node_unit.ModelPath(),
input_scales[input_idx], input_zps[input_idx]);
if (!status.IsOK()) {
LOGS_DEFAULT(ERROR) << "Op [" << op_type << "] name [" << op_name
<< "] GetQuantizationScaleAndZeroPoint for input_scale/zp failed, message: "
<< status.ErrorMessage();
return false;
}
for (++input_idx; input_idx < input_size; ++input_idx) {
if (!HasRequiredScaleAndZeroPoint(initializers,
MakeString("Op [", op_type, "] name [", op_name, "] input ", input_idx),
node_unit.Inputs()[input_idx],
node_unit.ModelPath(),
input_scales[0] /* required_scale */,
input_zps[0] /* required_zp */)) {
return false;
}
}
// NNAPI (28-) requires the output scale and zp be the same as the input 0
if (!HasRequiredScaleAndZeroPoint(initializers,
MakeString("Op [", op_type, "] name [", op_name, "]'s output 0"),
node_unit.Outputs()[0], node_unit.ModelPath(),
input_scales[0] /* required_scale */,
input_zps[0] /* required_zp */)) {
return false;
}
}
}
return true;
}

View file

@ -58,5 +58,45 @@ GetQDQTestCaseFn BuildQDQReshapeTestCase(const std::vector<int64_t>& input_shape
};
}
GetQDQTestCaseFn BuildQDQConcatTestCase(const std::vector<std::vector<int64_t>>& input_shapes,
int64_t axis,
bool has_input_float,
bool has_input_int8,
bool has_output_int8) {
return [input_shapes, axis,
has_input_float, has_input_int8, has_output_int8](
ModelTestBuilder& builder) {
auto input_count = input_shapes.size();
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> q_input_args;
for (size_t i = 0; i < input_count; i++) {
input_args.push_back(builder.MakeInput<float>(input_shapes[i], -1.f, 1.f));
if (i == 0 && has_input_float) {
q_input_args.push_back(input_args.back());
} else if (i == 0 && has_input_int8) {
q_input_args.push_back(AddQDQNodePair<int8_t>(builder, input_args.back(), 0.05f, 1));
} else {
q_input_args.push_back(AddQDQNodePair<uint8_t>(builder, input_args.back(), 0.05f, 128));
}
}
auto* concat_output = builder.MakeIntermediate();
Node& concat_node = builder.AddNode("Concat", q_input_args, {concat_output});
concat_node.AddAttribute("axis", axis);
auto* q_concat_output = builder.MakeIntermediate();
if (has_output_int8) {
builder.AddQuantizeLinearNode<int8_t>(concat_output, 0.05f, 1, q_concat_output);
auto* output_arg = builder.MakeOutput();
builder.AddDequantizeLinearNode<int8_t>(q_concat_output, 0.05f, 1, output_arg);
} else {
builder.AddQuantizeLinearNode<uint8_t>(concat_output, 0.05f, 128, q_concat_output);
auto* output_arg = builder.MakeOutput();
builder.AddDequantizeLinearNode<uint8_t>(q_concat_output, 0.05f, 128, output_arg);
}
};
}
} // namespace test
} // namespace onnxruntime

View file

@ -215,9 +215,6 @@ GetQDQTestCaseFn BuildQDQTransposeTestCase(
};
}
GetQDQTestCaseFn BuildQDQReshapeTestCase(const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& reshape_shape);
template <typename InputType, typename OutputType>
GetQDQTestCaseFn BuildQDQSoftMaxTestCase(const std::vector<int64_t>& input_shape, const int64_t& axis = -1) {
return [input_shape, axis](ModelTestBuilder& builder) {
@ -242,5 +239,14 @@ GetQDQTestCaseFn BuildQDQSoftMaxTestCase(const std::vector<int64_t>& input_shape
};
}
GetQDQTestCaseFn BuildQDQReshapeTestCase(const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& reshape_shape);
GetQDQTestCaseFn BuildQDQConcatTestCase(const std::vector<std::vector<int64_t>>& input_shapes,
int64_t axis,
bool has_input_float = false,
bool has_input_int8 = false,
bool has_output_int8 = false);
} // namespace test
} // namespace onnxruntime

View file

@ -1626,38 +1626,6 @@ TEST(QDQTransformerTests, Concat) {
bool has_input_float = false,
bool has_input_int8 = false,
bool has_output_int8 = false) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto input_count = input_shapes.size();
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> q_input_args;
for (size_t i = 0; i < input_count; i++) {
input_args.push_back(builder.MakeInput<float>(input_shapes[i], -1.f, 1.f));
if (i == 0 && has_input_float) {
q_input_args.push_back(input_args.back());
} else if (i == 0 && has_input_int8) {
q_input_args.push_back(AddQDQNodePair<int8_t>(builder, input_args.back(), 0.05f, 1));
} else {
q_input_args.push_back(AddQDQNodePair<uint8_t>(builder, input_args.back(), 0.05f, 128));
}
}
auto* concat_output = builder.MakeIntermediate();
Node& concat_node = builder.AddNode("Concat", q_input_args, {concat_output});
concat_node.AddAttribute("axis", axis);
auto* q_concat_output = builder.MakeIntermediate();
if (has_output_int8) {
builder.AddQuantizeLinearNode<int8_t>(concat_output, 0.05f, 1, q_concat_output);
auto* output_arg = builder.MakeOutput();
builder.AddDequantizeLinearNode<int8_t>(q_concat_output, 0.05f, 1, output_arg);
} else {
builder.AddQuantizeLinearNode<uint8_t>(concat_output, 0.05f, 128, q_concat_output);
auto* output_arg = builder.MakeOutput();
builder.AddDequantizeLinearNode<uint8_t>(q_concat_output, 0.05f, 128, output_arg);
}
};
auto check_graph = [&input_shapes, &has_input_float, &has_input_int8, &has_output_int8](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
if (has_input_float || has_input_int8 || has_output_int8) {
@ -1669,7 +1637,11 @@ TEST(QDQTransformerTests, Concat) {
}
};
TransformerTester(build_test_case,
TransformerTester(BuildQDQConcatTestCase(input_shapes,
axis,
has_input_float,
has_input_int8,
has_output_int8),
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,

View file

@ -290,6 +290,7 @@ static void RunQDQModelTest(const GetQDQTestCaseFn& build_test_case,
std::make_unique<NnapiExecutionProvider>(0),
helper.feeds_, params);
#else
ORT_UNUSED_PARAMETER(params);
// test load only
SessionOptions so;
InferenceSessionWrapper session_object{so, GetEnvironment()};
@ -394,6 +395,19 @@ TEST(NnapiExecutionProviderTest, TestQDQSoftMax) {
});
}
TEST(NnapiExecutionProviderTest, TestQDQConcat) {
RunQDQModelTest(BuildQDQConcatTestCase(
{
{1, 6, 36},
{1, 6, 8},
{1, 6, 2},
} /* input_shapes */,
2 /* axis */),
"nnapi_qdq_test_graph_concat", {
true /* verify_entire_graph_use_ep */
});
}
#endif // !(ORT_MINIMAL_BUILD)
TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {