mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Implement Scale function for quant gemm (#5632)
* Implement a Scale function for quantization Quantized GEMM is always followed by Scaling (PerTensor Or PerColumn), and often need to be accumulated to an existing matrix. This PR implements a post-processor for quantized GEMM result and accumulate it to another matrix.
This commit is contained in:
parent
cca8cd849a
commit
2ba637c558
8 changed files with 483 additions and 387 deletions
|
|
@ -20,6 +20,7 @@ set(mlas_common_srcs
|
|||
${ONNXRUNTIME_ROOT}/core/mlas/lib/quantize.cpp
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/qladd.cpp
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/qlmul.cpp
|
||||
${ONNXRUNTIME_ROOT}/core/mlas/lib/qpostprocessor.cpp
|
||||
)
|
||||
|
||||
if(MSVC)
|
||||
|
|
|
|||
|
|
@ -206,21 +206,26 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
|
|||
if (packed_weights_) {
|
||||
const auto* packed_weight =
|
||||
static_cast<const uint8_t*>(packed_weights_.get()) + packed_weights_size_ * (weights_offset / head_size);
|
||||
|
||||
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(qkv_dest + qkv_offset,
|
||||
head_size,
|
||||
&dequant_scale,
|
||||
bias_data + weights_offset);
|
||||
MlasGemm(
|
||||
sequence_length, // M = S
|
||||
head_size, // N = H
|
||||
hidden_size, // K = NH
|
||||
input_data + input_offset, // A
|
||||
hidden_size, // lda = NH
|
||||
input_zero_point, // input zero point
|
||||
packed_weight, // B
|
||||
weight_zero_point, // weight zero point
|
||||
weights_is_signed, // weight data type
|
||||
qkv_dest + qkv_offset, // C
|
||||
head_size, // ldc
|
||||
&dequant_scale, // output scale
|
||||
bias_data + weights_offset, // bias
|
||||
nullptr); // use single-thread
|
||||
sequence_length, // M = S
|
||||
head_size, // N = H
|
||||
hidden_size, // K = NH
|
||||
input_data + input_offset, // A
|
||||
hidden_size, // lda = NH
|
||||
input_zero_point, // input zero point
|
||||
packed_weight, // B
|
||||
weight_zero_point, // weight zero point
|
||||
weights_is_signed, // weight data type
|
||||
reinterpret_cast<int32_t*>(qkv_dest + qkv_offset), // C
|
||||
head_size, // ldc
|
||||
nullptr, // use single-thread
|
||||
&scale_bias_processor); // output processor
|
||||
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -52,6 +52,10 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
|
|||
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
if (packed_b_) {
|
||||
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(y_data + helper.OutputOffsets()[i],
|
||||
static_cast<size_t>(helper.N()),
|
||||
&multiplier,
|
||||
bias_data);
|
||||
MlasGemm(static_cast<size_t>(helper.M()),
|
||||
static_cast<size_t>(helper.N()),
|
||||
static_cast<size_t>(helper.K()),
|
||||
|
|
@ -61,11 +65,10 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
|
|||
packed_b_.get(),
|
||||
b_zero_point,
|
||||
b_is_signed_,
|
||||
y_data + helper.OutputOffsets()[i],
|
||||
reinterpret_cast<int32_t*>(y_data + helper.OutputOffsets()[i]),
|
||||
static_cast<size_t>(helper.N()),
|
||||
&multiplier,
|
||||
bias_data,
|
||||
thread_pool);
|
||||
thread_pool,
|
||||
&scale_bias_processor);
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -207,6 +207,80 @@ MlasGemm(
|
|||
MLAS_THREADPOOL* ThreadPool
|
||||
);
|
||||
|
||||
enum class MLAS_QUANTIZATION_GRANULARITY {
|
||||
PerMatrix,
|
||||
PerColumn,
|
||||
};
|
||||
|
||||
enum class MLAS_QGEMM_OUTPUT_MODE {
|
||||
ZeroMode, // overwrite the output buffer
|
||||
AccumulateMode, // accumulate to the output buffer
|
||||
};
|
||||
|
||||
class MLAS_QGEMM_OUTPUT_PROCESSOR {
|
||||
public:
|
||||
virtual
|
||||
void
|
||||
Process(
|
||||
const int32_t*, // Supplies the address of matrix to process
|
||||
size_t, // Supplies the start row index of matrix
|
||||
size_t, // Supplies the start col index of matrix
|
||||
size_t, // Supplies the element count per row to process
|
||||
size_t, // Supplies the element count per col to process
|
||||
size_t // Supplies the leading dimension of matrix
|
||||
) const = 0;
|
||||
};
|
||||
|
||||
class MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR : public MLAS_QGEMM_OUTPUT_PROCESSOR {
|
||||
public:
|
||||
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR(
|
||||
float* Output,
|
||||
size_t LeadingDimensionOutput,
|
||||
const float* Scale,
|
||||
const float* Bias,
|
||||
MLAS_QGEMM_OUTPUT_MODE Mode = MLAS_QGEMM_OUTPUT_MODE::ZeroMode,
|
||||
MLAS_QUANTIZATION_GRANULARITY QuantGran = MLAS_QUANTIZATION_GRANULARITY::PerMatrix) :
|
||||
Output_(Output),
|
||||
LeadingDimensionOutput_(LeadingDimensionOutput),
|
||||
Scale_(Scale),
|
||||
Bias_(Bias),
|
||||
OutputMode_(Mode),
|
||||
QuantGran_(QuantGran)
|
||||
{
|
||||
}
|
||||
|
||||
void
|
||||
Process(
|
||||
const int32_t* C,
|
||||
size_t StartM,
|
||||
size_t StartN,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t ldc
|
||||
) const override;
|
||||
|
||||
private:
|
||||
template<bool HasBias, MLAS_QGEMM_OUTPUT_MODE Mode, MLAS_QUANTIZATION_GRANULARITY QuantGran>
|
||||
inline
|
||||
void
|
||||
ProcessImpl(
|
||||
const int32_t* C,
|
||||
size_t StartM,
|
||||
size_t StartN,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t ldc
|
||||
) const;
|
||||
|
||||
private:
|
||||
float* Output_;
|
||||
size_t LeadingDimensionOutput_;
|
||||
const float* Scale_;
|
||||
const float* Bias_;
|
||||
MLAS_QGEMM_OUTPUT_MODE OutputMode_;
|
||||
MLAS_QUANTIZATION_GRANULARITY QuantGran_;
|
||||
};
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemm(
|
||||
|
|
@ -222,27 +296,8 @@ MlasGemm(
|
|||
bool BIsSigned,
|
||||
int32_t* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
);
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemm(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const uint8_t* A,
|
||||
size_t lda,
|
||||
uint8_t offa,
|
||||
const uint8_t* B,
|
||||
size_t ldb,
|
||||
uint8_t offb,
|
||||
bool BIsSigned,
|
||||
float* C,
|
||||
size_t ldc,
|
||||
const float* Scale,
|
||||
const float* Bias,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
MLAS_THREADPOOL* ThreadPool,
|
||||
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr
|
||||
);
|
||||
|
||||
void
|
||||
|
|
@ -259,26 +314,8 @@ MlasGemm(
|
|||
bool BIsSigned,
|
||||
int32_t* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
);
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemm(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const uint8_t* A,
|
||||
size_t lda,
|
||||
uint8_t offa,
|
||||
const void* PackedB,
|
||||
uint8_t offb,
|
||||
bool BIsSigned,
|
||||
float* C,
|
||||
size_t ldc,
|
||||
const float* Scale,
|
||||
const float* Bias,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
MLAS_THREADPOOL* ThreadPool,
|
||||
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr
|
||||
);
|
||||
|
||||
//
|
||||
|
|
|
|||
|
|
@ -40,13 +40,11 @@ struct MLAS_GEMM_U8X8_WORK_BLOCK {
|
|||
size_t ldb;
|
||||
int32_t* C;
|
||||
size_t ldc;
|
||||
const float* Scale;
|
||||
const float* BiasFloat;
|
||||
uint8_t offa;
|
||||
uint8_t offb;
|
||||
bool BIsPacked;
|
||||
bool BIsSigned;
|
||||
bool CIsFloat;
|
||||
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor;
|
||||
};
|
||||
|
||||
//
|
||||
|
|
@ -84,123 +82,6 @@ MlasGemmU8X8ScaleSumBuffer(
|
|||
return MlasGemmU8X8ScaleSumBuffer(SumBuffer, SumBuffer, N, Scale);
|
||||
}
|
||||
|
||||
void
|
||||
MlasGemmU8X8OutputFloat(
|
||||
const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock,
|
||||
int32_t* C,
|
||||
size_t StartN,
|
||||
size_t CountM,
|
||||
size_t CountN
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine converts the output matrix to a floating point format using
|
||||
the supplied scale and bias parameters.
|
||||
|
||||
Arguments:
|
||||
|
||||
WorkBlock - Supplies the structure containing the GEMM parameters.
|
||||
|
||||
C - Supplies the address of matrix C.
|
||||
|
||||
StartN - Supplies the starting column offset relative to the base of the
|
||||
work block. This is used to offset into column vectors accessed via the
|
||||
work block.
|
||||
|
||||
CountM - Supplies the number of rows of the output matrix to process.
|
||||
|
||||
CountN - Supplies the number of columns of the output matrix to process.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
const size_t ldc = WorkBlock->ldc;
|
||||
const MLAS_FLOAT32X4 ScaleVector = MlasBroadcastFloat32x4(WorkBlock->Scale);
|
||||
#if !defined(MLAS_SSE2_INTRINSICS)
|
||||
const float ScaleValue = MlasExtractLaneFloat32x4<0>(ScaleVector);
|
||||
#endif
|
||||
|
||||
//
|
||||
// Check if the optional bias vector was supplied.
|
||||
//
|
||||
|
||||
const float* BiasFloat = WorkBlock->BiasFloat;
|
||||
|
||||
if (BiasFloat != nullptr) {
|
||||
|
||||
BiasFloat += WorkBlock->RangeStartN + StartN;
|
||||
|
||||
while (CountM-- > 0) {
|
||||
|
||||
const float* bias = BiasFloat;
|
||||
int32_t* c = C;
|
||||
size_t n = CountN;
|
||||
|
||||
while (n >= 4) {
|
||||
|
||||
MLAS_FLOAT32X4 FloatVector = MlasCastToFloat32x4(MlasLoadInt32x4(c));
|
||||
FloatVector = MlasMultiplyFloat32x4(FloatVector, ScaleVector);
|
||||
FloatVector = MlasAddFloat32x4(FloatVector, MlasLoadFloat32x4(bias));
|
||||
MlasStoreFloat32x4(reinterpret_cast<float*>(c), FloatVector);
|
||||
|
||||
bias += 4;
|
||||
c += 4;
|
||||
n -= 4;
|
||||
}
|
||||
|
||||
for (size_t offset = 0; offset < n; offset++) {
|
||||
|
||||
#if defined(MLAS_SSE2_INTRINSICS)
|
||||
__m128 FloatVector = _mm_set_ss(float(c[offset]));
|
||||
FloatVector = _mm_mul_ss(FloatVector, ScaleVector);
|
||||
FloatVector = _mm_add_ss(FloatVector, _mm_load_ss(&bias[offset]));
|
||||
_mm_store_ss(reinterpret_cast<float*>(&c[offset]), FloatVector);
|
||||
#else
|
||||
*reinterpret_cast<float*>(&c[offset]) = float(c[offset]) * ScaleValue + bias[offset];
|
||||
#endif
|
||||
}
|
||||
|
||||
C += ldc;
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
while (CountM-- > 0) {
|
||||
|
||||
int32_t* c = C;
|
||||
size_t n = CountN;
|
||||
|
||||
while (n >= 4) {
|
||||
|
||||
MLAS_FLOAT32X4 FloatVector = MlasCastToFloat32x4(MlasLoadInt32x4(c));
|
||||
FloatVector = MlasMultiplyFloat32x4(FloatVector, ScaleVector);
|
||||
MlasStoreFloat32x4(reinterpret_cast<float*>(c), FloatVector);
|
||||
|
||||
c += 4;
|
||||
n -= 4;
|
||||
}
|
||||
|
||||
for (size_t offset = 0; offset < n; offset++) {
|
||||
|
||||
#if defined(MLAS_SSE2_INTRINSICS)
|
||||
__m128 FloatVector = _mm_set_ss((float)c[offset]);
|
||||
FloatVector = _mm_mul_ss(FloatVector, ScaleVector);
|
||||
_mm_store_ss(reinterpret_cast<float*>(&c[offset]), FloatVector);
|
||||
#else
|
||||
*reinterpret_cast<float*>(&c[offset]) = float(c[offset]) * ScaleValue;
|
||||
#endif
|
||||
}
|
||||
|
||||
C += ldc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename KernelType>
|
||||
void
|
||||
MLASCALL
|
||||
|
|
@ -251,7 +132,7 @@ Return Value:
|
|||
// Try to use a GEMV kernel if supported by this kernel type.
|
||||
//
|
||||
|
||||
if ((M == 1) && (offa == 0) && (offb == 0) && !WorkBlock->CIsFloat) {
|
||||
if ((M == 1) && (offa == 0) && (offb == 0) && WorkBlock->OutputProcessor == nullptr) {
|
||||
if (KernelType::TryGemvKernel(A, B, ldb, C, K, N, WorkBlock->BIsSigned)) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -346,8 +227,13 @@ Return Value:
|
|||
RowsRemaining, CountN, ldc, RowSums, ColumnSumBuffer,
|
||||
DepthValue, ZeroMode);
|
||||
|
||||
if (PostProcess && WorkBlock->CIsFloat) {
|
||||
MlasGemmU8X8OutputFloat(WorkBlock, c, n, RowsHandled, CountN);
|
||||
if (PostProcess && WorkBlock->OutputProcessor != nullptr) {
|
||||
WorkBlock->OutputProcessor->Process(WorkBlock->C,
|
||||
WorkBlock->RangeStartM + m + CountM - RowsRemaining,
|
||||
WorkBlock->RangeStartN + n,
|
||||
RowsHandled,
|
||||
CountN,
|
||||
WorkBlock->ldc);
|
||||
}
|
||||
|
||||
c += ldc * RowsHandled;
|
||||
|
|
@ -508,8 +394,14 @@ Return Value:
|
|||
RowsRemaining, CountN, ldc, RowSums, ColumnSumBuffer,
|
||||
DepthValue, ZeroMode);
|
||||
|
||||
if (PostProcess && WorkBlock->CIsFloat) {
|
||||
MlasGemmU8X8OutputFloat(WorkBlock, c, n, RowsHandled, CountN);
|
||||
if (PostProcess && WorkBlock->OutputProcessor != nullptr) {
|
||||
WorkBlock->OutputProcessor->Process(
|
||||
WorkBlock->C,
|
||||
WorkBlock->RangeStartM + m + CountM - RowsRemaining,
|
||||
WorkBlock->RangeStartN + n,
|
||||
RowsHandled,
|
||||
CountN,
|
||||
WorkBlock->ldc);
|
||||
}
|
||||
|
||||
c += ldc * RowsHandled;
|
||||
|
|
@ -2187,7 +2079,8 @@ MlasGemm(
|
|||
bool BIsSigned,
|
||||
int32_t* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
MLAS_THREADPOOL* ThreadPool,
|
||||
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor
|
||||
)
|
||||
/*++
|
||||
|
||||
|
|
@ -2227,6 +2120,8 @@ Arguments:
|
|||
ThreadPool - Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
|
||||
OutputProcessor - Post Processor on C.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
|
@ -2250,6 +2145,7 @@ Return Value:
|
|||
WorkBlock.ldb = ldb;
|
||||
WorkBlock.C = C;
|
||||
WorkBlock.ldc = ldc;
|
||||
WorkBlock.OutputProcessor = OutputProcessor;
|
||||
WorkBlock.offa = offa;
|
||||
WorkBlock.offb = offb;
|
||||
WorkBlock.BIsSigned = BIsSigned;
|
||||
|
|
@ -2261,107 +2157,6 @@ Return Value:
|
|||
MlasGemmU8X8Schedule(&WorkBlock, ThreadPool);
|
||||
}
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemm(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const uint8_t* A,
|
||||
size_t lda,
|
||||
uint8_t offa,
|
||||
const uint8_t* B,
|
||||
size_t ldb,
|
||||
uint8_t offb,
|
||||
bool BIsSigned,
|
||||
float* C,
|
||||
size_t ldc,
|
||||
const float* Scale,
|
||||
const float* Bias,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine implements the quantized integer matrix/matrix multiply
|
||||
operation (QGEMM).
|
||||
|
||||
Arguments:
|
||||
|
||||
M - Supplies the number of rows of matrix A and matrix C.
|
||||
|
||||
N - Supplies the number of columns of matrix B and matrix C.
|
||||
|
||||
K - Supplies the number of columns of matrix A and the number of rows of
|
||||
matrix B.
|
||||
|
||||
A - Supplies the address of matrix A.
|
||||
|
||||
lda - Supplies the first dimension of matrix A.
|
||||
|
||||
offa - Supplies the zero point offset of matrix A.
|
||||
|
||||
B - Supplies the address of matrix B.
|
||||
|
||||
ldb - Supplies the first dimension of matrix B.
|
||||
|
||||
offb - Supplies the zero point offset of matrix B.
|
||||
|
||||
BIsSigned - Supplies true if matrix B is signed data, else false if matrix
|
||||
B is unsigned data.
|
||||
|
||||
C - Supplies the address of matrix C.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
Scale - Supplies the scale multiplier to apply to each element of matrix C.
|
||||
Used to scale the integer output of the QGEMM back to a floating point
|
||||
number.
|
||||
|
||||
Bias - Supplies the bias vector to apply to element of matrix C. The vector
|
||||
is of length N.
|
||||
|
||||
ThreadPool - Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
MLAS_GEMM_U8X8_WORK_BLOCK WorkBlock;
|
||||
|
||||
//
|
||||
// Capture the GEMM parameters to the work block.
|
||||
//
|
||||
|
||||
memset(&WorkBlock, 0, sizeof(MLAS_GEMM_U8X8_WORK_BLOCK));
|
||||
|
||||
WorkBlock.M = M;
|
||||
WorkBlock.N = N;
|
||||
WorkBlock.K = K;
|
||||
WorkBlock.A = A;
|
||||
WorkBlock.lda = lda;
|
||||
WorkBlock.B = B;
|
||||
WorkBlock.ldb = ldb;
|
||||
WorkBlock.C = (int32_t*)C;
|
||||
WorkBlock.ldc = ldc;
|
||||
WorkBlock.Scale = Scale;
|
||||
WorkBlock.BiasFloat = Bias;
|
||||
WorkBlock.offa = offa;
|
||||
WorkBlock.offb = offb;
|
||||
WorkBlock.BIsSigned = BIsSigned;
|
||||
WorkBlock.CIsFloat = true;
|
||||
|
||||
//
|
||||
// Schedule the operation across a set of worker threads.
|
||||
//
|
||||
|
||||
MlasGemmU8X8Schedule(&WorkBlock, ThreadPool);
|
||||
}
|
||||
|
||||
#endif // MLAS_SUPPORTS_GEMM_U8X8
|
||||
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
|
|
@ -2380,7 +2175,8 @@ MlasGemm(
|
|||
bool BIsSigned,
|
||||
int32_t* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
MLAS_THREADPOOL* ThreadPool,
|
||||
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor
|
||||
)
|
||||
/*++
|
||||
|
||||
|
|
@ -2418,100 +2214,7 @@ Arguments:
|
|||
ThreadPool - Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
MLAS_GEMM_U8X8_WORK_BLOCK WorkBlock;
|
||||
|
||||
//
|
||||
// Capture the GEMM parameters to the work block.
|
||||
//
|
||||
|
||||
memset(&WorkBlock, 0, sizeof(MLAS_GEMM_U8X8_WORK_BLOCK));
|
||||
|
||||
WorkBlock.M = M;
|
||||
WorkBlock.N = N;
|
||||
WorkBlock.K = K;
|
||||
WorkBlock.A = A;
|
||||
WorkBlock.lda = lda;
|
||||
WorkBlock.B = PackedB;
|
||||
WorkBlock.C = C;
|
||||
WorkBlock.ldc = ldc;
|
||||
WorkBlock.offa = offa;
|
||||
WorkBlock.offb = offb;
|
||||
WorkBlock.BIsPacked = true;
|
||||
WorkBlock.BIsSigned = BIsSigned;
|
||||
|
||||
//
|
||||
// Schedule the operation across a set of worker threads.
|
||||
//
|
||||
|
||||
MlasGemmU8X8Schedule(&WorkBlock, ThreadPool);
|
||||
}
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemm(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const uint8_t* A,
|
||||
size_t lda,
|
||||
uint8_t offa,
|
||||
const void* PackedB,
|
||||
uint8_t offb,
|
||||
bool BIsSigned,
|
||||
float* C,
|
||||
size_t ldc,
|
||||
const float* Scale,
|
||||
const float* Bias,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine implements the quantized integer matrix/matrix multiply
|
||||
operation (QGEMM).
|
||||
|
||||
Arguments:
|
||||
|
||||
M - Supplies the number of rows of matrix A and matrix C.
|
||||
|
||||
N - Supplies the number of columns of matrix B and matrix C.
|
||||
|
||||
K - Supplies the number of columns of matrix A and the number of rows of
|
||||
matrix B.
|
||||
|
||||
A - Supplies the address of matrix A.
|
||||
|
||||
lda - Supplies the first dimension of matrix A.
|
||||
|
||||
offa - Supplies the zero point offset of matrix A.
|
||||
|
||||
PackedB - Supplies the address of packed matrix B.
|
||||
|
||||
offb - Supplies the zero point offset of matrix B.
|
||||
|
||||
BIsSigned - Supplies true if matrix B is signed data, else false if matrix
|
||||
B is unsigned data.
|
||||
|
||||
C - Supplies the address of matrix C.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
Scale - Supplies the scale multiplier to apply to each element of matrix C.
|
||||
Used to scale the integer output of the QGEMM back to a floating point
|
||||
number.
|
||||
|
||||
Bias - Supplies the bias vector to apply to element of matrix C. The vector
|
||||
is of length N.
|
||||
|
||||
ThreadPool - Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
OutputProcessor - Post Processor on C
|
||||
|
||||
Return Value:
|
||||
|
||||
|
|
@ -2535,13 +2238,11 @@ Return Value:
|
|||
WorkBlock.B = PackedB;
|
||||
WorkBlock.C = (int32_t*)C;
|
||||
WorkBlock.ldc = ldc;
|
||||
WorkBlock.Scale = Scale;
|
||||
WorkBlock.BiasFloat = Bias;
|
||||
WorkBlock.OutputProcessor = OutputProcessor,
|
||||
WorkBlock.offa = offa;
|
||||
WorkBlock.offb = offb;
|
||||
WorkBlock.BIsPacked = true;
|
||||
WorkBlock.BIsSigned = BIsSigned;
|
||||
WorkBlock.CIsFloat = true;
|
||||
|
||||
//
|
||||
// Schedule the operation across a set of worker threads.
|
||||
|
|
|
|||
242
onnxruntime/core/mlas/lib/qpostprocessor.cpp
Normal file
242
onnxruntime/core/mlas/lib/qpostprocessor.cpp
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
qpostprocessor.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the post processor for QGEMM.
|
||||
|
||||
--*/
|
||||
|
||||
#include "mlasi.h"
|
||||
|
||||
#ifdef MLAS_SUPPORTS_GEMM_U8X8
|
||||
|
||||
void MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR::Process(
|
||||
const int32_t* C,
|
||||
size_t StartM,
|
||||
size_t StartN,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t ldc
|
||||
) const
|
||||
{
|
||||
if (Bias_) {
|
||||
if (QuantGran_ == MLAS_QUANTIZATION_GRANULARITY::PerColumn) {
|
||||
if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) {
|
||||
ProcessImpl<true, MLAS_QGEMM_OUTPUT_MODE::AccumulateMode, MLAS_QUANTIZATION_GRANULARITY::PerColumn>(
|
||||
C,
|
||||
StartM,
|
||||
StartN,
|
||||
CountM,
|
||||
CountN,
|
||||
ldc);
|
||||
} else {
|
||||
ProcessImpl<true, MLAS_QGEMM_OUTPUT_MODE::ZeroMode, MLAS_QUANTIZATION_GRANULARITY::PerColumn>(
|
||||
C,
|
||||
StartM,
|
||||
StartN,
|
||||
CountM,
|
||||
CountN,
|
||||
ldc);
|
||||
}
|
||||
} else if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) {
|
||||
ProcessImpl<true, MLAS_QGEMM_OUTPUT_MODE::AccumulateMode, MLAS_QUANTIZATION_GRANULARITY::PerMatrix>(
|
||||
C,
|
||||
StartM,
|
||||
StartN,
|
||||
CountM,
|
||||
CountN,
|
||||
ldc);
|
||||
} else {
|
||||
ProcessImpl<true, MLAS_QGEMM_OUTPUT_MODE::ZeroMode, MLAS_QUANTIZATION_GRANULARITY::PerMatrix>(
|
||||
C,
|
||||
StartM,
|
||||
StartN,
|
||||
CountM,
|
||||
CountN,
|
||||
ldc);
|
||||
}
|
||||
} else {
|
||||
if (QuantGran_ == MLAS_QUANTIZATION_GRANULARITY::PerColumn) {
|
||||
if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) {
|
||||
ProcessImpl<false, MLAS_QGEMM_OUTPUT_MODE::AccumulateMode, MLAS_QUANTIZATION_GRANULARITY::PerColumn>(
|
||||
C,
|
||||
StartM,
|
||||
StartN,
|
||||
CountM,
|
||||
CountN,
|
||||
ldc);
|
||||
} else {
|
||||
ProcessImpl<false, MLAS_QGEMM_OUTPUT_MODE::ZeroMode, MLAS_QUANTIZATION_GRANULARITY::PerColumn>(
|
||||
C,
|
||||
StartM,
|
||||
StartN,
|
||||
CountM,
|
||||
CountN,
|
||||
ldc);
|
||||
}
|
||||
} else if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) {
|
||||
ProcessImpl<false, MLAS_QGEMM_OUTPUT_MODE::AccumulateMode, MLAS_QUANTIZATION_GRANULARITY::PerMatrix>(
|
||||
C,
|
||||
StartM,
|
||||
StartN,
|
||||
CountM,
|
||||
CountN,
|
||||
ldc);
|
||||
} else {
|
||||
ProcessImpl<false, MLAS_QGEMM_OUTPUT_MODE::ZeroMode, MLAS_QUANTIZATION_GRANULARITY::PerMatrix>(
|
||||
C,
|
||||
StartM,
|
||||
StartN,
|
||||
CountM,
|
||||
CountN,
|
||||
ldc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool HasBias, MLAS_QGEMM_OUTPUT_MODE Mode, MLAS_QUANTIZATION_GRANULARITY QuantGran>
|
||||
inline
|
||||
void
|
||||
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR::ProcessImpl(
|
||||
const int32_t* C,
|
||||
size_t StartM,
|
||||
size_t StartN,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t ldc) const
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine converts the output matrix C to a floating point format using
|
||||
the stored scale and bias parameters.
|
||||
|
||||
Arguments:
|
||||
|
||||
C - Supplies the address of matrix C.
|
||||
|
||||
StartM - Supplies the starting row offset relative to the matrix.
|
||||
|
||||
StartN - Supplies the starting column offset relative to the matrix.
|
||||
|
||||
CountM - Supplies the number of rows of the output matrix to process.
|
||||
|
||||
CountN - Supplies the number of columns of the output matrix to process.
|
||||
|
||||
ldc - Supplies the leading dimension of C.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
float* Output = Output_;
|
||||
const float* Bias = Bias_;
|
||||
const float* Scale = Scale_;
|
||||
|
||||
if (HasBias) {
|
||||
Bias += StartN;
|
||||
}
|
||||
|
||||
if(QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn){
|
||||
Scale += StartN;
|
||||
}
|
||||
|
||||
MLAS_FLOAT32X4 ScaleVector = MlasBroadcastFloat32x4(Scale_);
|
||||
#if !defined(MLAS_SSE2_INTRINSICS)
|
||||
float ScaleValue = MlasExtractLaneFloat32x4<0>(ScaleVector);
|
||||
#endif
|
||||
|
||||
C += StartM * ldc + StartN;
|
||||
Output += StartM * LeadingDimensionOutput_ + StartN;
|
||||
|
||||
|
||||
while (CountM-- > 0) {
|
||||
|
||||
float* c_out = Output;
|
||||
const int32_t* c = C;
|
||||
const float* bias = Bias;
|
||||
const float* scale = Scale;
|
||||
|
||||
size_t n = CountN;
|
||||
|
||||
while (n >= 4) {
|
||||
|
||||
MLAS_FLOAT32X4 FloatVector = MlasCastToFloat32x4(MlasLoadInt32x4(c));
|
||||
|
||||
if (QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn) {
|
||||
ScaleVector = MlasLoadFloat32x4(scale);
|
||||
scale += 4;
|
||||
}
|
||||
|
||||
if (Mode == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) {
|
||||
FloatVector = MlasMultiplyAddFloat32x4(FloatVector, ScaleVector, MlasLoadFloat32x4(c_out));
|
||||
} else {
|
||||
FloatVector = MlasMultiplyFloat32x4(FloatVector, ScaleVector);
|
||||
}
|
||||
|
||||
if (HasBias) {
|
||||
FloatVector = MlasAddFloat32x4(FloatVector, MlasLoadFloat32x4(bias));
|
||||
bias += 4;
|
||||
}
|
||||
|
||||
MlasStoreFloat32x4(c_out, FloatVector);
|
||||
|
||||
c_out += 4;
|
||||
c += 4;
|
||||
n -= 4;
|
||||
}
|
||||
|
||||
for (size_t offset = 0; offset < n; offset++) {
|
||||
|
||||
#if defined(MLAS_SSE2_INTRINSICS)
|
||||
__m128 FloatVector = _mm_set_ss(float(c[offset]));
|
||||
|
||||
if (QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn) {
|
||||
ScaleVector = _mm_load_ss(&scale[offset]);
|
||||
}
|
||||
|
||||
if (Mode == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) {
|
||||
FloatVector = _mm_add_ps(_mm_mul_ss(FloatVector, ScaleVector), _mm_load_ss(&c_out[offset]));
|
||||
} else {
|
||||
FloatVector = _mm_mul_ss(FloatVector, ScaleVector);
|
||||
}
|
||||
|
||||
if (HasBias) {
|
||||
FloatVector = _mm_add_ss(FloatVector, _mm_load_ss(&bias[offset]));
|
||||
}
|
||||
|
||||
_mm_store_ss(&c_out[offset], FloatVector);
|
||||
#else
|
||||
if (QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn) {
|
||||
ScaleValue = scale[offset];
|
||||
}
|
||||
|
||||
float result = float(c[offset]) * ScaleValue;
|
||||
if (HasBias) {
|
||||
result += bias[offset];
|
||||
}
|
||||
|
||||
if (Mode == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) {
|
||||
c_out[offset] += result;
|
||||
} else {
|
||||
c_out[offset] = result;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
C += ldc;
|
||||
Output += LeadingDimensionOutput_;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // MLAS_SUPPORTS_GEMM_U8X8
|
||||
|
|
@ -85,7 +85,13 @@ void QGemm(
|
|||
const float* bias,
|
||||
concurrency::ThreadPool* thread_pool) {
|
||||
#ifdef MLAS_SUPPORTS_GEMM_U8X8
|
||||
MlasGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, rhs_signed, result_data, ldc, result_scale, bias, thread_pool);
|
||||
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(result_data, ldc, result_scale, bias);
|
||||
MlasGemm(M, N, K,
|
||||
lhs_data, lda, lhs_offset,
|
||||
rhs_data, ldb, rhs_offset, rhs_signed,
|
||||
reinterpret_cast<int32_t*>(result_data), ldc,
|
||||
thread_pool,
|
||||
&scale_bias_processor);
|
||||
#else
|
||||
QGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, rhs_signed, reinterpret_cast<int32_t*>(result_data), ldc, thread_pool);
|
||||
for (int m = 0; m < M; m++) {
|
||||
|
|
|
|||
|
|
@ -573,7 +573,13 @@ protected:
|
|||
const float* Bias
|
||||
)
|
||||
{
|
||||
MlasGemm(M, N, K, A, lda, offa, B, ldb, offb, BIsSigned, C, ldc, &CScale, Bias, threadpool);
|
||||
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(C, ldc, &CScale, Bias);
|
||||
MlasGemm(M, N, K,
|
||||
A, lda, offa,
|
||||
B, ldb, offb, BIsSigned,
|
||||
reinterpret_cast<int32_t*>(C), ldc,
|
||||
threadpool,
|
||||
&scale_bias_processor);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -638,7 +644,13 @@ protected:
|
|||
)
|
||||
{
|
||||
const void* PackedB = PackB(N, K, B, ldb, BIsSigned);
|
||||
MlasGemm(M, N, K, A, lda, offa, PackedB, offb, BIsSigned, C, ldc, &CScale, Bias, threadpool);
|
||||
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(C, ldc, &CScale, Bias);
|
||||
MlasGemm(M, N, K,
|
||||
A, lda, offa,
|
||||
PackedB, offb, BIsSigned,
|
||||
reinterpret_cast<int32_t*>(C), ldc,
|
||||
threadpool,
|
||||
&scale_bias_processor);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -2756,6 +2768,92 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class MlasScaleOutputTest : public MlasTestBase
|
||||
{
|
||||
private:
|
||||
MatrixGuardBuffer<int32_t> BufferInput;
|
||||
MatrixGuardBuffer<float> BufferOutput;
|
||||
MatrixGuardBuffer<float> BufferOutputRef;
|
||||
MatrixGuardBuffer<float> BufferScale;
|
||||
|
||||
void
|
||||
Test(
|
||||
size_t M,
|
||||
size_t N,
|
||||
bool PerColumn,
|
||||
bool AccumulateMode
|
||||
)
|
||||
{
|
||||
int32_t* Input = BufferInput.GetBuffer(M * N);
|
||||
float* Output = BufferOutput.GetBuffer(M * N);
|
||||
float* OutputRef = BufferOutputRef.GetBuffer(M * N);
|
||||
float* Scale = BufferScale.GetBuffer(PerColumn ? N : 1);
|
||||
|
||||
std::default_random_engine generator(static_cast<unsigned>(M * N));
|
||||
std::uniform_real_distribution<float> real_distribution(-1.0f, 1.0f);
|
||||
std::uniform_int_distribution<int32_t> int_distribution(std::numeric_limits<int16_t>::min(),
|
||||
std::numeric_limits<int16_t>::max());
|
||||
|
||||
for (size_t s = 0; s < M * N; s++) {
|
||||
|
||||
Input[s] = int_distribution(generator);
|
||||
Output[s] = OutputRef[s] = real_distribution(generator);
|
||||
}
|
||||
|
||||
for (size_t s = 0; s < (PerColumn ? N : 1); s++) {
|
||||
Scale[s] = real_distribution(generator);
|
||||
}
|
||||
|
||||
// Compute Reference Value
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
|
||||
float current_scale = PerColumn ? Scale[n] : Scale[0];
|
||||
if (AccumulateMode) {
|
||||
OutputRef[m * N + n] += Input[m * N + n] * current_scale;
|
||||
} else {
|
||||
OutputRef[m * N + n] = Input[m * N + n] * current_scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute Output with MLAS
|
||||
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR OutputProcessor(Output, N, Scale, nullptr,
|
||||
AccumulateMode ? MLAS_QGEMM_OUTPUT_MODE::AccumulateMode : MLAS_QGEMM_OUTPUT_MODE::ZeroMode,
|
||||
PerColumn ? MLAS_QUANTIZATION_GRANULARITY::PerColumn : MLAS_QUANTIZATION_GRANULARITY::PerMatrix);
|
||||
OutputProcessor.Process(Input, 0, 0, M, N, N);
|
||||
|
||||
constexpr float epsilon = 1e-6f;
|
||||
|
||||
for (size_t n = 0; n < M * N; n++) {
|
||||
|
||||
float diff = std::fabs(Output[n] - OutputRef[n]);
|
||||
if (diff > epsilon) {
|
||||
printf("MlasScaleOutputTest: Output[%zu][%zu]:%.8f, OutputRef[%zu][%zu]:%.8f, for case M=%zu, N=%zu\n",
|
||||
n / N, n % N, Output[n], n / N, n % N, OutputRef[n], M, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
void
|
||||
ExecuteShort(
|
||||
void
|
||||
) override
|
||||
{
|
||||
for (size_t m = 1; m < 18; m++) {
|
||||
|
||||
for (size_t n = 1; n < 18; n++) {
|
||||
|
||||
Test(m, n, true, true);
|
||||
Test(m, n, true, false);
|
||||
Test(m, n, false, true);
|
||||
Test(m, n, false, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void
|
||||
RunThreadedTests(
|
||||
void
|
||||
|
|
@ -2870,6 +2968,9 @@ main(
|
|||
onnxruntime::make_unique<MlasQLinearBinaryOpTest>(
|
||||
[] (float a, float b) { return a * b; }, "*", MlasQLinearMul<int8_t>, MlasQLinearMul<uint8_t>)->ExecuteShort();
|
||||
|
||||
printf("MlasScaleOutput tests.\n");
|
||||
onnxruntime::make_unique<MlasScaleOutputTest>()->ExecuteShort();
|
||||
|
||||
printf("Done.\n");
|
||||
|
||||
return 0;
|
||||
|
|
|
|||
Loading…
Reference in a new issue