diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc index da4c65375a..024f64a652 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc @@ -32,8 +32,8 @@ static inline bool has_same_zero_point(bool is_signed, const Tensor* tensor_x_ze QLinearConcat::QLinearConcat(const OpKernelInfo& info) : OpKernel(info), ConcatBase(info) { size_t input_def_count = info.node().InputDefs().size(); - ORT_ENFORCE(input_def_count >= 8 && (input_def_count - 2) % 3 == 0, - "At least two inputs are needed, and each input must be (tensor, scale, zero_point) tuple!"); + ORT_ENFORCE(input_def_count >= 5 && (input_def_count - 2) % 3 == 0, + "Each input must be (tensor, scale, zero_point) tuple!"); size_t input_count = (input_def_count - 2) / 3; fixed_lookup_tables_.resize(input_count); @@ -90,8 +90,8 @@ Status QLinearConcat::Compute(OpKernelContext* ctx) const { // Number of input tensors to concatenate (tupled) auto input_count_x3 = Node().InputArgCount()[2]; - ORT_ENFORCE(input_count_x3 >= 6 && input_count_x3 % 3 == 0, - "At least two inputs are needed, and each input must be (tensor, scale, zero_point) tuple!"); + ORT_ENFORCE(input_count_x3 >= 3 && input_count_x3 % 3 == 0, + "Each input must be (tensor, scale, zero_point) tuple!"); // Hold pointers to the input tensors to be used in the PrepareForCompute() step auto input_count = input_count_x3 / 3; diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 74dbcd86d4..7e9995aff5 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -785,7 +785,7 @@ Wwhere the function `Sigmoid(x) = 1 / (1 + exp(-x))` )DOC"; .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); auto numInputs = ctx.getNumInputs(); - if (numInputs < 8 || (numInputs - 2) % 3 != 0 || + if (numInputs < 5 || (numInputs - 2) % 3 != 0 || !hasNInputShapes(ctx, static_cast(numInputs))) { return; } diff --git a/onnxruntime/test/contrib_ops/qlinear_concat_test.cc b/onnxruntime/test/contrib_ops/qlinear_concat_test.cc index ea59d17adc..506d16d6b1 100644 --- a/onnxruntime/test/contrib_ops/qlinear_concat_test.cc +++ b/onnxruntime/test/contrib_ops/qlinear_concat_test.cc @@ -213,5 +213,31 @@ TEST(QLinearConcatS8, ExpectFail_WrongZeroPointType_1) { QLinearConcat3InputsS8({true, true, true}, FailByWrongZeroPointType); } +void QLinearConcatOneInputS8(std::vector is_const_inputs, QLinearConcatFailCause fail_by = NoFail) { + std::vector y_shape = {2, 1, 3}; + std::vector> x_shapes = {{2, 1, 3}}; + std::vector> x_vecs = { {0, -2, 3, -5, 127, -128} }; + std::vector x_zero_points = {0}; + std::vector x_scales = {1.0}; + + float y_scale = 0.25; + int8_t y_zero_point = 0; + std::vector y_vec = { 0, -8, 12, -20, 127, -128 }; + + RunQLinearConcat(x_shapes, x_vecs, 1, x_scales, x_zero_points, y_shape, y_vec, y_scale, y_zero_point, + is_const_inputs, fail_by); + + RunQLinearConcat(x_shapes, x_vecs, -2, x_scales, x_zero_points, y_shape, y_vec, y_scale, y_zero_point, + is_const_inputs, fail_by); +} + +TEST(QLinearConcatS8, InputOne_Dynamic) { + QLinearConcatOneInputS8({false}); +} + +TEST(QLinearConcatS8, InputOne_Const) { + QLinearConcatOneInputS8({true}); +} + } // namespace test } // namespace onnxruntime