mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Make requantize a qgemm post processor (#7850)
Description: Change requantize interface so it can be processed block by block. This enable as to make requantize to be a post processor of QGEMM. Motivation and Context Previous changes show we improve performance by parallelize batch gemm. Unfortunately we could not parallelize the batch gemm in quantize_linear_matmul due to the requantize operation at the end of each gemm. By changing requantize to be a qgemm post processor, we now can parallelize the batch operation. Co-authored-by: Chen Fu <fuchen@microsoft.com>
This commit is contained in:
parent
ccdedf1b2e
commit
8140e3fde5
5 changed files with 227 additions and 194 deletions
|
|
@ -963,19 +963,83 @@ MlasQuantizeLinear(
|
|||
OutputType ZeroPoint
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Requantize a block of the intermediate buffer to the output buffer,
|
||||
* optionally adding the supplied bias
|
||||
*
|
||||
* @param Input Input matrix
|
||||
* @param InputLeadingDimension Input matrix leading dimension
|
||||
* @param Output Output matrix
|
||||
* @param OutputLeadingDimension Output matrix leading dimension
|
||||
* @param Bias Optional bias vector, to be added
|
||||
to the input before quantization
|
||||
* @param Scale Quantization scale
|
||||
* @param PerColumnScale true if scale is per-column
|
||||
* @param ZeroPoint quantization zero point value
|
||||
* @param StartM
|
||||
* @param StartN
|
||||
* @param CountM
|
||||
* @param CountN
|
||||
* @return
|
||||
*/
|
||||
void
|
||||
MLASCALL
|
||||
MlasRequantizeOutput(
|
||||
const int32_t* Input,
|
||||
size_t InputLeadingDimension,
|
||||
uint8_t* Output,
|
||||
size_t OutputLeadingDimension,
|
||||
const int32_t* Bias,
|
||||
size_t M,
|
||||
size_t N,
|
||||
const float* Scale,
|
||||
bool PerColumnScale,
|
||||
uint8_t ZeroPoint
|
||||
uint8_t ZeroPoint,
|
||||
size_t StartM,
|
||||
size_t StartN,
|
||||
size_t CountM,
|
||||
size_t CountN
|
||||
);
|
||||
|
||||
class MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR : public MLAS_QGEMM_OUTPUT_PROCESSOR
|
||||
{
|
||||
public:
|
||||
MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR(
|
||||
uint8_t* Output,
|
||||
size_t OutputLeadingDimension,
|
||||
const int32_t* Bias,
|
||||
const float* Scale,
|
||||
bool PerColumnScale,
|
||||
uint8_t ZeroPoint)
|
||||
: Output_(Output),
|
||||
OutputLeadingDimension_(OutputLeadingDimension),
|
||||
Bias_(Bias),
|
||||
Scale_(Scale),
|
||||
PerColumnScale_(PerColumnScale),
|
||||
ZeroPoint_(ZeroPoint)
|
||||
{
|
||||
}
|
||||
|
||||
void Process(const int32_t* C,
|
||||
size_t StartM,
|
||||
size_t StartN,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t ldc) const override
|
||||
{
|
||||
MlasRequantizeOutput(C, ldc, Output_, OutputLeadingDimension_, Bias_, Scale_,
|
||||
PerColumnScale_, ZeroPoint_, StartM, StartN, CountM, CountN);
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
uint8_t* Output_;
|
||||
size_t OutputLeadingDimension_;
|
||||
const int32_t* Bias_;
|
||||
const float* Scale_;
|
||||
bool PerColumnScale_;
|
||||
uint8_t ZeroPoint_;
|
||||
};
|
||||
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasFindMinMaxElement(
|
||||
|
|
|
|||
|
|
@ -121,7 +121,9 @@ MlasQLinearGlobalAveragePoolNchw(
|
|||
int32x2_t vacc = vadd_s32(vget_high_s32(vacc_lo), vget_low_s32(vacc_lo));
|
||||
*sum_buffer++ = vget_lane_s32(vpadd_s32(vacc, vacc), 0);
|
||||
}
|
||||
MlasRequantizeOutput(AccumulateBuffer, Output, nullptr, 1, Channels, &scale, false, static_cast<uint8_t>(ZeroPointOutput));
|
||||
|
||||
MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &scale, false,
|
||||
static_cast<uint8_t>(ZeroPointOutput), 0, 0, 1, Channels);
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
|
|
@ -256,7 +258,8 @@ MlasQLinearGlobalAveragePoolNhwcSingleBatch(
|
|||
vst1q_s32(acc + 4, vacc_hi);
|
||||
}
|
||||
}
|
||||
MlasRequantizeOutput(AccumulateBuffer, Output, nullptr, 1, Channels, &Scale, false, Output_zero_point);
|
||||
MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &Scale, false,
|
||||
Output_zero_point, 0, 0, 1, Channels);
|
||||
}
|
||||
|
||||
#elif defined(MLAS_SSE2_INTRINSICS)
|
||||
|
|
@ -323,7 +326,8 @@ MlasQLinearGlobalAveragePoolNchw(
|
|||
vsums = _mm_add_epi32(vsums, vshuf);
|
||||
*sum_buffer++ = _mm_cvtsi128_si32(vsums);
|
||||
}
|
||||
MlasRequantizeOutput(AccumulateBuffer, Output, nullptr, 1, Channels, &scale, false, static_cast<uint8_t>(ZeroPointOutput));
|
||||
MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &scale, false,
|
||||
static_cast<uint8_t>(ZeroPointOutput), 0, 0, 1, Channels);
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
|
|
@ -515,7 +519,8 @@ MlasQLinearGlobalAveragePoolNhwcSingleBatch(
|
|||
_mm_storeu_si128(((__m128i*)acc) + 1, vacc_hi);
|
||||
}
|
||||
}
|
||||
MlasRequantizeOutput(AccumulateBuffer, Output, nullptr, 1, Channels, &Scale, false, Output_zero_point);
|
||||
MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &Scale, false,
|
||||
Output_zero_point, 0, 0, 1, Channels);
|
||||
}
|
||||
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -356,65 +356,46 @@ void
|
|||
MLASCALL
|
||||
MlasRequantizeOutput(
|
||||
const int32_t* Input,
|
||||
size_t InputLeadingDimension,
|
||||
uint8_t* Output,
|
||||
size_t OutputLeadingDimension,
|
||||
const int32_t* Bias,
|
||||
size_t M,
|
||||
size_t N,
|
||||
const float* Scale,
|
||||
bool PerColumnScale,
|
||||
uint8_t ZeroPoint
|
||||
uint8_t ZeroPoint,
|
||||
size_t StartM,
|
||||
size_t StartN,
|
||||
size_t CountM,
|
||||
size_t CountN
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine requantizes the intermediate buffer to the output buffer
|
||||
optionally adding the supplied bias.
|
||||
|
||||
Arguments:
|
||||
|
||||
Input - Supplies the input matrix.
|
||||
|
||||
Output - Supplies the output matrix.
|
||||
|
||||
Bias - Supplies the optional bias vector to be added to the input buffer
|
||||
before requantization.
|
||||
|
||||
Buffer - Supplies the output matrix.
|
||||
|
||||
M - Supplies the number of elements of the bias vector and the number of
|
||||
rows in the output matrix.
|
||||
|
||||
N - Supplies the number of columns of the output matrix.
|
||||
|
||||
Scale - Supplies the quantization scale.
|
||||
|
||||
PerColumnScale - Supplies true if the quantization scale has per-column
|
||||
values, else false if a single quantization scale applies to the
|
||||
entire matrix.
|
||||
|
||||
ZeroPoint - Supplies the quantization zero point value.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
const __m128 PerMatrixScaleVector = PerColumnScale ? _mm_setzero_ps() : _mm_load1_ps(Scale);
|
||||
const __m128 MinimumValueVector = _mm_set1_ps(float(0 - ZeroPoint));
|
||||
const __m128 MaximumValueVector = _mm_set1_ps(float(255 - ZeroPoint));
|
||||
const __m128i ZeroPointVector = _mm_set1_epi32(ZeroPoint);
|
||||
|
||||
if (nullptr != Bias) {
|
||||
Bias += StartN;
|
||||
}
|
||||
if (PerColumnScale) {
|
||||
Scale += StartN;
|
||||
}
|
||||
|
||||
Input += StartM * InputLeadingDimension + StartN;
|
||||
Output += StartM * OutputLeadingDimension + StartN;
|
||||
|
||||
//
|
||||
// Step through each row of the output matrix.
|
||||
//
|
||||
|
||||
while (M-- > 0) {
|
||||
while (CountM-- > 0) {
|
||||
|
||||
const int32_t* bias = Bias;
|
||||
const float* scale = PerColumnScale ? Scale : nullptr;
|
||||
size_t n = N;
|
||||
size_t n = CountN;
|
||||
|
||||
auto* RowInput = Input;
|
||||
auto* RowOutput = Output;
|
||||
|
||||
//
|
||||
// Process 16 columns of the matrices at a time.
|
||||
|
|
@ -426,11 +407,11 @@ Return Value:
|
|||
// Load the input data and optionally add the per-column bias.
|
||||
//
|
||||
|
||||
__m128i IntegerVector0 = _mm_loadu_si128((const __m128i *)&Input[0]);
|
||||
__m128i IntegerVector1 = _mm_loadu_si128((const __m128i *)&Input[4]);
|
||||
__m128i IntegerVector2 = _mm_loadu_si128((const __m128i *)&Input[8]);
|
||||
__m128i IntegerVector3 = _mm_loadu_si128((const __m128i *)&Input[12]);
|
||||
Input += 16;
|
||||
__m128i IntegerVector0 = _mm_loadu_si128((const __m128i*)&RowInput[0]);
|
||||
__m128i IntegerVector1 = _mm_loadu_si128((const __m128i*)&RowInput[4]);
|
||||
__m128i IntegerVector2 = _mm_loadu_si128((const __m128i*)&RowInput[8]);
|
||||
__m128i IntegerVector3 = _mm_loadu_si128((const __m128i*)&RowInput[12]);
|
||||
RowInput += 16;
|
||||
|
||||
if (bias != nullptr) {
|
||||
IntegerVector0 = _mm_add_epi32(IntegerVector0, _mm_loadu_si128((const __m128i *)&bias[0]));
|
||||
|
|
@ -491,8 +472,8 @@ Return Value:
|
|||
|
||||
__m128i ByteVector = _mm_packus_epi16(WordVector0, WordVector1);
|
||||
|
||||
_mm_storeu_si128((__m128i*)Output, ByteVector);
|
||||
Output += 16;
|
||||
_mm_storeu_si128((__m128i*)RowOutput, ByteVector);
|
||||
RowOutput += 16;
|
||||
|
||||
n -= 16;
|
||||
}
|
||||
|
|
@ -511,8 +492,8 @@ Return Value:
|
|||
|
||||
if (n >= 4) {
|
||||
|
||||
IntegerVector = _mm_loadu_si128((const __m128i*)&Input[0]);
|
||||
Input += 4;
|
||||
IntegerVector = _mm_loadu_si128((const __m128i*)&RowInput[0]);
|
||||
RowInput += 4;
|
||||
|
||||
if (bias != nullptr) {
|
||||
IntegerVector = _mm_add_epi32(IntegerVector, _mm_loadu_si128((const __m128i*)&bias[0]));
|
||||
|
|
@ -521,7 +502,7 @@ Return Value:
|
|||
|
||||
} else {
|
||||
|
||||
int32_t IntegerValue = *Input++;
|
||||
int32_t IntegerValue = *RowInput++;
|
||||
|
||||
if (bias != nullptr) {
|
||||
IntegerValue += *bias++;
|
||||
|
|
@ -567,19 +548,23 @@ Return Value:
|
|||
|
||||
if (n >= 4) {
|
||||
|
||||
*reinterpret_cast<uint32_t*>(Output) = OutputValue;
|
||||
Output += 4;
|
||||
*reinterpret_cast<uint32_t*>(RowOutput) = OutputValue;
|
||||
RowOutput += 4;
|
||||
|
||||
n -= 4;
|
||||
|
||||
} else {
|
||||
|
||||
*Output = uint8_t(OutputValue);
|
||||
Output += 1;
|
||||
*RowOutput = uint8_t(OutputValue);
|
||||
RowOutput += 1;
|
||||
|
||||
n -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Next Row
|
||||
Input += InputLeadingDimension;
|
||||
Output += OutputLeadingDimension;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -589,63 +574,44 @@ void
|
|||
MLASCALL
|
||||
MlasRequantizeOutput(
|
||||
const int32_t* Input,
|
||||
size_t InputLeadingDimension,
|
||||
uint8_t* Output,
|
||||
size_t OutputLeadingDimension,
|
||||
const int32_t* Bias,
|
||||
size_t M,
|
||||
size_t N,
|
||||
const float* Scale,
|
||||
bool PerColumnScale,
|
||||
uint8_t ZeroPoint
|
||||
uint8_t ZeroPoint,
|
||||
size_t StartM,
|
||||
size_t StartN,
|
||||
size_t CountM,
|
||||
size_t CountN
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine requantizes the intermediate buffer to the output buffer
|
||||
optionally adding the supplied bias.
|
||||
|
||||
Arguments:
|
||||
|
||||
Input - Supplies the input matrix.
|
||||
|
||||
Output - Supplies the output matrix.
|
||||
|
||||
Bias - Supplies the optional bias vector to be added to the input buffer
|
||||
before requantization.
|
||||
|
||||
Buffer - Supplies the output matrix.
|
||||
|
||||
M - Supplies the number of elements of the bias vector and the number of
|
||||
rows in the output matrix.
|
||||
|
||||
N - Supplies the number of columns of the output matrix.
|
||||
|
||||
Scale - Supplies the quantization scale.
|
||||
|
||||
PerColumnScale - Supplies true if the quantization scale has per-column
|
||||
values, else false if a single quantization scale applies to the
|
||||
entire matrix.
|
||||
|
||||
ZeroPoint - Supplies the quantization zero point value.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
const float32x4_t PerMatrixScaleVector = PerColumnScale ? vdupq_n_f32(0) : vld1q_dup_f32(Scale);
|
||||
const int16x8_t ZeroPointVector = vdupq_n_s16(ZeroPoint);
|
||||
|
||||
if (nullptr != Bias) {
|
||||
Bias += StartN;
|
||||
}
|
||||
if (PerColumnScale) {
|
||||
Scale += StartN;
|
||||
}
|
||||
|
||||
Input += StartM * InputLeadingDimension + StartN;
|
||||
Output += StartM * OutputLeadingDimension + StartN;
|
||||
|
||||
//
|
||||
// Step through each row of the output matrix.
|
||||
//
|
||||
|
||||
while (M-- > 0) {
|
||||
while (CountM-- > 0) {
|
||||
|
||||
const int32_t* bias = Bias;
|
||||
const float* scale = PerColumnScale ? Scale : nullptr;
|
||||
size_t n = N;
|
||||
size_t n = CountN;
|
||||
|
||||
auto* RowInput = Input;
|
||||
auto* RowOutput = Output;
|
||||
|
||||
//
|
||||
// Process 16 columns of the matrices at a time.
|
||||
|
|
@ -659,11 +625,11 @@ Return Value:
|
|||
|
||||
int32x4x4_t IntegerVector;
|
||||
|
||||
IntegerVector.val[0] = vld1q_s32(&Input[0]);
|
||||
IntegerVector.val[1] = vld1q_s32(&Input[4]);
|
||||
IntegerVector.val[2] = vld1q_s32(&Input[8]);
|
||||
IntegerVector.val[3] = vld1q_s32(&Input[12]);
|
||||
Input += 16;
|
||||
IntegerVector.val[0] = vld1q_s32(&RowInput[0]);
|
||||
IntegerVector.val[1] = vld1q_s32(&RowInput[4]);
|
||||
IntegerVector.val[2] = vld1q_s32(&RowInput[8]);
|
||||
IntegerVector.val[3] = vld1q_s32(&RowInput[12]);
|
||||
RowInput += 16;
|
||||
|
||||
if (bias != nullptr) {
|
||||
IntegerVector.val[0] = vaddq_s32(IntegerVector.val[0], vld1q_s32(&bias[0]));
|
||||
|
|
@ -731,8 +697,8 @@ Return Value:
|
|||
WordVector.val[0] = vqaddq_s16(WordVector.val[0], ZeroPointVector);
|
||||
WordVector.val[1] = vqaddq_s16(WordVector.val[1], ZeroPointVector);
|
||||
|
||||
vst1q_u8(Output, vqmovun_high_s16(vqmovun_s16(WordVector.val[0]), WordVector.val[1]));
|
||||
Output += 16;
|
||||
vst1q_u8(RowOutput, vqmovun_high_s16(vqmovun_s16(WordVector.val[0]), WordVector.val[1]));
|
||||
RowOutput += 16;
|
||||
|
||||
n -= 16;
|
||||
}
|
||||
|
|
@ -751,8 +717,8 @@ Return Value:
|
|||
|
||||
if (n >= 4) {
|
||||
|
||||
IntegerVector = vld1q_s32(&Input[0]);
|
||||
Input += 4;
|
||||
IntegerVector = vld1q_s32(&RowInput[0]);
|
||||
RowInput += 4;
|
||||
|
||||
if (bias != nullptr) {
|
||||
IntegerVector = vaddq_s32(IntegerVector, vld1q_s32(&bias[0]));
|
||||
|
|
@ -761,8 +727,8 @@ Return Value:
|
|||
|
||||
} else {
|
||||
|
||||
IntegerVector = vld1q_dup_s32(Input);
|
||||
Input += 1;
|
||||
IntegerVector = vld1q_dup_s32(RowInput);
|
||||
RowInput += 1;
|
||||
|
||||
if (bias != nullptr) {
|
||||
IntegerVector = vaddq_s32(IntegerVector, vld1q_dup_s32(bias));
|
||||
|
|
@ -813,19 +779,24 @@ Return Value:
|
|||
|
||||
if (n >= 4) {
|
||||
|
||||
vst1q_lane_u32(reinterpret_cast<uint32_t*>(Output), vreinterpretq_u32_u8(ByteVector), 0);
|
||||
Output += 4;
|
||||
vst1q_lane_u32(reinterpret_cast<uint32_t*>(RowOutput),
|
||||
vreinterpretq_u32_u8(ByteVector), 0);
|
||||
RowOutput += 4;
|
||||
|
||||
n -= 4;
|
||||
|
||||
} else {
|
||||
|
||||
vst1q_lane_u8(Output, ByteVector, 0);
|
||||
Output += 1;
|
||||
vst1q_lane_u8(RowOutput, ByteVector, 0);
|
||||
RowOutput += 1;
|
||||
|
||||
n -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Next Row
|
||||
Input += InputLeadingDimension;
|
||||
Output += OutputLeadingDimension;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -835,68 +806,49 @@ void
|
|||
MLASCALL
|
||||
MlasRequantizeOutput(
|
||||
const int32_t* Input,
|
||||
size_t InputLeadingDimension,
|
||||
uint8_t* Output,
|
||||
size_t OutputLeadingDimension,
|
||||
const int32_t* Bias,
|
||||
size_t M,
|
||||
size_t N,
|
||||
const float* Scale,
|
||||
bool PerColumnScale,
|
||||
uint8_t ZeroPoint
|
||||
uint8_t ZeroPoint,
|
||||
size_t StartM,
|
||||
size_t StartN,
|
||||
size_t CountM,
|
||||
size_t CountN
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine requantizes the intermediate buffer to the output buffer
|
||||
optionally adding the supplied bias.
|
||||
|
||||
Arguments:
|
||||
|
||||
Input - Supplies the input matrix.
|
||||
|
||||
Output - Supplies the output matrix.
|
||||
|
||||
Bias - Supplies the optional bias vector to be added to the input buffer
|
||||
before requantization.
|
||||
|
||||
Buffer - Supplies the output matrix.
|
||||
|
||||
M - Supplies the number of elements of the bias vector and the number of
|
||||
rows in the output matrix.
|
||||
|
||||
N - Supplies the number of columns of the output matrix.
|
||||
|
||||
Scale - Supplies the quantization scale.
|
||||
|
||||
PerColumnScale - Supplies true if the quantization scale has per-column
|
||||
values, else false if a single quantization scale applies to the
|
||||
entire matrix.
|
||||
|
||||
ZeroPoint - Supplies the quantization zero point value.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
const float PerMatrixScaleValue = PerColumnScale ? 0.0f : *Scale;
|
||||
const float MinimumValue = float(0 - ZeroPoint);
|
||||
const float MaximumValue = float(255 - ZeroPoint);
|
||||
|
||||
if (nullptr != Bias) {
|
||||
Bias += StartN;
|
||||
}
|
||||
if (PerColumnScale) {
|
||||
Scale += StartN;
|
||||
}
|
||||
|
||||
Input += StartM * InputLeadingDimension + StartN;
|
||||
Output += StartM * OutputLeadingDimension + StartN;
|
||||
|
||||
//
|
||||
// Step through each row of the output matrix.
|
||||
//
|
||||
|
||||
while (M-- > 0) {
|
||||
while (CountM-- > 0) {
|
||||
|
||||
const int32_t* bias = Bias;
|
||||
const float* scale = Scale;
|
||||
size_t n = N;
|
||||
size_t n = CountN;
|
||||
|
||||
auto* RowInput = Input;
|
||||
auto* RowOutput = Output;
|
||||
|
||||
while (n > 0) {
|
||||
|
||||
int32_t IntegerValue = *Input++;
|
||||
int32_t IntegerValue = *RowInput++;
|
||||
|
||||
if (bias != nullptr) {
|
||||
IntegerValue += *bias++;
|
||||
|
|
@ -920,10 +872,14 @@ Return Value:
|
|||
IntegerValue = int32_t(MlasBitsOfFp32(FloatValue + MLAS_ROUNDING_BIAS_MAGIC)) -
|
||||
MLAS_ROUNDING_BIAS_MAGIC_BITS;
|
||||
|
||||
*Output++ = uint8_t(IntegerValue + ZeroPoint);
|
||||
*RowOutput++ = uint8_t(IntegerValue + ZeroPoint);
|
||||
|
||||
n -= 1;
|
||||
}
|
||||
|
||||
// Next Row
|
||||
Input += InputLeadingDimension;
|
||||
Output += OutputLeadingDimension;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -78,47 +78,52 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) const {
|
|||
output_scales[i] = (a_scale_data * b_scale_data[i] / y_scale_data);
|
||||
}
|
||||
|
||||
AllocatorPtr alloc;
|
||||
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc));
|
||||
auto gemm_output_data = alloc->Alloc(SafeInt<size_t>(sizeof(int32_t)) *
|
||||
static_cast<size_t>(helper.M()) * static_cast<size_t>(helper.N()));
|
||||
BufferUniquePtr gemm_output_buffer(gemm_output_data, BufferDeleter(alloc));
|
||||
auto* gemm_output = static_cast<int32_t*>(gemm_output_buffer.get());
|
||||
|
||||
const size_t num_gemms = helper.OutputOffsets().size();
|
||||
MLAS_GEMM_U8X8_SHAPE_PARAMS gemm_shape;
|
||||
gemm_shape.M = static_cast<size_t>(helper.M());
|
||||
gemm_shape.N = static_cast<size_t>(helper.N());
|
||||
gemm_shape.K = static_cast<size_t>(helper.K());
|
||||
gemm_shape.BIsSigned = b_is_signed;
|
||||
|
||||
MLAS_GEMM_U8X8_DATA_PARAMS gemm_params;
|
||||
gemm_params.lda = gemm_shape.K;
|
||||
gemm_params.ZeroPointA = *a_offset->template Data<uint8_t>();
|
||||
gemm_params.ldb = gemm_shape.N;
|
||||
gemm_params.C = gemm_output;
|
||||
gemm_params.ldc = gemm_shape.N;
|
||||
gemm_params.BIsPacked = bool(packed_b_);
|
||||
gemm_params.PerColumnZeroPoints = !IsScalarOr1ElementVector(b_offset);
|
||||
AllocatorPtr alloc;
|
||||
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc));
|
||||
auto gemm_output_data = alloc->Alloc(SafeInt<size_t>(gemm_shape.M) *
|
||||
gemm_shape.N * sizeof(int32_t) * num_gemms);
|
||||
BufferUniquePtr gemm_output_buffer(gemm_output_data, BufferDeleter(alloc));
|
||||
auto* gemm_output = static_cast<int32_t*>(gemm_output_buffer.get());
|
||||
|
||||
|
||||
std::vector<MLAS_GEMM_U8X8_DATA_PARAMS> gemm_params(num_gemms);
|
||||
std::vector<MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR> requant_procs;
|
||||
requant_procs.reserve(num_gemms);
|
||||
|
||||
auto b_zp_data = static_cast<const uint8_t*>(b_offset->DataRaw());
|
||||
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
|
||||
gemm_params.A = a->template Data<uint8_t>() + helper.LeftOffsets()[i];
|
||||
gemm_params.B = b_data + helper.RightOffsets()[i];
|
||||
gemm_params.ZeroPointB = b_zp_data + helper.RightZeroPointOffsets()[i];
|
||||
for (size_t i = 0; i < num_gemms; i++) {
|
||||
gemm_params[i].A = a->template Data<uint8_t>() + helper.LeftOffsets()[i];
|
||||
gemm_params[i].lda = gemm_shape.K;
|
||||
gemm_params[i].ZeroPointA = *a_offset->template Data<uint8_t>();
|
||||
|
||||
MlasGemm(gemm_shape, gemm_params, ctx->GetOperatorThreadPool());
|
||||
gemm_params[i].B = b_data + helper.RightOffsets()[i];
|
||||
gemm_params[i].ldb = gemm_shape.N;
|
||||
gemm_params[i].BIsPacked = bool(packed_b_);
|
||||
gemm_params[i].ZeroPointB = b_zp_data + helper.RightZeroPointOffsets()[i];
|
||||
|
||||
//TODO!! consider making this a post processor, so that we can parallize this loop
|
||||
MlasRequantizeOutput(gemm_output,
|
||||
y->template MutableData<uint8_t>() + helper.OutputOffsets()[i],
|
||||
nullptr,
|
||||
static_cast<size_t>(helper.M()),
|
||||
static_cast<size_t>(helper.N()),
|
||||
output_scales.data() + helper.RightScaleOffsets()[i],
|
||||
output_scales.size() > 1,
|
||||
*y_offset->template Data<uint8_t>());
|
||||
gemm_params[i].C = gemm_output + (gemm_shape.M * gemm_shape.N * i);
|
||||
gemm_params[i].ldc = gemm_shape.N;
|
||||
|
||||
gemm_params[i].PerColumnZeroPoints = !IsScalarOr1ElementVector(b_offset);
|
||||
|
||||
requant_procs.emplace_back(y->template MutableData<uint8_t>() + helper.OutputOffsets()[i],
|
||||
static_cast<size_t>(helper.N()),
|
||||
nullptr,
|
||||
output_scales.data() + helper.RightScaleOffsets()[i],
|
||||
output_scales.size() > 1,
|
||||
*y_offset->template Data<uint8_t>());
|
||||
gemm_params[i].OutputProcessor = &(requant_procs[i]);
|
||||
}
|
||||
|
||||
MlasGemmBatch(gemm_shape, gemm_params.data(), num_gemms, ctx->GetOperatorThreadPool());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -590,13 +590,16 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
|
|||
|
||||
MlasRequantizeOutput(
|
||||
worker_gemm_output,
|
||||
worker_requantize_output,
|
||||
Bdata,
|
||||
static_cast<size_t>(output_count),
|
||||
static_cast<size_t>(M),
|
||||
worker_requantize_output,
|
||||
static_cast<size_t>(M),
|
||||
Bdata,
|
||||
output_scales.data(),
|
||||
output_scales.size() > 1,
|
||||
Y_zero_point_value);
|
||||
Y_zero_point_value,
|
||||
0,0,
|
||||
static_cast<size_t>(output_count),
|
||||
static_cast<size_t>(M));
|
||||
};
|
||||
|
||||
concurrency::ThreadPool::TrySimpleParallelFor(thread_pool, thread_count, conv_worker);
|
||||
|
|
|
|||
Loading…
Reference in a new issue