Make requantize a qgemm post processor (#7850)

Description:
Change requantize interface so it can be processed block by block. This enable as to make requantize to be a post processor of QGEMM.

Motivation and Context

Previous changes show we improve performance by parallelize batch gemm. Unfortunately we could not parallelize the batch gemm in quantize_linear_matmul due to the requantize operation at the end of each gemm. By changing requantize to be a qgemm post processor, we now can parallelize the batch operation.

Co-authored-by: Chen Fu <fuchen@microsoft.com>
This commit is contained in:
Chen Fu 2021-05-27 15:05:04 -07:00 committed by GitHub
parent ccdedf1b2e
commit 8140e3fde5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 227 additions and 194 deletions

View file

@ -963,19 +963,83 @@ MlasQuantizeLinear(
OutputType ZeroPoint
);
/**
* @brief Requantize a block of the intermediate buffer to the output buffer,
* optionally adding the supplied bias
*
* @param Input Input matrix
* @param InputLeadingDimension Input matrix leading dimension
* @param Output Output matrix
* @param OutputLeadingDimension Output matrix leading dimension
* @param Bias Optional bias vector, to be added
to the input before quantization
* @param Scale Quantization scale
* @param PerColumnScale true if scale is per-column
* @param ZeroPoint quantization zero point value
* @param StartM
* @param StartN
* @param CountM
* @param CountN
* @return
*/
void
MLASCALL
MlasRequantizeOutput(
const int32_t* Input,
size_t InputLeadingDimension,
uint8_t* Output,
size_t OutputLeadingDimension,
const int32_t* Bias,
size_t M,
size_t N,
const float* Scale,
bool PerColumnScale,
uint8_t ZeroPoint
uint8_t ZeroPoint,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN
);
class MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR : public MLAS_QGEMM_OUTPUT_PROCESSOR
{
public:
MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR(
uint8_t* Output,
size_t OutputLeadingDimension,
const int32_t* Bias,
const float* Scale,
bool PerColumnScale,
uint8_t ZeroPoint)
: Output_(Output),
OutputLeadingDimension_(OutputLeadingDimension),
Bias_(Bias),
Scale_(Scale),
PerColumnScale_(PerColumnScale),
ZeroPoint_(ZeroPoint)
{
}
void Process(const int32_t* C,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc) const override
{
MlasRequantizeOutput(C, ldc, Output_, OutputLeadingDimension_, Bias_, Scale_,
PerColumnScale_, ZeroPoint_, StartM, StartN, CountM, CountN);
}
private:
uint8_t* Output_;
size_t OutputLeadingDimension_;
const int32_t* Bias_;
const float* Scale_;
bool PerColumnScale_;
uint8_t ZeroPoint_;
};
void
MLASCALL
MlasFindMinMaxElement(

View file

@ -121,7 +121,9 @@ MlasQLinearGlobalAveragePoolNchw(
int32x2_t vacc = vadd_s32(vget_high_s32(vacc_lo), vget_low_s32(vacc_lo));
*sum_buffer++ = vget_lane_s32(vpadd_s32(vacc, vacc), 0);
}
MlasRequantizeOutput(AccumulateBuffer, Output, nullptr, 1, Channels, &scale, false, static_cast<uint8_t>(ZeroPointOutput));
MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &scale, false,
static_cast<uint8_t>(ZeroPointOutput), 0, 0, 1, Channels);
}
MLAS_FORCEINLINE
@ -256,7 +258,8 @@ MlasQLinearGlobalAveragePoolNhwcSingleBatch(
vst1q_s32(acc + 4, vacc_hi);
}
}
MlasRequantizeOutput(AccumulateBuffer, Output, nullptr, 1, Channels, &Scale, false, Output_zero_point);
MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &Scale, false,
Output_zero_point, 0, 0, 1, Channels);
}
#elif defined(MLAS_SSE2_INTRINSICS)
@ -323,7 +326,8 @@ MlasQLinearGlobalAveragePoolNchw(
vsums = _mm_add_epi32(vsums, vshuf);
*sum_buffer++ = _mm_cvtsi128_si32(vsums);
}
MlasRequantizeOutput(AccumulateBuffer, Output, nullptr, 1, Channels, &scale, false, static_cast<uint8_t>(ZeroPointOutput));
MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &scale, false,
static_cast<uint8_t>(ZeroPointOutput), 0, 0, 1, Channels);
}
MLAS_FORCEINLINE
@ -515,7 +519,8 @@ MlasQLinearGlobalAveragePoolNhwcSingleBatch(
_mm_storeu_si128(((__m128i*)acc) + 1, vacc_hi);
}
}
MlasRequantizeOutput(AccumulateBuffer, Output, nullptr, 1, Channels, &Scale, false, Output_zero_point);
MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &Scale, false,
Output_zero_point, 0, 0, 1, Channels);
}
#else

