diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp index e7db3a2672..c6e8af38c0 100644 --- a/onnxruntime/core/mlas/lib/quantize.cpp +++ b/onnxruntime/core/mlas/lib/quantize.cpp @@ -871,6 +871,198 @@ MlasRequantizeOutput( } } +#elif defined(MLAS_TARGET_POWER) + +template +void +MLASCALL +MlasRequantizeOutput( + const int32_t* Input, + size_t InputLeadingDimension, + OutputType* Output, + size_t OutputLeadingDimension, + const int32_t* Bias, + const float* Scale, + bool PerColumnScale, + OutputType ZeroPoint, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN + ) +{ + float PerMatrixScaleValue = PerColumnScale ? 0.0f : *Scale; + float MinimumValue = float(std::numeric_limits::lowest() - ZeroPoint); + float MaximumValue = float(std::numeric_limits::max() - ZeroPoint); + + auto PerMatrixScaleVector = vec_splats(PerMatrixScaleValue); + auto MinimumVector = vec_splats(MinimumValue); + auto MaximumVector = vec_splats(MaximumValue); + auto ZeroPointVector = vec_splats(int32_t(ZeroPoint)); + + // Workaround to avoid 'variable set but not used' message + MLAS_UNREFERENCED_PARAMETER(PerMatrixScaleVector); + + if (nullptr != Bias) { + Bias += StartN; + } + if (PerColumnScale) { + Scale += StartN; + } + + Input += StartM * InputLeadingDimension + StartN; + Output += StartM * OutputLeadingDimension + StartN; + + // + // Step through each row of the output matrix. + // + + while (CountM-- > 0) { + + const int32_t* bias = Bias; + const float* scale = PerColumnScale ? Scale : nullptr; + size_t n = CountN; + + auto* RowInput = Input; + auto* RowOutput = Output; + + // Process 16 cols at a time + + while (n >= 16) { + + auto IntegerVector0 = vec_xl(0, &RowInput[0]); + auto IntegerVector1 = vec_xl(0, &RowInput[4]); + auto IntegerVector2 = vec_xl(0, &RowInput[8]); + auto IntegerVector3 = vec_xl(0, &RowInput[12]); + RowInput += 16; + + if (bias != nullptr) { + IntegerVector0 = vec_add(IntegerVector0, vec_xl(0, &bias[0])); + IntegerVector1 = vec_add(IntegerVector1, vec_xl(0, &bias[4])); + IntegerVector2 = vec_add(IntegerVector2, vec_xl(0, &bias[8])); + IntegerVector3 = vec_add(IntegerVector3, vec_xl(0, &bias[12])); + bias += 16; + } + + auto FloatVector0 = vec_ctf(IntegerVector0, 0); + auto FloatVector1 = vec_ctf(IntegerVector1, 0); + auto FloatVector2 = vec_ctf(IntegerVector2, 0); + auto FloatVector3 = vec_ctf(IntegerVector3, 0); + + if (scale != nullptr) { + FloatVector0 = vec_mul(FloatVector0, vec_xl(0, &scale[0])); + FloatVector1 = vec_mul(FloatVector1, vec_xl(0, &scale[4])); + FloatVector2 = vec_mul(FloatVector2, vec_xl(0, &scale[8])); + FloatVector3 = vec_mul(FloatVector3, vec_xl(0, &scale[12])); + scale += 16; + } else { + FloatVector0 = vec_mul(FloatVector0, PerMatrixScaleVector); + FloatVector1 = vec_mul(FloatVector1, PerMatrixScaleVector); + FloatVector2 = vec_mul(FloatVector2, PerMatrixScaleVector); + FloatVector3 = vec_mul(FloatVector3, PerMatrixScaleVector); + } + + FloatVector0 = vec_max(FloatVector0, MinimumVector); + FloatVector1 = vec_max(FloatVector1, MinimumVector); + FloatVector2 = vec_max(FloatVector2, MinimumVector); + FloatVector3 = vec_max(FloatVector3, MinimumVector); + + FloatVector0 = vec_min(FloatVector0, MaximumVector); + FloatVector1 = vec_min(FloatVector1, MaximumVector); + FloatVector2 = vec_min(FloatVector2, MaximumVector); + FloatVector3 = vec_min(FloatVector3, MaximumVector); + + FloatVector0 = vec_round(FloatVector0); + FloatVector1 = vec_round(FloatVector1); + FloatVector2 = vec_round(FloatVector2); + FloatVector3 = vec_round(FloatVector3); + + auto IntegerOutVector0 = vec_signed(FloatVector0); + auto IntegerOutVector1 = vec_signed(FloatVector1); + auto IntegerOutVector2 = vec_signed(FloatVector2); + auto IntegerOutVector3 = vec_signed(FloatVector3); + + IntegerOutVector0 = vec_add(IntegerOutVector0, ZeroPointVector); + IntegerOutVector1 = vec_add(IntegerOutVector1, ZeroPointVector); + IntegerOutVector2 = vec_add(IntegerOutVector2, ZeroPointVector); + IntegerOutVector3 = vec_add(IntegerOutVector3, ZeroPointVector); + + auto ShortVector0 = vec_pack(IntegerOutVector0, IntegerOutVector1); + auto ShortVector1 = vec_pack(IntegerOutVector2, IntegerOutVector3); + auto CharVector = vec_pack(ShortVector0, ShortVector1); + + vec_xst(CharVector, 0, (int8_t *) RowOutput); + RowOutput += 16; + n -= 16; + } + + while (n >= 4) { + int8_t OutputBuffer[16]; + + auto IntegerVector = vec_xl(0, &RowInput[0]); + RowInput += 4; + + if (bias != nullptr) { + IntegerVector = vec_add(IntegerVector, vec_xl(0, &bias[0])); + bias += 4; + } + + auto FloatVector = vec_ctf(IntegerVector, 0); + + if (scale != nullptr) { + FloatVector = vec_mul(FloatVector, vec_xl(0, scale)); + scale += 4; + } else { + FloatVector = vec_mul(FloatVector, PerMatrixScaleVector); + } + + FloatVector = vec_max(FloatVector, MinimumVector); + FloatVector = vec_min(FloatVector, MaximumVector); + FloatVector = vec_round(FloatVector); + + auto IntegerOutVector = vec_signed(FloatVector); + IntegerOutVector = vec_add(IntegerOutVector, ZeroPointVector); + + auto ShortVector = vec_pack(IntegerOutVector, vec_splats((int32_t) 0)); + auto CharVector = vec_pack(ShortVector, vec_splats((int16_t) 0)); + + vec_xst(CharVector, 0, OutputBuffer); + memcpy(RowOutput, OutputBuffer, 4); + + RowOutput += 4; + n -= 4; + } + + while (n > 0) { + auto IntegerValue = RowInput[0]; + RowInput += 1; + + if (bias != nullptr) { + IntegerValue += bias[0]; + bias += 1; + } + + float FloatValue = float(IntegerValue); + float ScaleValue = PerColumnScale ? *scale++ : PerMatrixScaleValue; + + FloatValue *= ScaleValue; + FloatValue = std::max(FloatValue, MinimumValue); + FloatValue = std::min(FloatValue, MaximumValue); + + IntegerValue = int32_t(MlasBitsOfFp32(FloatValue + MLAS_ROUNDING_BIAS_MAGIC)) - + MLAS_ROUNDING_BIAS_MAGIC_BITS; + + *RowOutput++ = OutputType(IntegerValue + ZeroPoint); + + n -= 1; + } + + // Next Row + Input += InputLeadingDimension; + Output += OutputLeadingDimension; + } +} + #else template