Implement Scale function for quant gemm (#5632)

* Implement a Scale function for quantization

Quantized GEMM is always followed by Scaling (PerTensor Or PerColumn), and often need to be accumulated to an existing matrix. This PR implements a post-processor for quantized GEMM result and accumulate it to another matrix.
This commit is contained in:
Yufeng Li 2020-11-10 23:34:38 -08:00 committed by GitHub
parent cca8cd849a
commit 2ba637c558
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 483 additions and 387 deletions

View file

@ -20,6 +20,7 @@ set(mlas_common_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/quantize.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/qladd.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/qlmul.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/qpostprocessor.cpp
)
if(MSVC)

View file

@ -206,21 +206,26 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
if (packed_weights_) {
const auto* packed_weight =
static_cast<const uint8_t*>(packed_weights_.get()) + packed_weights_size_ * (weights_offset / head_size);
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(qkv_dest + qkv_offset,
head_size,
&dequant_scale,
bias_data + weights_offset);
MlasGemm(
sequence_length, // M = S
head_size, // N = H
hidden_size, // K = NH
input_data + input_offset, // A
hidden_size, // lda = NH
input_zero_point, // input zero point
packed_weight, // B
weight_zero_point, // weight zero point
weights_is_signed, // weight data type
qkv_dest + qkv_offset, // C
head_size, // ldc
&dequant_scale, // output scale
bias_data + weights_offset, // bias
nullptr); // use single-thread
sequence_length, // M = S
head_size, // N = H
hidden_size, // K = NH
input_data + input_offset, // A
hidden_size, // lda = NH
input_zero_point, // input zero point
packed_weight, // B
weight_zero_point, // weight zero point
weights_is_signed, // weight data type
reinterpret_cast<int32_t*>(qkv_dest + qkv_offset), // C
head_size, // ldc
nullptr, // use single-thread
&scale_bias_processor); // output processor
continue;
}
#endif

View file

@ -52,6 +52,10 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
if (packed_b_) {
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(y_data + helper.OutputOffsets()[i],
static_cast<size_t>(helper.N()),
&multiplier,
bias_data);
MlasGemm(static_cast<size_t>(helper.M()),
static_cast<size_t>(helper.N()),
static_cast<size_t>(helper.K()),
@ -61,11 +65,10 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
packed_b_.get(),
b_zero_point,
b_is_signed_,
y_data + helper.OutputOffsets()[i],
reinterpret_cast<int32_t*>(y_data + helper.OutputOffsets()[i]),
static_cast<size_t>(helper.N()),
&multiplier,
bias_data,
thread_pool);
thread_pool,
&scale_bias_processor);
continue;
}
#endif

View file

