ppc64le: optimizing the MlasRequantizeOutput() with VSX (#11659)

This commit is contained in:
Maxiwell S. Garcia 2022-06-10 20:04:52 -03:00 committed by GitHub
parent def78a1b81
commit 0869f4f4ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -871,6 +871,198 @@ MlasRequantizeOutput(
}
}
#elif defined(MLAS_TARGET_POWER)
template <typename OutputType>
void
MLASCALL
MlasRequantizeOutput(
const int32_t* Input,
size_t InputLeadingDimension,
OutputType* Output,
size_t OutputLeadingDimension,
const int32_t* Bias,
const float* Scale,
bool PerColumnScale,
OutputType ZeroPoint,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN
)
{
float PerMatrixScaleValue = PerColumnScale ? 0.0f : *Scale;
float MinimumValue = float(std::numeric_limits<OutputType>::lowest() - ZeroPoint);
float MaximumValue = float(std::numeric_limits<OutputType>::max() - ZeroPoint);
auto PerMatrixScaleVector = vec_splats(PerMatrixScaleValue);
auto MinimumVector = vec_splats(MinimumValue);
auto MaximumVector = vec_splats(MaximumValue);
auto ZeroPointVector = vec_splats(int32_t(ZeroPoint));
// Workaround to avoid 'variable set but not used' message
MLAS_UNREFERENCED_PARAMETER(PerMatrixScaleVector);
if (nullptr != Bias) {
Bias += StartN;
}
if (PerColumnScale) {
Scale += StartN;
}
Input += StartM * InputLeadingDimension + StartN;
Output += StartM * OutputLeadingDimension + StartN;
//
// Step through each row of the output matrix.
//
while (CountM-- > 0) {
const int32_t* bias = Bias;
const float* scale = PerColumnScale ? Scale : nullptr;
size_t n = CountN;
auto* RowInput = Input;
auto* RowOutput = Output;
// Process 16 cols at a time
while (n >= 16) {
auto IntegerVector0 = vec_xl(0, &RowInput[0]);
auto IntegerVector1 = vec_xl(0, &RowInput[4]);
auto IntegerVector2 = vec_xl(0, &RowInput[8]);
auto IntegerVector3 = vec_xl(0, &RowInput[12]);
RowInput += 16;
if (bias != nullptr) {
IntegerVector0 = vec_add(IntegerVector0, vec_xl(0, &bias[0]));
IntegerVector1 = vec_add(IntegerVector1, vec_xl(0, &bias[4]));
IntegerVector2 = vec_add(IntegerVector2, vec_xl(0, &bias[8]));
IntegerVector3 = vec_add(IntegerVector3, vec_xl(0, &bias[12]));
bias += 16;
}
auto FloatVector0 = vec_ctf(IntegerVector0, 0);
auto FloatVector1 = vec_ctf(IntegerVector1, 0);
auto FloatVector2 = vec_ctf(IntegerVector2, 0);
auto FloatVector3 = vec_ctf(IntegerVector3, 0);
if (scale != nullptr) {
FloatVector0 = vec_mul(FloatVector0, vec_xl(0, &scale[0]));
FloatVector1 = vec_mul(FloatVector1, vec_xl(0, &scale[4]));
FloatVector2 = vec_mul(FloatVector2, vec_xl(0, &scale[8]));
FloatVector3 = vec_mul(FloatVector3, vec_xl(0, &scale[12]));
scale += 16;
} else {
FloatVector0 = vec_mul(FloatVector0, PerMatrixScaleVector);
FloatVector1 = vec_mul(FloatVector1, PerMatrixScaleVector);
FloatVector2 = vec_mul(FloatVector2, PerMatrixScaleVector);
FloatVector3 = vec_mul(FloatVector3, PerMatrixScaleVector);
}
FloatVector0 = vec_max(FloatVector0, MinimumVector);
FloatVector1 = vec_max(FloatVector1, MinimumVector);
FloatVector2 = vec_max(FloatVector2, MinimumVector);
FloatVector3 = vec_max(FloatVector3, MinimumVector);
FloatVector0 = vec_min(FloatVector0, MaximumVector);
FloatVector1 = vec_min(FloatVector1, MaximumVector);
FloatVector2 = vec_min(FloatVector2, MaximumVector);
FloatVector3 = vec_min(FloatVector3, MaximumVector);
FloatVector0 = vec_round(FloatVector0);
FloatVector1 = vec_round(FloatVector1);
FloatVector2 = vec_round(FloatVector2);
FloatVector3 = vec_round(FloatVector3);
auto IntegerOutVector0 = vec_signed(FloatVector0);
auto IntegerOutVector1 = vec_signed(FloatVector1);
auto IntegerOutVector2 = vec_signed(FloatVector2);
auto IntegerOutVector3 = vec_signed(FloatVector3);
IntegerOutVector0 = vec_add(IntegerOutVector0, ZeroPointVector);
IntegerOutVector1 = vec_add(IntegerOutVector1, ZeroPointVector);
IntegerOutVector2 = vec_add(IntegerOutVector2, ZeroPointVector);
IntegerOutVector3 = vec_add(IntegerOutVector3, ZeroPointVector);
auto ShortVector0 = vec_pack(IntegerOutVector0, IntegerOutVector1);
auto ShortVector1 = vec_pack(IntegerOutVector2, IntegerOutVector3);
auto CharVector = vec_pack(ShortVector0, ShortVector1);
vec_xst(CharVector, 0, (int8_t *) RowOutput);
RowOutput += 16;
n -= 16;
}
while (n >= 4) {
int8_t OutputBuffer[16];
auto IntegerVector = vec_xl(0, &RowInput[0]);
RowInput += 4;
if (bias != nullptr) {
IntegerVector = vec_add(IntegerVector, vec_xl(0, &bias[0]));
bias += 4;
}
auto FloatVector = vec_ctf(IntegerVector, 0);
if (scale != nullptr) {
FloatVector = vec_mul(FloatVector, vec_xl(0, scale));
scale += 4;
} else {
FloatVector = vec_mul(FloatVector, PerMatrixScaleVector);
}
FloatVector = vec_max(FloatVector, MinimumVector);
FloatVector = vec_min(FloatVector, MaximumVector);
FloatVector = vec_round(FloatVector);
auto IntegerOutVector = vec_signed(FloatVector);
IntegerOutVector = vec_add(IntegerOutVector, ZeroPointVector);
auto ShortVector = vec_pack(IntegerOutVector, vec_splats((int32_t) 0));
auto CharVector = vec_pack(ShortVector, vec_splats((int16_t) 0));
vec_xst(CharVector, 0, OutputBuffer);
memcpy(RowOutput, OutputBuffer, 4);
RowOutput += 4;
n -= 4;
}
while (n > 0) {
auto IntegerValue = RowInput[0];
RowInput += 1;
if (bias != nullptr) {
IntegerValue += bias[0];
bias += 1;
}
float FloatValue = float(IntegerValue);
float ScaleValue = PerColumnScale ? *scale++ : PerMatrixScaleValue;
FloatValue *= ScaleValue;
FloatValue = std::max(FloatValue, MinimumValue);
FloatValue = std::min(FloatValue, MaximumValue);
IntegerValue = int32_t(MlasBitsOfFp32(FloatValue + MLAS_ROUNDING_BIAS_MAGIC)) -
MLAS_ROUNDING_BIAS_MAGIC_BITS;
*RowOutput++ = OutputType(IntegerValue + ZeroPoint);
n -= 1;
}
// Next Row
Input += InputLeadingDimension;
Output += OutputLeadingDimension;
}
}
#else
template <typename OutputType>