mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Support only one input in QLinearConcat (#11265)
This commit is contained in:
parent
2e6c2177af
commit
70d97bdf53
3 changed files with 31 additions and 5 deletions
|
|
@ -32,8 +32,8 @@ static inline bool has_same_zero_point(bool is_signed, const Tensor* tensor_x_ze
|
|||
|
||||
QLinearConcat::QLinearConcat(const OpKernelInfo& info) : OpKernel(info), ConcatBase(info) {
|
||||
size_t input_def_count = info.node().InputDefs().size();
|
||||
ORT_ENFORCE(input_def_count >= 8 && (input_def_count - 2) % 3 == 0,
|
||||
"At least two inputs are needed, and each input must be (tensor, scale, zero_point) tuple!");
|
||||
ORT_ENFORCE(input_def_count >= 5 && (input_def_count - 2) % 3 == 0,
|
||||
"Each input must be (tensor, scale, zero_point) tuple!");
|
||||
|
||||
size_t input_count = (input_def_count - 2) / 3;
|
||||
fixed_lookup_tables_.resize(input_count);
|
||||
|
|
@ -90,8 +90,8 @@ Status QLinearConcat::Compute(OpKernelContext* ctx) const {
|
|||
|
||||
// Number of input tensors to concatenate (tupled)
|
||||
auto input_count_x3 = Node().InputArgCount()[2];
|
||||
ORT_ENFORCE(input_count_x3 >= 6 && input_count_x3 % 3 == 0,
|
||||
"At least two inputs are needed, and each input must be (tensor, scale, zero_point) tuple!");
|
||||
ORT_ENFORCE(input_count_x3 >= 3 && input_count_x3 % 3 == 0,
|
||||
"Each input must be (tensor, scale, zero_point) tuple!");
|
||||
|
||||
// Hold pointers to the input tensors to be used in the PrepareForCompute() step
|
||||
auto input_count = input_count_x3 / 3;
|
||||
|
|
|
|||
|
|
@ -785,7 +785,7 @@ Wwhere the function `Sigmoid(x) = 1 / (1 + exp(-x))` )DOC";
|
|||
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
auto numInputs = ctx.getNumInputs();
|
||||
if (numInputs < 8 || (numInputs - 2) % 3 != 0 ||
|
||||
if (numInputs < 5 || (numInputs - 2) % 3 != 0 ||
|
||||
!hasNInputShapes(ctx, static_cast<int>(numInputs))) {
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -213,5 +213,31 @@ TEST(QLinearConcatS8, ExpectFail_WrongZeroPointType_1) {
|
|||
QLinearConcat3InputsS8({true, true, true}, FailByWrongZeroPointType);
|
||||
}
|
||||
|
||||
void QLinearConcatOneInputS8(std::vector<bool> is_const_inputs, QLinearConcatFailCause fail_by = NoFail) {
|
||||
std::vector<int64_t> y_shape = {2, 1, 3};
|
||||
std::vector<std::vector<int64_t>> x_shapes = {{2, 1, 3}};
|
||||
std::vector<std::vector<int8_t>> x_vecs = { {0, -2, 3, -5, 127, -128} };
|
||||
std::vector<int8_t> x_zero_points = {0};
|
||||
std::vector<float> x_scales = {1.0};
|
||||
|
||||
float y_scale = 0.25;
|
||||
int8_t y_zero_point = 0;
|
||||
std::vector<int8_t> y_vec = { 0, -8, 12, -20, 127, -128 };
|
||||
|
||||
RunQLinearConcat<int8_t>(x_shapes, x_vecs, 1, x_scales, x_zero_points, y_shape, y_vec, y_scale, y_zero_point,
|
||||
is_const_inputs, fail_by);
|
||||
|
||||
RunQLinearConcat<int8_t>(x_shapes, x_vecs, -2, x_scales, x_zero_points, y_shape, y_vec, y_scale, y_zero_point,
|
||||
is_const_inputs, fail_by);
|
||||
}
|
||||
|
||||
TEST(QLinearConcatS8, InputOne_Dynamic) {
|
||||
QLinearConcatOneInputS8({false});
|
||||
}
|
||||
|
||||
TEST(QLinearConcatS8, InputOne_Const) {
|
||||
QLinearConcatOneInputS8({true});
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue