Support only one input in QLinearConcat (#11265)

This commit is contained in:
Zhang Lei 2022-04-19 20:55:51 -07:00 committed by GitHub
parent 2e6c2177af
commit 70d97bdf53
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 5 deletions

View file

@ -32,8 +32,8 @@ static inline bool has_same_zero_point(bool is_signed, const Tensor* tensor_x_ze
QLinearConcat::QLinearConcat(const OpKernelInfo& info) : OpKernel(info), ConcatBase(info) {
size_t input_def_count = info.node().InputDefs().size();
ORT_ENFORCE(input_def_count >= 8 && (input_def_count - 2) % 3 == 0,
"At least two inputs are needed, and each input must be (tensor, scale, zero_point) tuple!");
ORT_ENFORCE(input_def_count >= 5 && (input_def_count - 2) % 3 == 0,
"Each input must be (tensor, scale, zero_point) tuple!");
size_t input_count = (input_def_count - 2) / 3;
fixed_lookup_tables_.resize(input_count);
@ -90,8 +90,8 @@ Status QLinearConcat::Compute(OpKernelContext* ctx) const {
// Number of input tensors to concatenate (tupled)
auto input_count_x3 = Node().InputArgCount()[2];
ORT_ENFORCE(input_count_x3 >= 6 && input_count_x3 % 3 == 0,
"At least two inputs are needed, and each input must be (tensor, scale, zero_point) tuple!");
ORT_ENFORCE(input_count_x3 >= 3 && input_count_x3 % 3 == 0,
"Each input must be (tensor, scale, zero_point) tuple!");
// Hold pointers to the input tensors to be used in the PrepareForCompute() step
auto input_count = input_count_x3 / 3;

View file

@ -785,7 +785,7 @@ Wwhere the function `Sigmoid(x) = 1 / (1 + exp(-x))` )DOC";
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
auto numInputs = ctx.getNumInputs();
if (numInputs < 8 || (numInputs - 2) % 3 != 0 ||
if (numInputs < 5 || (numInputs - 2) % 3 != 0 ||
!hasNInputShapes(ctx, static_cast<int>(numInputs))) {
return;
}

View file

@ -213,5 +213,31 @@ TEST(QLinearConcatS8, ExpectFail_WrongZeroPointType_1) {
QLinearConcat3InputsS8({true, true, true}, FailByWrongZeroPointType);
}
void QLinearConcatOneInputS8(std::vector<bool> is_const_inputs, QLinearConcatFailCause fail_by = NoFail) {
std::vector<int64_t> y_shape = {2, 1, 3};
std::vector<std::vector<int64_t>> x_shapes = {{2, 1, 3}};
std::vector<std::vector<int8_t>> x_vecs = { {0, -2, 3, -5, 127, -128} };
std::vector<int8_t> x_zero_points = {0};
std::vector<float> x_scales = {1.0};
float y_scale = 0.25;
int8_t y_zero_point = 0;
std::vector<int8_t> y_vec = { 0, -8, 12, -20, 127, -128 };
RunQLinearConcat<int8_t>(x_shapes, x_vecs, 1, x_scales, x_zero_points, y_shape, y_vec, y_scale, y_zero_point,
is_const_inputs, fail_by);
RunQLinearConcat<int8_t>(x_shapes, x_vecs, -2, x_scales, x_zero_points, y_shape, y_vec, y_scale, y_zero_point,
is_const_inputs, fail_by);
}
TEST(QLinearConcatS8, InputOne_Dynamic) {
QLinearConcatOneInputS8({false});
}
TEST(QLinearConcatS8, InputOne_Const) {
QLinearConcatOneInputS8({true});
}
} // namespace test
} // namespace onnxruntime