@ -207,6 +207,80 @@ MlasGemm(
MLAS_THREADPOOL* ThreadPool
);
enum class MLAS_QUANTIZATION_GRANULARITY {
PerMatrix,
PerColumn,
};
enum class MLAS_QGEMM_OUTPUT_MODE {
ZeroMode, // overwrite the output buffer
AccumulateMode, // accumulate to the output buffer
};
class MLAS_QGEMM_OUTPUT_PROCESSOR {
public:
virtual
void
Process(
const int32_t*, // Supplies the address of matrix to process
size_t, // Supplies the start row index of matrix
size_t, // Supplies the start col index of matrix
size_t, // Supplies the element count per row to process
size_t, // Supplies the element count per col to process
size_t // Supplies the leading dimension of matrix
) const = 0;
};
class MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR : public MLAS_QGEMM_OUTPUT_PROCESSOR {
public:
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR(
float* Output,
size_t LeadingDimensionOutput,
const float* Scale,
const float* Bias,
MLAS_QGEMM_OUTPUT_MODE Mode = MLAS_QGEMM_OUTPUT_MODE::ZeroMode,
MLAS_QUANTIZATION_GRANULARITY QuantGran = MLAS_QUANTIZATION_GRANULARITY::PerMatrix) :
Output_(Output),
LeadingDimensionOutput_(LeadingDimensionOutput),
Scale_(Scale),
Bias_(Bias),
OutputMode_(Mode),
QuantGran_(QuantGran)
{
}
void
Process(
const int32_t* C,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc
) const override;
private:
template<bool HasBias, MLAS_QGEMM_OUTPUT_MODE Mode, MLAS_QUANTIZATION_GRANULARITY QuantGran>
inline
void
ProcessImpl(
const int32_t* C,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc
) const;
private:
float* Output_;
size_t LeadingDimensionOutput_;
const float* Scale_;
const float* Bias_;
MLAS_QGEMM_OUTPUT_MODE OutputMode_;
MLAS_QUANTIZATION_GRANULARITY QuantGran_;
};
void
MLASCALL
MlasGemm(
@ -222,27 +296,8 @@ MlasGemm(
bool BIsSigned,
int32_t* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
);
void
MLASCALL
MlasGemm(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const uint8_t* B,
size_t ldb,
uint8_t offb,
bool BIsSigned,
float* C,
size_t ldc,
const float* Scale,
const float* Bias,
MLAS_THREADPOOL* ThreadPool
MLAS_THREADPOOL* ThreadPool,
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr
);
void
@ -259,26 +314,8 @@ MlasGemm(
bool BIsSigned,
int32_t* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
);
void
MLASCALL
MlasGemm(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const void* PackedB,
uint8_t offb,
bool BIsSigned,
float* C,
size_t ldc,
const float* Scale,
const float* Bias,
MLAS_THREADPOOL* ThreadPool
MLAS_THREADPOOL* ThreadPool,
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr
);
//

View file

@ -40,13 +40,11 @@ struct MLAS_GEMM_U8X8_WORK_BLOCK {
size_t ldb;
int32_t* C;
size_t ldc;
const float* Scale;
const float* BiasFloat;
uint8_t offa;
uint8_t offb;
bool BIsPacked;
bool BIsSigned;
bool CIsFloat;
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor;
};
//
@ -84,123 +82,6 @@ MlasGemmU8X8ScaleSumBuffer(
return MlasGemmU8X8ScaleSumBuffer(SumBuffer, SumBuffer, N, Scale);
}
void
MlasGemmU8X8OutputFloat(
const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock,
int32_t* C,
size_t StartN,
size_t CountM,
size_t CountN
)
/*++
Routine Description:
This routine converts the output matrix to a floating point format using
the supplied scale and bias parameters.
Arguments:
WorkBlock - Supplies the structure containing the GEMM parameters.
C - Supplies the address of matrix C.
StartN - Supplies the starting column offset relative to the base of the
work block. This is used to offset into column vectors accessed via the
work block.
CountM - Supplies the number of rows of the output matrix to process.
CountN - Supplies the number of columns of the output matrix to process.
Return Value:
None.
--*/
{
const size_t ldc = WorkBlock->ldc;
const MLAS_FLOAT32X4 ScaleVector = MlasBroadcastFloat32x4(WorkBlock->Scale);
#if !defined(MLAS_SSE2_INTRINSICS)
const float ScaleValue = MlasExtractLaneFloat32x4<0>(ScaleVector);
#endif
//
// Check if the optional bias vector was supplied.
//
const float* BiasFloat = WorkBlock->BiasFloat;
if (BiasFloat != nullptr) {
BiasFloat += WorkBlock->RangeStartN + StartN;
while (CountM-- > 0) {
const float* bias = BiasFloat;
int32_t* c = C;
size_t n = CountN;
while (n >= 4) {
MLAS_FLOAT32X4 FloatVector = MlasCastToFloat32x4(MlasLoadInt32x4(c));
FloatVector = MlasMultiplyFloat32x4(FloatVector, ScaleVector);
FloatVector = MlasAddFloat32x4(FloatVector, MlasLoadFloat32x4(bias));
MlasStoreFloat32x4(reinterpret_cast<float*>(c), FloatVector);
bias += 4;
c += 4;
n -= 4;
}
for (size_t offset = 0; offset < n; offset++) {
#if defined(MLAS_SSE2_INTRINSICS)
__m128 FloatVector = _mm_set_ss(float(c[offset]));
FloatVector = _mm_mul_ss(FloatVector, ScaleVector);
FloatVector = _mm_add_ss(FloatVector, _mm_load_ss(&bias[offset]));
_mm_store_ss(reinterpret_cast<float*>(&c[offset]), FloatVector);
#else
*reinterpret_cast<float*>(&c[offset]) = float(c[offset]) * ScaleValue + bias[offset];
#endif
}
C += ldc;
}
} else {
while (CountM-- > 0) {
int32_t* c = C;
size_t n = CountN;
while (n >= 4) {
MLAS_FLOAT32X4 FloatVector = MlasCastToFloat32x4(MlasLoadInt32x4(c));
FloatVector = MlasMultiplyFloat32x4(FloatVector, ScaleVector);
MlasStoreFloat32x4(reinterpret_cast<float*>(c), FloatVector);
c += 4;
n -= 4;
}
for (size_t offset = 0; offset < n; offset++) {
#if defined(MLAS_SSE2_INTRINSICS)
__m128 FloatVector = _mm_set_ss((float)c[offset]);
FloatVector = _mm_mul_ss(FloatVector, ScaleVector);
_mm_store_ss(reinterpret_cast<float*>(&c[offset]), FloatVector);
#else
*reinterpret_cast<float*>(&c[offset]) = float(c[offset]) * ScaleValue;
#endif
}
C += ldc;
}
}
}
template<typename KernelType>
void
MLASCALL
@ -251,7 +132,7 @@ Return Value:
// Try to use a GEMV kernel if supported by this kernel type.
//
if ((M == 1) && (offa == 0) && (offb == 0) && !WorkBlock->CIsFloat) {
if ((M == 1) && (offa == 0) && (offb == 0) && WorkBlock->OutputProcessor == nullptr) {
if (KernelType::TryGemvKernel(A, B, ldb, C, K, N, WorkBlock->BIsSigned)) {
return;
}
@ -346,8 +227,13 @@ Return Value:
RowsRemaining, CountN, ldc, RowSums, ColumnSumBuffer,
DepthValue, ZeroMode);
if (PostProcess && WorkBlock->CIsFloat) {
MlasGemmU8X8OutputFloat(WorkBlock, c, n, RowsHandled, CountN);
if (PostProcess && WorkBlock->OutputProcessor != nullptr) {
WorkBlock->OutputProcessor->Process(WorkBlock->C,
WorkBlock->RangeStartM + m + CountM - RowsRemaining,
WorkBlock->RangeStartN + n,
RowsHandled,
CountN,
WorkBlock->ldc);
}
c += ldc * RowsHandled;
@ -508,8 +394,14 @@ Return Value:
RowsRemaining, CountN, ldc, RowSums, ColumnSumBuffer,
DepthValue, ZeroMode);
if (PostProcess && WorkBlock->CIsFloat) {
MlasGemmU8X8OutputFloat(WorkBlock, c, n, RowsHandled, CountN);
if (PostProcess && WorkBlock->OutputProcessor != nullptr) {
WorkBlock->OutputProcessor->Process(
WorkBlock->C,
WorkBlock->RangeStartM + m + CountM - RowsRemaining,
WorkBlock->RangeStartN + n,
RowsHandled,
CountN,
WorkBlock->ldc);
}
c += ldc * RowsHandled;
@ -2187,7 +2079,8 @@ MlasGemm(
bool BIsSigned,
int32_t* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
MLAS_THREADPOOL* ThreadPool,
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor
)
/*++
@ -2227,6 +2120,8 @@ Arguments:
ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
OutputProcessor - Post Processor on C.
Return Value:
None.
@ -2250,6 +2145,7 @@ Return Value:
WorkBlock.ldb = ldb;
WorkBlock.C = C;
WorkBlock.ldc = ldc;
WorkBlock.OutputProcessor = OutputProcessor;
WorkBlock.offa = offa;
WorkBlock.offb = offb;
WorkBlock.BIsSigned = BIsSigned;
@ -2261,107 +2157,6 @@ Return Value:
MlasGemmU8X8Schedule(&WorkBlock, ThreadPool);
}
void
MLASCALL
MlasGemm(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const uint8_t* B,
size_t ldb,
uint8_t offb,
bool BIsSigned,
float* C,
size_t ldc,
const float* Scale,
const float* Bias,
MLAS_THREADPOOL* ThreadPool
)
/*++
Routine Description:
This routine implements the quantized integer matrix/matrix multiply
operation (QGEMM).
Arguments:
M - Supplies the number of rows of matrix A and matrix C.
N - Supplies the number of columns of matrix B and matrix C.
K - Supplies the number of columns of matrix A and the number of rows of
matrix B.
A - Supplies the address of matrix A.
lda - Supplies the first dimension of matrix A.
offa - Supplies the zero point offset of matrix A.
B - Supplies the address of matrix B.
ldb - Supplies the first dimension of matrix B.
offb - Supplies the zero point offset of matrix B.
BIsSigned - Supplies true if matrix B is signed data, else false if matrix
B is unsigned data.
C - Supplies the address of matrix C.
ldc - Supplies the first dimension of matrix C.
Scale - Supplies the scale multiplier to apply to each element of matrix C.
Used to scale the integer output of the QGEMM back to a floating point
number.
Bias - Supplies the bias vector to apply to element of matrix C. The vector
is of length N.
ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
Return Value:
None.
--*/
{
MLAS_GEMM_U8X8_WORK_BLOCK WorkBlock;
//
// Capture the GEMM parameters to the work block.
//
memset(&WorkBlock, 0, sizeof(MLAS_GEMM_U8X8_WORK_BLOCK));
WorkBlock.M = M;
WorkBlock.N = N;
WorkBlock.K = K;
WorkBlock.A = A;
WorkBlock.lda = lda;
WorkBlock.B = B;
WorkBlock.ldb = ldb;
WorkBlock.C = (int32_t*)C;
WorkBlock.ldc = ldc;
WorkBlock.Scale = Scale;
WorkBlock.BiasFloat = Bias;
WorkBlock.offa = offa;
WorkBlock.offb = offb;
WorkBlock.BIsSigned = BIsSigned;
WorkBlock.CIsFloat = true;
//
// Schedule the operation across a set of worker threads.
//
MlasGemmU8X8Schedule(&WorkBlock, ThreadPool);
}
#endif // MLAS_SUPPORTS_GEMM_U8X8
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
@ -2380,7 +2175,8 @@ MlasGemm(
bool BIsSigned,
int32_t* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
MLAS_THREADPOOL* ThreadPool,
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor
)
/*++
@ -2418,100 +2214,7 @@ Arguments:
ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
Return Value:
None.
--*/
{
MLAS_GEMM_U8X8_WORK_BLOCK WorkBlock;
//
// Capture the GEMM parameters to the work block.
//
memset(&WorkBlock, 0, sizeof(MLAS_GEMM_U8X8_WORK_BLOCK));
WorkBlock.M = M;
WorkBlock.N = N;
WorkBlock.K = K;
WorkBlock.A = A;
WorkBlock.lda = lda;
WorkBlock.B = PackedB;
WorkBlock.C = C;
WorkBlock.ldc = ldc;
WorkBlock.offa = offa;
WorkBlock.offb = offb;
WorkBlock.BIsPacked = true;
WorkBlock.BIsSigned = BIsSigned;
//
// Schedule the operation across a set of worker threads.
//
MlasGemmU8X8Schedule(&WorkBlock, ThreadPool);
}
void
MLASCALL
MlasGemm(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const void* PackedB,
uint8_t offb,
bool BIsSigned,
float* C,
size_t ldc,
const float* Scale,
const float* Bias,
MLAS_THREADPOOL* ThreadPool
)
/*++
Routine Description:
This routine implements the quantized integer matrix/matrix multiply
operation (QGEMM).
Arguments:
M - Supplies the number of rows of matrix A and matrix C.
N - Supplies the number of columns of matrix B and matrix C.
K - Supplies the number of columns of matrix A and the number of rows of
matrix B.
A - Supplies the address of matrix A.
lda - Supplies the first dimension of matrix A.
offa - Supplies the zero point offset of matrix A.
PackedB - Supplies the address of packed matrix B.
offb - Supplies the zero point offset of matrix B.
BIsSigned - Supplies true if matrix B is signed data, else false if matrix
B is unsigned data.
C - Supplies the address of matrix C.
ldc - Supplies the first dimension of matrix C.
Scale - Supplies the scale multiplier to apply to each element of matrix C.
Used to scale the integer output of the QGEMM back to a floating point
number.
Bias - Supplies the bias vector to apply to element of matrix C. The vector
is of length N.
ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
OutputProcessor - Post Processor on C
Return Value:
@ -2535,13 +2238,11 @@ Return Value:
WorkBlock.B = PackedB;
WorkBlock.C = (int32_t*)C;
WorkBlock.ldc = ldc;
WorkBlock.Scale = Scale;
WorkBlock.BiasFloat = Bias;
WorkBlock.OutputProcessor = OutputProcessor,
WorkBlock.offa = offa;
WorkBlock.offb = offb;
WorkBlock.BIsPacked = true;
WorkBlock.BIsSigned = BIsSigned;
WorkBlock.CIsFloat = true;
//
// Schedule the operation across a set of worker threads.

View file

@ -0,0 +1,242 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
qpostprocessor.cpp
Abstract:
This module implements the post processor for QGEMM.
--*/
#include "mlasi.h"
#ifdef MLAS_SUPPORTS_GEMM_U8X8
void MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR::Process(
const int32_t* C,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc
) const
{
if (Bias_) {
if (QuantGran_ == MLAS_QUANTIZATION_GRANULARITY::PerColumn) {
if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) {
ProcessImpl<true, MLAS_QGEMM_OUTPUT_MODE::AccumulateMode, MLAS_QUANTIZATION_GRANULARITY::PerColumn>(
C,
StartM,
StartN,
CountM,
CountN,
ldc);
} else {
ProcessImpl<true, MLAS_QGEMM_OUTPUT_MODE::ZeroMode, MLAS_QUANTIZATION_GRANULARITY::PerColumn>(
C,
StartM,
StartN,
CountM,
CountN,
ldc);
}
} else if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) {
ProcessImpl<true, MLAS_QGEMM_OUTPUT_MODE::AccumulateMode, MLAS_QUANTIZATION_GRANULARITY::PerMatrix>(
C,
StartM,
StartN,
CountM,
CountN,
ldc);
} else {
ProcessImpl<true, MLAS_QGEMM_OUTPUT_MODE::ZeroMode, MLAS_QUANTIZATION_GRANULARITY::PerMatrix>(
C,
StartM,
StartN,
CountM,
CountN,
ldc);
}
} else {
if (QuantGran_ == MLAS_QUANTIZATION_GRANULARITY::PerColumn) {
if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) {
ProcessImpl<false, MLAS_QGEMM_OUTPUT_MODE::AccumulateMode, MLAS_QUANTIZATION_GRANULARITY::PerColumn>(
C,
StartM,
StartN,
CountM,
CountN,
ldc);
} else {
ProcessImpl<false, MLAS_QGEMM_OUTPUT_MODE::ZeroMode, MLAS_QUANTIZATION_GRANULARITY::PerColumn>(
C,
StartM,
StartN,
CountM,
CountN,
ldc);
}
} else if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) {
ProcessImpl<false, MLAS_QGEMM_OUTPUT_MODE::AccumulateMode, MLAS_QUANTIZATION_GRANULARITY::PerMatrix>(
C,
StartM,
StartN,
CountM,
CountN,
ldc);
} else {
ProcessImpl<false, MLAS_QGEMM_OUTPUT_MODE::ZeroMode, MLAS_QUANTIZATION_GRANULARITY::PerMatrix>(
C,
StartM,
StartN,
CountM,
CountN,
ldc);
}
}
}
template<bool HasBias, MLAS_QGEMM_OUTPUT_MODE Mode, MLAS_QUANTIZATION_GRANULARITY QuantGran>
inline
void
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR::ProcessImpl(
const int32_t* C,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc) const
/*++
Routine Description:
This routine converts the output matrix C to a floating point format using
the stored scale and bias parameters.
Arguments:
C - Supplies the address of matrix C.
StartM - Supplies the starting row offset relative to the matrix.
StartN - Supplies the starting column offset relative to the matrix.
CountM - Supplies the number of rows of the output matrix to process.
CountN - Supplies the number of columns of the output matrix to process.
ldc - Supplies the leading dimension of C.
Return Value:
None.
--*/
{
float* Output = Output_;
const float* Bias = Bias_;
const float* Scale = Scale_;
if (HasBias) {
Bias += StartN;
}
if(QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn){
Scale += StartN;
}
MLAS_FLOAT32X4 ScaleVector = MlasBroadcastFloat32x4(Scale_);
#if !defined(MLAS_SSE2_INTRINSICS)
float ScaleValue = MlasExtractLaneFloat32x4<0>(ScaleVector);
#endif
C += StartM * ldc + StartN;
Output += StartM * LeadingDimensionOutput_ + StartN;
while (CountM-- > 0) {
float* c_out = Output;
const int32_t* c = C;
const float* bias = Bias;
const float* scale = Scale;
size_t n = CountN;
while (n >= 4) {
MLAS_FLOAT32X4 FloatVector = MlasCastToFloat32x4(MlasLoadInt32x4(c));
if (QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn) {
ScaleVector = MlasLoadFloat32x4(scale);
scale += 4;
}
if (Mode == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) {
FloatVector = MlasMultiplyAddFloat32x4(FloatVector, ScaleVector, MlasLoadFloat32x4(c_out));
} else {
FloatVector = MlasMultiplyFloat32x4(FloatVector, ScaleVector);
}
if (HasBias) {
FloatVector = MlasAddFloat32x4(FloatVector, MlasLoadFloat32x4(bias));
bias += 4;
}
MlasStoreFloat32x4(c_out, FloatVector);
c_out += 4;
c += 4;
n -= 4;
}
for (size_t offset = 0; offset < n; offset++) {
#if defined(MLAS_SSE2_INTRINSICS)
__m128 FloatVector = _mm_set_ss(float(c[offset]));
if (QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn) {
ScaleVector = _mm_load_ss(&scale[offset]);
}
if (Mode == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) {
FloatVector = _mm_add_ps(_mm_mul_ss(FloatVector, ScaleVector), _mm_load_ss(&c_out[offset]));
} else {
FloatVector = _mm_mul_ss(FloatVector, ScaleVector);
}
if (HasBias) {
FloatVector = _mm_add_ss(FloatVector, _mm_load_ss(&bias[offset]));
}
_mm_store_ss(&c_out[offset], FloatVector);
#else
if (QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn) {
ScaleValue = scale[offset];
}
float result = float(c[offset]) * ScaleValue;
if (HasBias) {
result += bias[offset];
}
if (Mode == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) {
c_out[offset] += result;
} else {
c_out[offset] = result;
}
#endif
}
C += ldc;
Output += LeadingDimensionOutput_;
}
}
#endif // MLAS_SUPPORTS_GEMM_U8X8