View file

@ -356,65 +356,46 @@ void
MLASCALL
MlasRequantizeOutput(
const int32_t* Input,
size_t InputLeadingDimension,
uint8_t* Output,
size_t OutputLeadingDimension,
const int32_t* Bias,
size_t M,
size_t N,
const float* Scale,
bool PerColumnScale,
uint8_t ZeroPoint
uint8_t ZeroPoint,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN
)
/*++
Routine Description:
This routine requantizes the intermediate buffer to the output buffer
optionally adding the supplied bias.
Arguments:
Input - Supplies the input matrix.
Output - Supplies the output matrix.
Bias - Supplies the optional bias vector to be added to the input buffer
before requantization.
Buffer - Supplies the output matrix.
M - Supplies the number of elements of the bias vector and the number of
rows in the output matrix.
N - Supplies the number of columns of the output matrix.
Scale - Supplies the quantization scale.
PerColumnScale - Supplies true if the quantization scale has per-column
values, else false if a single quantization scale applies to the
entire matrix.
ZeroPoint - Supplies the quantization zero point value.
Return Value:
None.
--*/
{
const __m128 PerMatrixScaleVector = PerColumnScale ? _mm_setzero_ps() : _mm_load1_ps(Scale);
const __m128 MinimumValueVector = _mm_set1_ps(float(0 - ZeroPoint));
const __m128 MaximumValueVector = _mm_set1_ps(float(255 - ZeroPoint));
const __m128i ZeroPointVector = _mm_set1_epi32(ZeroPoint);
if (nullptr != Bias) {
Bias += StartN;
}
if (PerColumnScale) {
Scale += StartN;
}
Input += StartM * InputLeadingDimension + StartN;
Output += StartM * OutputLeadingDimension + StartN;
//
// Step through each row of the output matrix.
//
while (M-- > 0) {
while (CountM-- > 0) {
const int32_t* bias = Bias;
const float* scale = PerColumnScale ? Scale : nullptr;
size_t n = N;
size_t n = CountN;
auto* RowInput = Input;
auto* RowOutput = Output;
//
// Process 16 columns of the matrices at a time.
@ -426,11 +407,11 @@ Return Value:
// Load the input data and optionally add the per-column bias.
//
__m128i IntegerVector0 = _mm_loadu_si128((const __m128i *)&Input[0]);
__m128i IntegerVector1 = _mm_loadu_si128((const __m128i *)&Input[4]);
__m128i IntegerVector2 = _mm_loadu_si128((const __m128i *)&Input[8]);
__m128i IntegerVector3 = _mm_loadu_si128((const __m128i *)&Input[12]);
Input += 16;
__m128i IntegerVector0 = _mm_loadu_si128((const __m128i*)&RowInput[0]);
__m128i IntegerVector1 = _mm_loadu_si128((const __m128i*)&RowInput[4]);
__m128i IntegerVector2 = _mm_loadu_si128((const __m128i*)&RowInput[8]);
__m128i IntegerVector3 = _mm_loadu_si128((const __m128i*)&RowInput[12]);
RowInput += 16;
if (bias != nullptr) {
IntegerVector0 = _mm_add_epi32(IntegerVector0, _mm_loadu_si128((const __m128i *)&bias[0]));
@ -491,8 +472,8 @@ Return Value:
__m128i ByteVector = _mm_packus_epi16(WordVector0, WordVector1);
_mm_storeu_si128((__m128i*)Output, ByteVector);
Output += 16;
_mm_storeu_si128((__m128i*)RowOutput, ByteVector);
RowOutput += 16;
n -= 16;
}
@ -511,8 +492,8 @@ Return Value:
if (n >= 4) {
IntegerVector = _mm_loadu_si128((const __m128i*)&Input[0]);
Input += 4;
IntegerVector = _mm_loadu_si128((const __m128i*)&RowInput[0]);
RowInput += 4;
if (bias != nullptr) {
IntegerVector = _mm_add_epi32(IntegerVector, _mm_loadu_si128((const __m128i*)&bias[0]));
@ -521,7 +502,7 @@ Return Value:
} else {
int32_t IntegerValue = *Input++;
int32_t IntegerValue = *RowInput++;
if (bias != nullptr) {
IntegerValue += *bias++;
@ -567,19 +548,23 @@ Return Value:
if (n >= 4) {
*reinterpret_cast<uint32_t*>(Output) = OutputValue;
Output += 4;
*reinterpret_cast<uint32_t*>(RowOutput) = OutputValue;
RowOutput += 4;
n -= 4;
} else {
*Output = uint8_t(OutputValue);
Output += 1;
*RowOutput = uint8_t(OutputValue);
RowOutput += 1;
n -= 1;
}
}
// Next Row
Input += InputLeadingDimension;
Output += OutputLeadingDimension;
}
}
@ -589,63 +574,44 @@ void
MLASCALL
MlasRequantizeOutput(
const int32_t* Input,
size_t InputLeadingDimension,
uint8_t* Output,
size_t OutputLeadingDimension,
const int32_t* Bias,
size_t M,
size_t N,
const float* Scale,
bool PerColumnScale,
uint8_t ZeroPoint
uint8_t ZeroPoint,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN
)
/*++
Routine Description:
This routine requantizes the intermediate buffer to the output buffer
optionally adding the supplied bias.
Arguments:
Input - Supplies the input matrix.
Output - Supplies the output matrix.
Bias - Supplies the optional bias vector to be added to the input buffer
before requantization.
Buffer - Supplies the output matrix.
M - Supplies the number of elements of the bias vector and the number of
rows in the output matrix.
N - Supplies the number of columns of the output matrix.
Scale - Supplies the quantization scale.
PerColumnScale - Supplies true if the quantization scale has per-column
values, else false if a single quantization scale applies to the
entire matrix.
ZeroPoint - Supplies the quantization zero point value.
Return Value:
None.
--*/
{
const float32x4_t PerMatrixScaleVector = PerColumnScale ? vdupq_n_f32(0) : vld1q_dup_f32(Scale);
const int16x8_t ZeroPointVector = vdupq_n_s16(ZeroPoint);
if (nullptr != Bias) {
Bias += StartN;
}
if (PerColumnScale) {
Scale += StartN;
}
Input += StartM * InputLeadingDimension + StartN;
Output += StartM * OutputLeadingDimension + StartN;
//
// Step through each row of the output matrix.
//
while (M-- > 0) {
while (CountM-- > 0) {
const int32_t* bias = Bias;
const float* scale = PerColumnScale ? Scale : nullptr;
size_t n = N;
size_t n = CountN;
auto* RowInput = Input;
auto* RowOutput = Output;
//
// Process 16 columns of the matrices at a time.
@ -659,11 +625,11 @@ Return Value:
int32x4x4_t IntegerVector;
IntegerVector.val[0] = vld1q_s32(&Input[0]);
IntegerVector.val[1] = vld1q_s32(&Input[4]);
IntegerVector.val[2] = vld1q_s32(&Input[8]);
IntegerVector.val[3] = vld1q_s32(&Input[12]);
Input += 16;
IntegerVector.val[0] = vld1q_s32(&RowInput[0]);
IntegerVector.val[1] = vld1q_s32(&RowInput[4]);
IntegerVector.val[2] = vld1q_s32(&RowInput[8]);
IntegerVector.val[3] = vld1q_s32(&RowInput[12]);
RowInput += 16;
if (bias != nullptr) {
IntegerVector.val[0] = vaddq_s32(IntegerVector.val[0], vld1q_s32(&bias[0]));
@ -731,8 +697,8 @@ Return Value:
WordVector.val[0] = vqaddq_s16(WordVector.val[0], ZeroPointVector);
WordVector.val[1] = vqaddq_s16(WordVector.val[1], ZeroPointVector);
vst1q_u8(Output, vqmovun_high_s16(vqmovun_s16(WordVector.val[0]), WordVector.val[1]));
Output += 16;
vst1q_u8(RowOutput, vqmovun_high_s16(vqmovun_s16(WordVector.val[0]), WordVector.val[1]));
RowOutput += 16;
n -= 16;
}
@ -751,8 +717,8 @@ Return Value:
if (n >= 4) {
IntegerVector = vld1q_s32(&Input[0]);
Input += 4;
IntegerVector = vld1q_s32(&RowInput[0]);
RowInput += 4;
if (bias != nullptr) {
IntegerVector = vaddq_s32(IntegerVector, vld1q_s32(&bias[0]));
@ -761,8 +727,8 @@ Return Value:
} else {
IntegerVector = vld1q_dup_s32(Input);
Input += 1;
IntegerVector = vld1q_dup_s32(RowInput);
RowInput += 1;
if (bias != nullptr) {
IntegerVector = vaddq_s32(IntegerVector, vld1q_dup_s32(bias));
@ -813,19 +779,24 @@ Return Value:
if (n >= 4) {
vst1q_lane_u32(reinterpret_cast<uint32_t*>(Output), vreinterpretq_u32_u8(ByteVector), 0);
Output += 4;
vst1q_lane_u32(reinterpret_cast<uint32_t*>(RowOutput),
vreinterpretq_u32_u8(ByteVector), 0);
RowOutput += 4;
n -= 4;
} else {
vst1q_lane_u8(Output, ByteVector, 0);
Output += 1;
vst1q_lane_u8(RowOutput, ByteVector, 0);
RowOutput += 1;
n -= 1;
}
}
// Next Row
Input += InputLeadingDimension;
Output += OutputLeadingDimension;
}
}
@ -835,68 +806,49 @@ void
MLASCALL
MlasRequantizeOutput(
const int32_t* Input,
size_t InputLeadingDimension,
uint8_t* Output,
size_t OutputLeadingDimension,
const int32_t* Bias,
size_t M,
size_t N,
const float* Scale,
bool PerColumnScale,
uint8_t ZeroPoint
uint8_t ZeroPoint,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN
)
/*++
Routine Description:
This routine requantizes the intermediate buffer to the output buffer
optionally adding the supplied bias.
Arguments:
Input - Supplies the input matrix.
Output - Supplies the output matrix.
Bias - Supplies the optional bias vector to be added to the input buffer
before requantization.
Buffer - Supplies the output matrix.
M - Supplies the number of elements of the bias vector and the number of
rows in the output matrix.
N - Supplies the number of columns of the output matrix.
Scale - Supplies the quantization scale.
PerColumnScale - Supplies true if the quantization scale has per-column
values, else false if a single quantization scale applies to the
entire matrix.
ZeroPoint - Supplies the quantization zero point value.
Return Value:
None.
--*/
{
const float PerMatrixScaleValue = PerColumnScale ? 0.0f : *Scale;
const float MinimumValue = float(0 - ZeroPoint);
const float MaximumValue = float(255 - ZeroPoint);
if (nullptr != Bias) {
Bias += StartN;
}
if (PerColumnScale) {
Scale += StartN;
}
Input += StartM * InputLeadingDimension + StartN;
Output += StartM * OutputLeadingDimension + StartN;
//
// Step through each row of the output matrix.
//
while (M-- > 0) {
while (CountM-- > 0) {
const int32_t* bias = Bias;
const float* scale = Scale;
size_t n = N;
size_t n = CountN;
auto* RowInput = Input;
auto* RowOutput = Output;
while (n > 0) {
int32_t IntegerValue = *Input++;
int32_t IntegerValue = *RowInput++;
if (bias != nullptr) {
IntegerValue += *bias++;
@ -920,10 +872,14 @@ Return Value:
IntegerValue = int32_t(MlasBitsOfFp32(FloatValue + MLAS_ROUNDING_BIAS_MAGIC)) -
MLAS_ROUNDING_BIAS_MAGIC_BITS;
*Output++ = uint8_t(IntegerValue + ZeroPoint);
*RowOutput++ = uint8_t(IntegerValue + ZeroPoint);
n -= 1;
}
// Next Row
Input += InputLeadingDimension;
Output += OutputLeadingDimension;
}
}

View file

@ -78,47 +78,52 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) const {
output_scales[i] = (a_scale_data * b_scale_data[i] / y_scale_data);
}
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc));
auto gemm_output_data = alloc->Alloc(SafeInt<size_t>(sizeof(int32_t)) *
static_cast<size_t>(helper.M()) * static_cast<size_t>(helper.N()));
BufferUniquePtr gemm_output_buffer(gemm_output_data, BufferDeleter(alloc));
auto* gemm_output = static_cast<int32_t*>(gemm_output_buffer.get());
const size_t num_gemms = helper.OutputOffsets().size();
MLAS_GEMM_U8X8_SHAPE_PARAMS gemm_shape;
gemm_shape.M = static_cast<size_t>(helper.M());
gemm_shape.N = static_cast<size_t>(helper.N());
gemm_shape.K = static_cast<size_t>(helper.K());
gemm_shape.BIsSigned = b_is_signed;
MLAS_GEMM_U8X8_DATA_PARAMS gemm_params;
gemm_params.lda = gemm_shape.K;
gemm_params.ZeroPointA = *a_offset->template Data<uint8_t>();
gemm_params.ldb = gemm_shape.N;
gemm_params.C = gemm_output;
gemm_params.ldc = gemm_shape.N;
gemm_params.BIsPacked = bool(packed_b_);
gemm_params.PerColumnZeroPoints = !IsScalarOr1ElementVector(b_offset);
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc));
auto gemm_output_data = alloc->Alloc(SafeInt<size_t>(gemm_shape.M) *
gemm_shape.N * sizeof(int32_t) * num_gemms);
BufferUniquePtr gemm_output_buffer(gemm_output_data, BufferDeleter(alloc));
auto* gemm_output = static_cast<int32_t*>(gemm_output_buffer.get());
std::vector<MLAS_GEMM_U8X8_DATA_PARAMS> gemm_params(num_gemms);
std::vector<MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR> requant_procs;
requant_procs.reserve(num_gemms);
auto b_zp_data = static_cast<const uint8_t*>(b_offset->DataRaw());
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
gemm_params.A = a->template Data<uint8_t>() + helper.LeftOffsets()[i];
gemm_params.B = b_data + helper.RightOffsets()[i];
gemm_params.ZeroPointB = b_zp_data + helper.RightZeroPointOffsets()[i];
for (size_t i = 0; i < num_gemms; i++) {
gemm_params[i].A = a->template Data<uint8_t>() + helper.LeftOffsets()[i];
gemm_params[i].lda = gemm_shape.K;
gemm_params[i].ZeroPointA = *a_offset->template Data<uint8_t>();
MlasGemm(gemm_shape, gemm_params, ctx->GetOperatorThreadPool());
gemm_params[i].B = b_data + helper.RightOffsets()[i];
gemm_params[i].ldb = gemm_shape.N;
gemm_params[i].BIsPacked = bool(packed_b_);
gemm_params[i].ZeroPointB = b_zp_data + helper.RightZeroPointOffsets()[i];
//TODO!! consider making this a post processor, so that we can parallize this loop
MlasRequantizeOutput(gemm_output,
y->template MutableData<uint8_t>() + helper.OutputOffsets()[i],
nullptr,
static_cast<size_t>(helper.M()),
static_cast<size_t>(helper.N()),
output_scales.data() + helper.RightScaleOffsets()[i],
output_scales.size() > 1,
*y_offset->template Data<uint8_t>());
gemm_params[i].C = gemm_output + (gemm_shape.M * gemm_shape.N * i);
gemm_params[i].ldc = gemm_shape.N;
gemm_params[i].PerColumnZeroPoints = !IsScalarOr1ElementVector(b_offset);
requant_procs.emplace_back(y->template MutableData<uint8_t>() + helper.OutputOffsets()[i],
static_cast<size_t>(helper.N()),
nullptr,
output_scales.data() + helper.RightScaleOffsets()[i],
output_scales.size() > 1,
*y_offset->template Data<uint8_t>());
gemm_params[i].OutputProcessor = &(requant_procs[i]);
}
MlasGemmBatch(gemm_shape, gemm_params.data(), num_gemms, ctx->GetOperatorThreadPool());
return Status::OK();
}

View file

@ -590,13 +590,16 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
MlasRequantizeOutput(
worker_gemm_output,
worker_requantize_output,
Bdata,
static_cast<size_t>(output_count),
static_cast<size_t>(M),
worker_requantize_output,
static_cast<size_t>(M),
Bdata,
output_scales.data(),
output_scales.size() > 1,
Y_zero_point_value);
Y_zero_point_value,
0,0,
static_cast<size_t>(output_count),
static_cast<size_t>(M));
};
concurrency::ThreadPool::TrySimpleParallelFor(thread_pool, thread_count, conv_worker);