View file

@ -85,7 +85,13 @@ void QGemm(
const float* bias,
concurrency::ThreadPool* thread_pool) {
#ifdef MLAS_SUPPORTS_GEMM_U8X8
MlasGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, rhs_signed, result_data, ldc, result_scale, bias, thread_pool);
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(result_data, ldc, result_scale, bias);
MlasGemm(M, N, K,
lhs_data, lda, lhs_offset,
rhs_data, ldb, rhs_offset, rhs_signed,
reinterpret_cast<int32_t*>(result_data), ldc,
thread_pool,
&scale_bias_processor);
#else
QGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, rhs_signed, reinterpret_cast<int32_t*>(result_data), ldc, thread_pool);
for (int m = 0; m < M; m++) {

View file

@ -573,7 +573,13 @@ protected:
const float* Bias
)
{
MlasGemm(M, N, K, A, lda, offa, B, ldb, offb, BIsSigned, C, ldc, &CScale, Bias, threadpool);
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(C, ldc, &CScale, Bias);
MlasGemm(M, N, K,
A, lda, offa,
B, ldb, offb, BIsSigned,
reinterpret_cast<int32_t*>(C), ldc,
threadpool,
&scale_bias_processor);
}
};
@ -638,7 +644,13 @@ protected:
)
{
const void* PackedB = PackB(N, K, B, ldb, BIsSigned);
MlasGemm(M, N, K, A, lda, offa, PackedB, offb, BIsSigned, C, ldc, &CScale, Bias, threadpool);
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(C, ldc, &CScale, Bias);
MlasGemm(M, N, K,
A, lda, offa,
PackedB, offb, BIsSigned,
reinterpret_cast<int32_t*>(C), ldc,
threadpool,
&scale_bias_processor);
}
private:
@ -2756,6 +2768,92 @@ public:
}
};
class MlasScaleOutputTest : public MlasTestBase
{
private:
MatrixGuardBuffer<int32_t> BufferInput;
MatrixGuardBuffer<float> BufferOutput;
MatrixGuardBuffer<float> BufferOutputRef;
MatrixGuardBuffer<float> BufferScale;
void
Test(
size_t M,
size_t N,
bool PerColumn,
bool AccumulateMode
)
{
int32_t* Input = BufferInput.GetBuffer(M * N);
float* Output = BufferOutput.GetBuffer(M * N);
float* OutputRef = BufferOutputRef.GetBuffer(M * N);
float* Scale = BufferScale.GetBuffer(PerColumn ? N : 1);
std::default_random_engine generator(static_cast<unsigned>(M * N));
std::uniform_real_distribution<float> real_distribution(-1.0f, 1.0f);
std::uniform_int_distribution<int32_t> int_distribution(std::numeric_limits<int16_t>::min(),
std::numeric_limits<int16_t>::max());
for (size_t s = 0; s < M * N; s++) {
Input[s] = int_distribution(generator);
Output[s] = OutputRef[s] = real_distribution(generator);
}
for (size_t s = 0; s < (PerColumn ? N : 1); s++) {
Scale[s] = real_distribution(generator);
}
// Compute Reference Value
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
float current_scale = PerColumn ? Scale[n] : Scale[0];
if (AccumulateMode) {
OutputRef[m * N + n] += Input[m * N + n] * current_scale;
} else {
OutputRef[m * N + n] = Input[m * N + n] * current_scale;
}
}
}
// Compute Output with MLAS
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR OutputProcessor(Output, N, Scale, nullptr,
AccumulateMode ? MLAS_QGEMM_OUTPUT_MODE::AccumulateMode : MLAS_QGEMM_OUTPUT_MODE::ZeroMode,
PerColumn ? MLAS_QUANTIZATION_GRANULARITY::PerColumn : MLAS_QUANTIZATION_GRANULARITY::PerMatrix);
OutputProcessor.Process(Input, 0, 0, M, N, N);
constexpr float epsilon = 1e-6f;
for (size_t n = 0; n < M * N; n++) {
float diff = std::fabs(Output[n] - OutputRef[n]);
if (diff > epsilon) {
printf("MlasScaleOutputTest: Output[%zu][%zu]:%.8f, OutputRef[%zu][%zu]:%.8f, for case M=%zu, N=%zu\n",
n / N, n % N, Output[n], n / N, n % N, OutputRef[n], M, N);
}
}
}
public:
void
ExecuteShort(
void
) override
{
for (size_t m = 1; m < 18; m++) {
for (size_t n = 1; n < 18; n++) {
Test(m, n, true, true);
Test(m, n, true, false);
Test(m, n, false, true);
Test(m, n, false, false);
}
}
}
};
void
RunThreadedTests(
void
@ -2870,6 +2968,9 @@ main(
onnxruntime::make_unique<MlasQLinearBinaryOpTest>(
[] (float a, float b) { return a * b; }, "*", MlasQLinearMul<int8_t>, MlasQLinearMul<uint8_t>)->ExecuteShort();
printf("MlasScaleOutput tests.\n");
onnxruntime::make_unique<MlasScaleOutputTest>()->ExecuteShort();
printf("Done.\n");
return 0;