mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
* Implement a Scale function for quantization Quantized GEMM is always followed by Scaling (PerTensor Or PerColumn), and often need to be accumulated to an existing matrix. This PR implements a post-processor for quantized GEMM result and accumulate it to another matrix.
2460 lines
64 KiB
C++
2460 lines
64 KiB
C++
/*++
|
|
|
|
Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
Licensed under the MIT License.
|
|
|
|
Module Name:
|
|
|
|
qgemm.cpp
|
|
|
|
Abstract:
|
|
|
|
This module implements the quantized integer matrix/matrix multiply
|
|
operation (QGEMM).
|
|
|
|
--*/
|
|
|
|
#include "mlasi.h"
|
|
|
|
#ifdef MLAS_SUPPORTS_GEMM_U8X8
|
|
|
|
//
|
|
// Define the parameters to execute segments of a QGEMM operation on worker
|
|
// threads.
|
|
//
|
|
|
|
struct MLAS_GEMM_U8X8_WORK_BLOCK {
|
|
int32_t ThreadCountM;
|
|
int32_t ThreadCountN;
|
|
size_t RangeStartM;
|
|
size_t RangeStartN;
|
|
size_t RangeCountM;
|
|
size_t RangeCountN;
|
|
size_t M;
|
|
size_t N;
|
|
size_t K;
|
|
const uint8_t* A;
|
|
size_t lda;
|
|
const void* B;
|
|
size_t ldb;
|
|
int32_t* C;
|
|
size_t ldc;
|
|
uint8_t offa;
|
|
uint8_t offb;
|
|
bool BIsPacked;
|
|
bool BIsSigned;
|
|
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor;
|
|
};
|
|
|
|
//
|
|
// Define the default striding parameters used for the quantized integer
|
|
// matrix/matrix multiply operation.
|
|
//
|
|
|
|
struct MLAS_GEMM_U8X8_STRIDES {
|
|
size_t M;
|
|
size_t N;
|
|
size_t K;
|
|
};
|
|
|
|
void
|
|
MlasGemmU8X8ScaleSumBuffer(
|
|
int32_t* Output,
|
|
const int32_t* Input,
|
|
size_t N,
|
|
int32_t Scale
|
|
)
|
|
{
|
|
for (size_t n = 0; n < N; n++) {
|
|
Output[n] = Input[n] * Scale;
|
|
}
|
|
}
|
|
|
|
MLAS_FORCEINLINE
|
|
void
|
|
MlasGemmU8X8ScaleSumBuffer(
|
|
int32_t* SumBuffer,
|
|
size_t N,
|
|
int32_t Scale
|
|
)
|
|
{
|
|
return MlasGemmU8X8ScaleSumBuffer(SumBuffer, SumBuffer, N, Scale);
|
|
}
|
|
|
|
template<typename KernelType>
|
|
void
|
|
MLASCALL
|
|
MlasGemmU8X8Operation(
|
|
const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock
|
|
)
|
|
/*++
|
|
|
|
Routine Description:
|
|
|
|
This routine implements the quantized integer matrix/matrix multiply
|
|
operation (QGEMM).
|
|
|
|
Arguments:
|
|
|
|
WorkBlock - Supplies the structure containing the GEMM parameters.
|
|
|
|
Return Value:
|
|
|
|
None.
|
|
|
|
--*/
|
|
{
|
|
constexpr MLAS_GEMM_U8X8_STRIDES Strides = KernelType::Strides;
|
|
|
|
MLAS_DECLSPEC_ALIGN(typename KernelType::PackedAType PanelA[Strides.M * Strides.K], 64);
|
|
MLAS_DECLSPEC_ALIGN(typename KernelType::PackedBType PanelB[Strides.N * Strides.K], 64);
|
|
|
|
MLAS_DECLSPEC_ALIGN(int32_t RowSumBuffer[Strides.M], 64);
|
|
MLAS_DECLSPEC_ALIGN(int32_t ColumnSumBuffer[Strides.N], 64);
|
|
|
|
const size_t M = WorkBlock->RangeCountM;
|
|
const size_t N = WorkBlock->RangeCountN;
|
|
const size_t K = WorkBlock->K;
|
|
|
|
const size_t lda = WorkBlock->lda;
|
|
const size_t ldb = WorkBlock->ldb;
|
|
const size_t ldc = WorkBlock->ldc;
|
|
|
|
const uint8_t* A = WorkBlock->A + WorkBlock->RangeStartM * lda;
|
|
const uint8_t* B = (const uint8_t*)WorkBlock->B + WorkBlock->RangeStartN;
|
|
int32_t* C = WorkBlock->C + WorkBlock->RangeStartM * ldc + WorkBlock->RangeStartN;
|
|
|
|
int32_t offa = WorkBlock->offa;
|
|
int32_t offb = typename KernelType::OffsetBType(WorkBlock->offb);
|
|
|
|
//
|
|
// Try to use a GEMV kernel if supported by this kernel type.
|
|
//
|
|
|
|
if ((M == 1) && (offa == 0) && (offb == 0) && WorkBlock->OutputProcessor == nullptr) {
|
|
if (KernelType::TryGemvKernel(A, B, ldb, C, K, N, WorkBlock->BIsSigned)) {
|
|
return;
|
|
}
|
|
}
|
|
|
|
//
|
|
// Flip the sign bit of the zero point offset of matrix B if the kernel uses
|
|
// signed types and the matrix B data is unsigned.
|
|
//
|
|
|
|
#if defined(MLAS_SSE2_INTRINSICS)
|
|
if (std::is_signed<typename KernelType::OffsetBType>::value) {
|
|
if (!WorkBlock->BIsSigned) {
|
|
offb = typename KernelType::OffsetBType(offb ^ 0x80);
|
|
}
|
|
}
|
|
#elif defined(MLAS_NEON_INTRINSICS)
|
|
if (WorkBlock->BIsSigned) {
|
|
offb = typename KernelType::OffsetBType(offb ^ 0x80);
|
|
}
|
|
#endif
|
|
|
|
//
|
|
// Step through each slice of matrix B along the K dimension.
|
|
//
|
|
|
|
size_t CountK;
|
|
|
|
for (size_t k = 0; k < K; k += CountK) {
|
|
|
|
CountK = std::min(K - k, Strides.K);
|
|
|
|
//
|
|
// Step through each slice of matrix B along the N dimension.
|
|
//
|
|
|
|
size_t CountN;
|
|
|
|
for (size_t n = 0; n < N; n += CountN) {
|
|
|
|
CountN = std::min(N - n, Strides.N);
|
|
|
|
//
|
|
// Copy a panel of matrix B to a local packed buffer.
|
|
//
|
|
|
|
KernelType::CopyPackB(PanelB, B + n, ldb, CountN, CountK,
|
|
ColumnSumBuffer, WorkBlock->BIsSigned);
|
|
|
|
MlasGemmU8X8ScaleSumBuffer(ColumnSumBuffer, CountN, -offa);
|
|
|
|
//
|
|
// Step through each slice of matrix A along the M dimension.
|
|
//
|
|
|
|
const int32_t DepthValue = int32_t(CountK) * offa * offb;
|
|
const size_t PackedCountK = (CountK + KernelType::PackedK - 1) /
|
|
KernelType::PackedK;
|
|
|
|
int32_t* c = C + n;
|
|
size_t CountM;
|
|
|
|
for (size_t m = 0; m < M; m += CountM) {
|
|
|
|
CountM = std::min(M - m, Strides.M);
|
|
|
|
//
|
|
// Copy a panel of matrix A to a local packed buffer.
|
|
//
|
|
|
|
KernelType::CopyPackA(PanelA, A + m * lda, lda, CountM, CountK,
|
|
RowSumBuffer);
|
|
|
|
MlasGemmU8X8ScaleSumBuffer(RowSumBuffer, CountM, -offb);
|
|
|
|
//
|
|
// Step through the rows of the local packed buffer.
|
|
//
|
|
|
|
typename KernelType::PackedAType* pa = PanelA;
|
|
int32_t* RowSums = RowSumBuffer;
|
|
size_t RowsRemaining = CountM;
|
|
|
|
bool ZeroMode = (k == 0);
|
|
bool PostProcess = (k + CountK == K);
|
|
|
|
while (RowsRemaining > 0) {
|
|
|
|
size_t RowsHandled;
|
|
|
|
RowsHandled = KernelType::GemmKernel(pa, PanelB, c, PackedCountK,
|
|
RowsRemaining, CountN, ldc, RowSums, ColumnSumBuffer,
|
|
DepthValue, ZeroMode);
|
|
|
|
if (PostProcess && WorkBlock->OutputProcessor != nullptr) {
|
|
WorkBlock->OutputProcessor->Process(WorkBlock->C,
|
|
WorkBlock->RangeStartM + m + CountM - RowsRemaining,
|
|
WorkBlock->RangeStartN + n,
|
|
RowsHandled,
|
|
CountN,
|
|
WorkBlock->ldc);
|
|
}
|
|
|
|
c += ldc * RowsHandled;
|
|
pa += KernelType::PackedK * PackedCountK * RowsHandled;
|
|
RowSums += RowsHandled;
|
|
RowsRemaining -= RowsHandled;
|
|
}
|
|
}
|
|
}
|
|
|
|
A += CountK;
|
|
B += CountK * ldb;
|
|
}
|
|
}
|
|
|
|
template<typename KernelType>
|
|
void
|
|
MLASCALL
|
|
MlasGemmU8X8PackedOperation(
|
|
const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock
|
|
)
|
|
/*++
|
|
|
|
Routine Description:
|
|
|
|
This routine implements the quantized integer matrix/matrix multiply
|
|
operation (QGEMM).
|
|
|
|
Arguments:
|
|
|
|
WorkBlock - Supplies the structure containing the GEMM parameters.
|
|
|
|
Return Value:
|
|
|
|
None.
|
|
|
|
--*/
|
|
{
|
|
constexpr MLAS_GEMM_U8X8_STRIDES Strides = KernelType::PackedStrides;
|
|
|
|
MLAS_DECLSPEC_ALIGN(typename KernelType::PackedAType PanelA[Strides.M * Strides.K], 64);
|
|
|
|
MLAS_DECLSPEC_ALIGN(int32_t RowSumBuffer[Strides.M], 64);
|
|
MLAS_DECLSPEC_ALIGN(int32_t ColumnSumBuffer[Strides.N], 64);
|
|
|
|
const size_t M = WorkBlock->RangeCountM;
|
|
const size_t N = WorkBlock->RangeCountN;
|
|
const size_t K = WorkBlock->K;
|
|
|
|
const size_t lda = WorkBlock->lda;
|
|
const size_t ldc = WorkBlock->ldc;
|
|
|
|
const uint8_t* A = WorkBlock->A + WorkBlock->RangeStartM * lda;
|
|
const uint8_t* PackedB = (const uint8_t*)WorkBlock->B;
|
|
int32_t* C = WorkBlock->C + WorkBlock->RangeStartM * ldc + WorkBlock->RangeStartN;
|
|
|
|
int32_t offa = WorkBlock->offa;
|
|
int32_t offb = typename KernelType::OffsetBType(WorkBlock->offb);
|
|
|
|
//
|
|
// Flip the sign bit of the zero point offset of matrix B if the kernel uses
|
|
// signed types and the matrix B data is unsigned.
|
|
//
|
|
|
|
#if defined(MLAS_SSE2_INTRINSICS)
|
|
if (std::is_signed<typename KernelType::OffsetBType>::value) {
|
|
if (!WorkBlock->BIsSigned) {
|
|
offb = typename KernelType::OffsetBType(offb ^ 0x80);
|
|
}
|
|
}
|
|
#elif defined(MLAS_NEON_INTRINSICS)
|
|
if (WorkBlock->BIsSigned) {
|
|
offb = typename KernelType::OffsetBType(offb ^ 0x80);
|
|
}
|
|
#endif
|
|
|
|
//
|
|
// Extract the pointer to the column sum buffer from the packed matrix.
|
|
//
|
|
|
|
const size_t AlignedN =
|
|
(WorkBlock->N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1);
|
|
const int32_t* PackedColumnSumBuffer = (const int32_t*)PackedB;
|
|
PackedB = (const uint8_t*)(PackedColumnSumBuffer + AlignedN);
|
|
PackedColumnSumBuffer += WorkBlock->RangeStartN;
|
|
|
|
//
|
|
// Step through each slice of matrix B along the K dimension.
|
|
//
|
|
|
|
size_t CountK;
|
|
|
|
for (size_t k = 0; k < K; k += CountK) {
|
|
|
|
CountK = std::min(K - k, Strides.K);
|
|
|
|
const size_t PackedCountK = (CountK + KernelType::PackedK - 1) /
|
|
KernelType::PackedK;
|
|
|
|
if (k > 0) {
|
|
std::fill_n(ColumnSumBuffer, Strides.N, 0);
|
|
}
|
|
|
|
//
|
|
// Step through each slice of matrix B along the N dimension.
|
|
//
|
|
|
|
size_t CountN;
|
|
|
|
for (size_t n = 0; n < N; n += CountN) {
|
|
|
|
CountN = std::min(N - n, Strides.N);
|
|
|
|
if (k == 0) {
|
|
MlasGemmU8X8ScaleSumBuffer(ColumnSumBuffer, PackedColumnSumBuffer + n,
|
|
CountN, -offa);
|
|
}
|
|
|
|
//
|
|
// Step through each slice of matrix A along the M dimension.
|
|
//
|
|
|
|
const int32_t DepthValue = int32_t(CountK) * offa * offb;
|
|
const uint8_t* b = PackedB + (WorkBlock->RangeStartN + n) *
|
|
KernelType::PackedK * PackedCountK;
|
|
int32_t* c = C + n;
|
|
size_t CountM;
|
|
|
|
for (size_t m = 0; m < M; m += CountM) {
|
|
|
|
CountM = std::min(M - m, Strides.M);
|
|
|
|
//
|
|
// Copy a panel of matrix A to a local packed buffer.
|
|
//
|
|
|
|
KernelType::CopyPackA(PanelA, A + m * lda, lda, CountM, CountK,
|
|
RowSumBuffer);
|
|
|
|
MlasGemmU8X8ScaleSumBuffer(RowSumBuffer, CountM, -offb);
|
|
|
|
//
|
|
// Step through the rows of the local packed buffer.
|
|
//
|
|
|
|
typename KernelType::PackedAType* pa = PanelA;
|
|
int32_t* RowSums = RowSumBuffer;
|
|
size_t RowsRemaining = CountM;
|
|
|
|
bool ZeroMode = (k == 0);
|
|
bool PostProcess = (k + CountK == K);
|
|
|
|
while (RowsRemaining > 0) {
|
|
|
|
size_t RowsHandled;
|
|
|
|
RowsHandled = KernelType::GemmKernel(pa, b, c, PackedCountK,
|
|
RowsRemaining, CountN, ldc, RowSums, ColumnSumBuffer,
|
|
DepthValue, ZeroMode);
|
|
|
|
if (PostProcess && WorkBlock->OutputProcessor != nullptr) {
|
|
WorkBlock->OutputProcessor->Process(
|
|
WorkBlock->C,
|
|
WorkBlock->RangeStartM + m + CountM - RowsRemaining,
|
|
WorkBlock->RangeStartN + n,
|
|
RowsHandled,
|
|
CountN,
|
|
WorkBlock->ldc);
|
|
}
|
|
|
|
c += ldc * RowsHandled;
|
|
pa += KernelType::PackedK * PackedCountK * RowsHandled;
|
|
RowSums += RowsHandled;
|
|
RowsRemaining -= RowsHandled;
|
|
}
|
|
}
|
|
}
|
|
|
|
A += CountK;
|
|
PackedB = (const uint8_t*)PackedB + AlignedN * CountK;
|
|
}
|
|
}
|
|
|
|
#ifdef MLAS_SSE2_INTRINSICS
|
|
|
|
void
|
|
MlasGemmU8X8CopyPackASse(
|
|
int16_t* D,
|
|
const uint8_t* A,
|
|
size_t lda,
|
|
size_t CountM,
|
|
size_t CountK,
|
|
int32_t* RowSumBuffer
|
|
)
|
|
/*++
|
|
|
|
Routine Description:
|
|
|
|
This routine copies elements from the source matrix to the destination
|
|
packed buffer.
|
|
|
|
Arguments:
|
|
|
|
D - Supplies the address of the destination packed buffer.
|
|
|
|
A - Supplies the address of the source matrix.
|
|
|
|
lda - Supplies the number of elements per row of the source matrix.
|
|
|
|
CountM - Supplies the number of rows of the source matrix to copy.
|
|
|
|
CountK - Supplies the number of columns of the source matrix to copy.
|
|
|
|
RowSumBuffer - Supplies the address of the buffer to receive the sums of
|
|
the elements along each of the rows.
|
|
|
|
Return Value:
|
|
|
|
None.
|
|
|
|
--*/
|
|
{
|
|
const __m128i ZeroVector = _mm_setzero_si128();
|
|
const __m128i OnesWordBroadcast = _mm_set1_epi16(1);
|
|
uint8_t PaddedMatrixAData[8] = { 0 };
|
|
|
|
//
|
|
// Process a single row of matrix A in a loop.
|
|
//
|
|
|
|
while (CountM > 0) {
|
|
|
|
const uint8_t* a = A;
|
|
size_t k = CountK;
|
|
__m128i ReductionVector = ZeroVector;
|
|
|
|
//
|
|
// Zero extend the source bytes to 16-bits and write to the packed
|
|
// buffer.
|
|
//
|
|
// The packed buffer has the same data ordering as the source bytes,
|
|
// but CountK is aligned up to a multiple of 2 to maintain 32-bit
|
|
// alignment. All extra bytes are zero-padded.
|
|
//
|
|
// These 16-bit values are also accumulated into an intermediate per-row
|
|
// accumulator. CountK cannot be greater than 128 to avoid overflowing
|
|
// these signed 16-bit accumulators.
|
|
//
|
|
|
|
while (k >= 8) {
|
|
|
|
__m128i Bytes = _mm_loadl_epi64((__m128i*)&a[0]);
|
|
__m128i Words = _mm_unpacklo_epi8(Bytes, ZeroVector);
|
|
|
|
ReductionVector = _mm_add_epi16(ReductionVector, Words);
|
|
|
|
_mm_storeu_si128((__m128i*)&D[0], Words);
|
|
|
|
D += 8;
|
|
a += 8;
|
|
k -= 8;
|
|
}
|
|
|
|
if (k > 0) {
|
|
|
|
//
|
|
// Copy the remaining bytes to the zero padded stack buffer.
|
|
//
|
|
|
|
uint8_t* padded = PaddedMatrixAData;
|
|
uint8_t* padded_end = padded + k;
|
|
|
|
do {
|
|
padded[0] = a[0];
|
|
padded++;
|
|
a++;
|
|
} while (padded < padded_end);
|
|
|
|
__m128i Bytes = _mm_loadl_epi64((__m128i*)PaddedMatrixAData);
|
|
__m128i Words = _mm_unpacklo_epi8(Bytes, ZeroVector);
|
|
|
|
ReductionVector = _mm_add_epi16(ReductionVector, Words);
|
|
|
|
//
|
|
// Copy pairs of 16-bit values from the vector to the packed
|
|
// buffer and rotate the vector for the next iteration.
|
|
//
|
|
|
|
for (size_t pairs = (k + 1) / 2; pairs > 0; pairs--) {
|
|
*((int32_t*)D) = _mm_cvtsi128_si32(Words);
|
|
D += 2;
|
|
Words = _mm_shuffle_epi32(Words, _MM_SHUFFLE(0, 3, 2, 1));
|
|
}
|
|
}
|
|
|
|
//
|
|
// Reduce the partial accumulators.
|
|
//
|
|
|
|
ReductionVector = _mm_madd_epi16(ReductionVector, OnesWordBroadcast);
|
|
ReductionVector = _mm_add_epi32(ReductionVector,
|
|
_mm_shuffle_epi32(ReductionVector, _MM_SHUFFLE(3, 2, 3, 2)));
|
|
ReductionVector = _mm_add_epi32(ReductionVector,
|
|
_mm_shuffle_epi32(ReductionVector, _MM_SHUFFLE(0, 1, 0, 1)));
|
|
|
|
*RowSumBuffer++ = _mm_cvtsi128_si32(ReductionVector);
|
|
|
|
A += lda;
|
|
CountM -= 1;
|
|
}
|
|
}
|
|
|
|
void
|
|
MlasGemmU8X8CopyPackBProcessSse(
|
|
int16_t* D,
|
|
__m128i BytesRow0,
|
|
__m128i BytesRow1,
|
|
__m128i BitFlipVector,
|
|
__m128i ColumnSums[2]
|
|
)
|
|
{
|
|
__m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, BytesRow1);
|
|
|
|
BytesInterleaved = _mm_xor_si128(BytesInterleaved, BitFlipVector);
|
|
|
|
__m128i WordsInterleaved0 = _mm_srai_epi16(_mm_unpacklo_epi8(BytesInterleaved, BytesInterleaved), 8);
|
|
__m128i WordsInterleaved1 = _mm_srai_epi16(_mm_unpackhi_epi8(BytesInterleaved, BytesInterleaved), 8);
|
|
|
|
ColumnSums[0] = _mm_add_epi16(ColumnSums[0], WordsInterleaved0);
|
|
ColumnSums[1] = _mm_add_epi16(ColumnSums[1], WordsInterleaved1);
|
|
|
|
_mm_storeu_si128((__m128i*)&D[0], WordsInterleaved0);
|
|
_mm_storeu_si128((__m128i*)&D[8], WordsInterleaved1);
|
|
}
|
|
|
|
void
|
|
MlasGemmU8X8CopyPackBSse(
|
|
int16_t* D,
|
|
const uint8_t* B,
|
|
size_t ldb,
|
|
size_t CountN,
|
|
size_t CountK,
|
|
int32_t* ColumnSumBuffer,
|
|
bool BIsSigned
|
|
)
|
|
/*++
|
|
|
|
Routine Description:
|
|
|
|
This routine copies elements from the source matrix to the destination
|
|
packed buffer.
|
|
|
|
Arguments:
|
|
|
|
D - Supplies the address of the destination packed buffer.
|
|
|
|
B - Supplies the address of the source matrix.
|
|
|
|
ldb - Supplies the number of elements per row of the source matrix.
|
|
|
|
CountN - Supplies the number of columns of the source matrix to copy.
|
|
|
|
CountK - Supplies the number of rows of the source matrix to copy.
|
|
|
|
ColumnSumBuffer - Supplies the address of the buffer to receive the sums of
|
|
the elements along each of the columns.
|
|
|
|
BIsSigned - Supplies true if the source matrix is signed data, else false
|
|
if the source matrix is unsigned data.
|
|
|
|
Return Value:
|
|
|
|
None.
|
|
|
|
--*/
|
|
{
|
|
const __m128i OnesWordBroadcast = _mm_set1_epi16(1);
|
|
const __m128i BitFlipVector = _mm_set1_epi32(BIsSigned ? 0 : 0x80808080);
|
|
|
|
//
|
|
// Process 8 columns of matrix B in a loop.
|
|
//
|
|
|
|
while (CountN >= 8) {
|
|
|
|
const uint8_t* b = B;
|
|
size_t k = CountK;
|
|
__m128i ColumnSums[2];
|
|
|
|
ColumnSums[0] = _mm_setzero_si128();
|
|
ColumnSums[1] = _mm_setzero_si128();
|
|
|
|
//
|
|
// Interleave rows of matrix B and write to the packed buffer.
|
|
//
|
|
// These values are also zero-extended and accumulated into an
|
|
// intermediate per-column accumulator. CountK cannot be greater than
|
|
// 128 to avoid overflowing these signed 16-bit accumulators.
|
|
//
|
|
|
|
while (k >= 2) {
|
|
|
|
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]);
|
|
__m128i BytesRow1 = _mm_loadl_epi64((__m128i*)&b[ldb]);
|
|
|
|
MlasGemmU8X8CopyPackBProcessSse(D, BytesRow0, BytesRow1, BitFlipVector, ColumnSums);
|
|
|
|
b += ldb * 2;
|
|
D += 16;
|
|
k -= 2;
|
|
}
|
|
|
|
if (k > 0) {
|
|
|
|
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&b[0]);
|
|
|
|
MlasGemmU8X8CopyPackBProcessSse(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums);
|
|
|
|
D += 16;
|
|
}
|
|
|
|
//
|
|
// Reduce the partial accumulators.
|
|
//
|
|
|
|
ColumnSums[0] = _mm_madd_epi16(ColumnSums[0], OnesWordBroadcast);
|
|
ColumnSums[1] = _mm_madd_epi16(ColumnSums[1], OnesWordBroadcast);
|
|
|
|
_mm_storeu_si128((__m128i*)&ColumnSumBuffer[0], ColumnSums[0]);
|
|
_mm_storeu_si128((__m128i*)&ColumnSumBuffer[4], ColumnSums[1]);
|
|
ColumnSumBuffer += 8;
|
|
|
|
B += 8;
|
|
CountN -= 8;
|
|
}
|
|
|
|
//
|
|
// Process the remaining columns of matrix B.
|
|
//
|
|
|
|
if (CountN > 0) {
|
|
|
|
const uint8_t* b = B;
|
|
size_t k = CountK;
|
|
__m128i ColumnSums[2];
|
|
uint8_t PaddedMatrixBData[16];
|
|
|
|
_mm_storeu_si128((__m128i*)PaddedMatrixBData, BitFlipVector);
|
|
|
|
ColumnSums[0] = _mm_setzero_si128();
|
|
ColumnSums[1] = _mm_setzero_si128();
|
|
|
|
//
|
|
// Interleave rows of matrix B using an intermediate zero padded stack
|
|
// buffer and write to the packed buffer.
|
|
//
|
|
|
|
while (k >= 2) {
|
|
|
|
const uint8_t* bcopy = b;
|
|
uint8_t* padded = PaddedMatrixBData;
|
|
uint8_t* padded_end = padded + CountN;
|
|
|
|
do {
|
|
padded[0] = bcopy[0];
|
|
padded[8] = bcopy[ldb];
|
|
padded++;
|
|
bcopy++;
|
|
} while (padded < padded_end);
|
|
|
|
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[0]);
|
|
__m128i BytesRow1 = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[8]);
|
|
|
|
MlasGemmU8X8CopyPackBProcessSse(D, BytesRow0, BytesRow1, BitFlipVector, ColumnSums);
|
|
|
|
b += ldb * 2;
|
|
D += 16;
|
|
k -= 2;
|
|
}
|
|
|
|
if (k > 0) {
|
|
|
|
const uint8_t* bcopy = b;
|
|
uint8_t* padded = PaddedMatrixBData;
|
|
uint8_t* padded_end = padded + CountN;
|
|
|
|
do {
|
|
padded[0] = bcopy[0];
|
|
padded++;
|
|
bcopy++;
|
|
} while (padded < padded_end);
|
|
|
|
__m128i BytesRow0 = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[0]);
|
|
|
|
MlasGemmU8X8CopyPackBProcessSse(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums);
|
|
}
|
|
|
|
//
|
|
// Reduce the sum for the packed columns and multiply by the zero point
|
|
// offset of the other source matrix.
|
|
//
|
|
|
|
ColumnSums[0] = _mm_madd_epi16(ColumnSums[0], OnesWordBroadcast);
|
|
ColumnSums[1] = _mm_madd_epi16(ColumnSums[1], OnesWordBroadcast);
|
|
|
|
_mm_storeu_si128((__m128i*)&ColumnSumBuffer[0], ColumnSums[0]);
|
|
_mm_storeu_si128((__m128i*)&ColumnSumBuffer[4], ColumnSums[1]);
|
|
}
|
|
}
|
|
|
|
MLAS_FORCEINLINE
|
|
void
|
|
MlasGemmU8X8MultiplyAccumulateRowSse(
|
|
__m128i ABroadcast,
|
|
const int16_t* B,
|
|
__m128i Accumulators[2]
|
|
)
|
|
{
|
|
__m128i BElements0 = _mm_load_si128((__m128i*)&B[0]);
|
|
__m128i BElements1 = _mm_load_si128((__m128i*)&B[8]);
|
|
|
|
Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_madd_epi16(BElements0, ABroadcast));
|
|
Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_madd_epi16(BElements1, ABroadcast));
|
|
}
|
|
|
|
void
|
|
MlasGemmU8X8KernelSse(
|
|
const int16_t* A,
|
|
const int16_t* B,
|
|
int32_t* C,
|
|
size_t PackedCountK,
|
|
size_t CountN,
|
|
const int32_t* RowSumBuffer,
|
|
const int32_t* ColumnSumBuffer,
|
|
int32_t DepthValue,
|
|
bool ZeroMode
|
|
)
|
|
/*++
|
|
|
|
Routine Description:
|
|
|
|
This routine is an inner kernel to compute matrix multiplication for a
|
|
single row.
|
|
|
|
Arguments:
|
|
|
|
A - Supplies the address of matrix A. The matrix data has been packed
|
|
using MlasGemmU8X8CopyPackASse.
|
|
|
|
B - Supplies the address of matrix B. The matrix data has been packed
|
|
using MlasGemmU8X8CopyPackBSse.
|
|
|
|
C - Supplies the address of matrix C.
|
|
|
|
PackedCountK - Supplies the number of packed columns from matrix A and the
|
|
number of packed rows from matrix B to iterate over.
|
|
|
|
CountN - Supplies the number of columns from matrix B and matrix C to iterate
|
|
over.
|
|
|
|
RowSumBuffer - Supplies the sum of each row from matrix A multiplied by the
|
|
zero point offset of matrix B. These values are accumulated into every
|
|
row of matrix C.
|
|
|
|
ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied
|
|
by the zero point offset of matrix A. These values are accumulated into
|
|
every column of matrix C.
|
|
|
|
DepthValue - Supplies the value CountK multiplied by the zero point offset
|
|
of matrixA multplied by the zero point offset of matrix B. This value is
|
|
accumulated into every element of matrix C.
|
|
|
|
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
|
else false if the output matrix is accumulated into.
|
|
|
|
Return Value:
|
|
|
|
None.
|
|
|
|
--*/
|
|
{
|
|
while (CountN > 0) {
|
|
|
|
__m128i Accumulators[2];
|
|
|
|
//
|
|
// Initialize the accumulators with the sum of the global depth value
|
|
// constant, the column sums, and the row sums.
|
|
//
|
|
|
|
Accumulators[0] = _mm_set1_epi32(DepthValue);
|
|
Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_set1_epi32(RowSumBuffer[0]));
|
|
Accumulators[1] = Accumulators[0];
|
|
Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((__m128i*)&ColumnSumBuffer[0]));
|
|
Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_loadu_si128((__m128i*)&ColumnSumBuffer[4]));
|
|
ColumnSumBuffer += 8;
|
|
|
|
//
|
|
// Broadcast each pair of 16-bit values from the matrix A and multiply
|
|
// with the pair of 16-bit values from matrix B, and add the 32-bit
|
|
// intermediate into the accumulator registers.
|
|
//
|
|
|
|
const int16_t* a = A;
|
|
size_t k = PackedCountK;
|
|
|
|
while (k >= 4) {
|
|
|
|
__m128i AElements = _mm_loadu_si128((__m128i*)a);
|
|
__m128i ABroadcast;
|
|
|
|
ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(0, 0, 0, 0));
|
|
MlasGemmU8X8MultiplyAccumulateRowSse(ABroadcast, &B[0], Accumulators);
|
|
|
|
ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(1, 1, 1, 1));
|
|
MlasGemmU8X8MultiplyAccumulateRowSse(ABroadcast, &B[16], Accumulators);
|
|
|
|
ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(2, 2, 2, 2));
|
|
MlasGemmU8X8MultiplyAccumulateRowSse(ABroadcast, &B[32], Accumulators);
|
|
|
|
ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(3, 3, 3, 3));
|
|
MlasGemmU8X8MultiplyAccumulateRowSse(ABroadcast, &B[48], Accumulators);
|
|
|
|
a += 4 * 2;
|
|
B += 4 * 16;
|
|
k -= 4;
|
|
}
|
|
|
|
while (k > 0) {
|
|
|
|
__m128i ABroadcast = _mm_set1_epi32(*((int32_t*)a));
|
|
MlasGemmU8X8MultiplyAccumulateRowSse(ABroadcast, &B[0], Accumulators);
|
|
|
|
a += 2;
|
|
B += 16;
|
|
k -= 1;
|
|
}
|
|
|
|
//
|
|
// Output the accumulator block after optionally accumulating the values
|
|
// from matrix C.
|
|
//
|
|
|
|
if (CountN >= 8) {
|
|
|
|
if (!ZeroMode) {
|
|
Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((__m128i*)&C[0]));
|
|
Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_loadu_si128((__m128i*)&C[4]));
|
|
}
|
|
|
|
_mm_storeu_si128((__m128i*)&C[0], Accumulators[0]);
|
|
_mm_storeu_si128((__m128i*)&C[4], Accumulators[1]);
|
|
|
|
C += 8;
|
|
CountN -= 8;
|
|
|
|
} else {
|
|
|
|
//
|
|
// Output the remaining partial output block.
|
|
//
|
|
|
|
if ((CountN & 4) != 0) {
|
|
|
|
if (!ZeroMode) {
|
|
Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((__m128i*)&C[0]));
|
|
}
|
|
|
|
_mm_storeu_si128((__m128i*)&C[0], Accumulators[0]);
|
|
C += 4;
|
|
|
|
Accumulators[0] = Accumulators[1];
|
|
}
|
|
|
|
if ((CountN & 2) != 0) {
|
|
|
|
if (!ZeroMode) {
|
|
Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadl_epi64((__m128i*)&C[0]));
|
|
}
|
|
|
|
_mm_storel_epi64((__m128i*)&C[0], Accumulators[0]);
|
|
C += 2;
|
|
|
|
Accumulators[0] = _mm_shuffle_epi32(Accumulators[0], _MM_SHUFFLE(3, 2, 3, 2));
|
|
}
|
|
|
|
if ((CountN & 1) != 0) {
|
|
|
|
int32_t AccumulatorValue = _mm_cvtsi128_si32(Accumulators[0]);
|
|
|
|
if (!ZeroMode) {
|
|
AccumulatorValue += C[0];
|
|
}
|
|
|
|
C[0] = AccumulatorValue;
|
|
}
|
|
|
|
CountN = 0;
|
|
}
|
|
}
|
|
}
|
|
|
|
struct MLAS_GEMM_U8X8_KERNEL_SSE
|
|
{
|
|
typedef int16_t PackedAType;
|
|
typedef int16_t PackedBType;
|
|
typedef int8_t OffsetBType;
|
|
|
|
static constexpr size_t PackedK = 2;
|
|
static constexpr MLAS_GEMM_U8X8_STRIDES Strides{12, 128, 128};
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
bool
|
|
TryGemvKernel(
|
|
const uint8_t* A,
|
|
const uint8_t* B,
|
|
size_t ldb,
|
|
int32_t* C,
|
|
size_t CountK,
|
|
size_t CountN,
|
|
bool BIsSigned
|
|
)
|
|
{
|
|
MLAS_UNREFERENCED_PARAMETER(A);
|
|
MLAS_UNREFERENCED_PARAMETER(B);
|
|
MLAS_UNREFERENCED_PARAMETER(ldb);
|
|
MLAS_UNREFERENCED_PARAMETER(C);
|
|
MLAS_UNREFERENCED_PARAMETER(CountK);
|
|
MLAS_UNREFERENCED_PARAMETER(CountN);
|
|
MLAS_UNREFERENCED_PARAMETER(BIsSigned);
|
|
|
|
return false;
|
|
}
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
void
|
|
CopyPackA(
|
|
PackedAType* D,
|
|
const uint8_t* A,
|
|
size_t lda,
|
|
size_t CountM,
|
|
size_t CountK,
|
|
int32_t* RowSumBuffer
|
|
)
|
|
{
|
|
MlasGemmU8X8CopyPackASse(D, A, lda, CountM, CountK, RowSumBuffer);
|
|
}
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
void
|
|
CopyPackB(
|
|
PackedBType* D,
|
|
const uint8_t* B,
|
|
size_t ldb,
|
|
size_t CountN,
|
|
size_t CountK,
|
|
int32_t* ColumnSumBuffer,
|
|
bool BIsSigned
|
|
)
|
|
{
|
|
MlasGemmU8X8CopyPackBSse(D, B, ldb, CountN, CountK, ColumnSumBuffer,
|
|
BIsSigned);
|
|
}
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
size_t
|
|
GemmKernel(
|
|
const PackedAType* A,
|
|
const PackedBType* B,
|
|
int32_t* C,
|
|
size_t PackedCountK,
|
|
size_t CountM,
|
|
size_t CountN,
|
|
size_t ldc,
|
|
const int32_t* RowSumBuffer,
|
|
const int32_t* ColumnSumBuffer,
|
|
int32_t DepthValue,
|
|
bool ZeroMode
|
|
)
|
|
{
|
|
MLAS_UNREFERENCED_PARAMETER(CountM);
|
|
MLAS_UNREFERENCED_PARAMETER(ldc);
|
|
|
|
MlasGemmU8X8KernelSse(A, B, C, PackedCountK, CountN, RowSumBuffer,
|
|
ColumnSumBuffer, DepthValue, ZeroMode);
|
|
|
|
return 1;
|
|
}
|
|
};
|
|
|
|
constexpr size_t MLAS_GEMM_U8X8_KERNEL_SSE::PackedK;
|
|
constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_SSE::Strides;
|
|
|
|
template
|
|
void
|
|
MLASCALL
|
|
MlasGemmU8X8Operation<MLAS_GEMM_U8X8_KERNEL_SSE>(
|
|
const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock
|
|
);
|
|
|
|
#endif
|
|
|
|
#ifdef MLAS_TARGET_AMD64
|
|
|
|
//
|
|
// Stores a vector to transpose a 4x4 byte vector using vpshufb.
|
|
//
|
|
|
|
MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint8_t MlasTranspose4x4BytesAvx[16], 16) =
|
|
{ 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 };
|
|
|
|
//
|
|
// Define the prototypes of the AVX2/AVX512 routines written in assembly.
|
|
//
|
|
|
|
extern "C" {
|
|
|
|
void
|
|
MLASCALL
|
|
MlasGemmU8S8CopyPackAAvx2(
|
|
uint8_t* D,
|
|
const uint8_t* A,
|
|
size_t lda,
|
|
size_t CountM,
|
|
size_t CountK,
|
|
int32_t* RowSumBuffer
|
|
);
|
|
|
|
void
|
|
MLASCALL
|
|
MlasGemmU8S8CopyPackBAvx2(
|
|
uint8_t* D,
|
|
const uint8_t* B,
|
|
size_t ldb,
|
|
size_t CountN,
|
|
size_t CountK,
|
|
int32_t* ColumnSumBuffer,
|
|
bool BIsSigned
|
|
);
|
|
|
|
void
|
|
MLASCALL
|
|
MlasGemmU8U8CopyPackAAvx2(
|
|
int16_t* D,
|
|
const uint8_t* A,
|
|
size_t lda,
|
|
size_t CountM,
|
|
size_t CountK,
|
|
int32_t* RowSumBuffer
|
|
);
|
|
|
|
void
|
|
MLASCALL
|
|
MlasGemmU8U8CopyPackBAvx2(
|
|
uint8_t* D,
|
|
const uint8_t* B,
|
|
size_t ldb,
|
|
size_t CountN,
|
|
size_t CountK,
|
|
int32_t* ColumnSumBuffer
|
|
);
|
|
}
|
|
|
|
struct MLAS_GEMM_U8S8_KERNEL_AVX2
|
|
{
|
|
typedef uint8_t PackedAType;
|
|
typedef uint8_t PackedBType;
|
|
typedef int8_t OffsetBType;
|
|
|
|
static constexpr size_t PackedK = 4;
|
|
static constexpr MLAS_GEMM_U8X8_STRIDES Strides{24, 256, 128};
|
|
static constexpr MLAS_GEMM_U8X8_STRIDES PackedStrides{48, 256, 384};
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
bool
|
|
TryGemvKernel(
|
|
const uint8_t* A,
|
|
const uint8_t* B,
|
|
size_t ldb,
|
|
int32_t* C,
|
|
size_t CountK,
|
|
size_t CountN,
|
|
bool BIsSigned
|
|
)
|
|
{
|
|
if (BIsSigned) {
|
|
MlasPlatform.GemvU8S8Kernel(A, B, C, CountK, CountN, ldb);
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
void
|
|
CopyPackA(
|
|
PackedAType* D,
|
|
const uint8_t* A,
|
|
size_t lda,
|
|
size_t CountM,
|
|
size_t CountK,
|
|
int32_t* RowSumBuffer
|
|
)
|
|
{
|
|
MlasGemmU8S8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer);
|
|
}
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
void
|
|
CopyPackB(
|
|
PackedBType* D,
|
|
const uint8_t* B,
|
|
size_t ldb,
|
|
size_t CountN,
|
|
size_t CountK,
|
|
int32_t* ColumnSumBuffer,
|
|
bool BIsSigned
|
|
)
|
|
{
|
|
MlasGemmU8S8CopyPackBAvx2(D, B, ldb, CountN, CountK, ColumnSumBuffer,
|
|
BIsSigned);
|
|
}
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
size_t
|
|
GemmKernel(
|
|
const PackedAType* A,
|
|
const PackedBType* B,
|
|
int32_t* C,
|
|
size_t PackedCountK,
|
|
size_t CountM,
|
|
size_t CountN,
|
|
size_t ldc,
|
|
const int32_t* RowSumBuffer,
|
|
const int32_t* ColumnSumBuffer,
|
|
int32_t DepthValue,
|
|
bool ZeroMode
|
|
)
|
|
{
|
|
return MlasPlatform.GemmU8S8Kernel(A, B, C, PackedCountK, CountM, CountN,
|
|
ldc, RowSumBuffer, ColumnSumBuffer, DepthValue, ZeroMode);
|
|
}
|
|
};
|
|
|
|
constexpr size_t MLAS_GEMM_U8S8_KERNEL_AVX2::PackedK;
|
|
constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8S8_KERNEL_AVX2::Strides;
|
|
constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8S8_KERNEL_AVX2::PackedStrides;
|
|
|
|
template
|
|
void
|
|
MlasGemmU8X8Operation<MLAS_GEMM_U8S8_KERNEL_AVX2>(
|
|
const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock
|
|
);
|
|
|
|
template
|
|
void
|
|
MlasGemmU8X8PackedOperation<MLAS_GEMM_U8S8_KERNEL_AVX2>(
|
|
const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock
|
|
);
|
|
|
|
struct MLAS_GEMM_U8U8_KERNEL_AVX2
|
|
{
|
|
typedef int16_t PackedAType;
|
|
typedef uint8_t PackedBType;
|
|
typedef uint8_t OffsetBType;
|
|
|
|
static constexpr size_t PackedK = 2;
|
|
static constexpr MLAS_GEMM_U8X8_STRIDES Strides{24, 256, 128};
|
|
static constexpr MLAS_GEMM_U8X8_STRIDES PackedStrides{48, 256, 384};
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
bool
|
|
TryGemvKernel(
|
|
const uint8_t* A,
|
|
const uint8_t* B,
|
|
size_t ldb,
|
|
int32_t* C,
|
|
size_t CountK,
|
|
size_t CountN,
|
|
bool BIsSigned
|
|
)
|
|
{
|
|
MLAS_UNREFERENCED_PARAMETER(A);
|
|
MLAS_UNREFERENCED_PARAMETER(B);
|
|
MLAS_UNREFERENCED_PARAMETER(ldb);
|
|
MLAS_UNREFERENCED_PARAMETER(C);
|
|
MLAS_UNREFERENCED_PARAMETER(CountK);
|
|
MLAS_UNREFERENCED_PARAMETER(CountN);
|
|
MLAS_UNREFERENCED_PARAMETER(BIsSigned);
|
|
|
|
return false;
|
|
}
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
void
|
|
CopyPackA(
|
|
PackedAType* D,
|
|
const uint8_t* A,
|
|
size_t lda,
|
|
size_t CountM,
|
|
size_t CountK,
|
|
int32_t* RowSumBuffer
|
|
)
|
|
{
|
|
MlasGemmU8U8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer);
|
|
}
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
void
|
|
CopyPackB(
|
|
PackedBType* D,
|
|
const uint8_t* B,
|
|
size_t ldb,
|
|
size_t CountN,
|
|
size_t CountK,
|
|
int32_t* ColumnSumBuffer,
|
|
bool BIsSigned
|
|
)
|
|
{
|
|
MLAS_UNREFERENCED_PARAMETER(BIsSigned);
|
|
|
|
MlasGemmU8U8CopyPackBAvx2(D, B, ldb, CountN, CountK, ColumnSumBuffer);
|
|
}
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
size_t
|
|
GemmKernel(
|
|
const PackedAType* A,
|
|
const PackedBType* B,
|
|
int32_t* C,
|
|
size_t PackedCountK,
|
|
size_t CountM,
|
|
size_t CountN,
|
|
size_t ldc,
|
|
const int32_t* RowSumBuffer,
|
|
const int32_t* ColumnSumBuffer,
|
|
int32_t DepthValue,
|
|
bool ZeroMode
|
|
)
|
|
{
|
|
return MlasPlatform.GemmU8U8Kernel(A, B, C, PackedCountK, CountM, CountN,
|
|
ldc, RowSumBuffer, ColumnSumBuffer, DepthValue, ZeroMode);
|
|
}
|
|
};
|
|
|
|
constexpr size_t MLAS_GEMM_U8U8_KERNEL_AVX2::PackedK;
|
|
constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8U8_KERNEL_AVX2::Strides;
|
|
constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8U8_KERNEL_AVX2::PackedStrides;
|
|
|
|
template
|
|
void
|
|
MLASCALL
|
|
MlasGemmU8X8Operation<MLAS_GEMM_U8U8_KERNEL_AVX2>(
|
|
const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock
|
|
);
|
|
|
|
template
|
|
void
|
|
MlasGemmU8X8PackedOperation<MLAS_GEMM_U8U8_KERNEL_AVX2>(
|
|
const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock
|
|
);
|
|
|
|
#endif
|
|
|
|
#ifdef MLAS_NEON_INTRINSICS
|
|
|
|
//
|
|
// Define the prototypes of the NEON routines written in assembly.
|
|
//
|
|
|
|
extern "C"
|
|
size_t
|
|
MLASCALL
|
|
MlasGemmU8X8KernelNeon(
|
|
const uint8_t* A,
|
|
const uint8_t* B,
|
|
int32_t* C,
|
|
size_t PackedCountK,
|
|
size_t CountM,
|
|
size_t CountN,
|
|
size_t ldc,
|
|
const int32_t* RowSumVector,
|
|
const int32_t* ColumnSumVector,
|
|
int32_t DepthValue,
|
|
bool ZeroMode
|
|
);
|
|
|
|
void
|
|
MLASCALL
|
|
MlasGemmU8X8CopyPackANeon(
|
|
uint8_t* D,
|
|
const uint8_t* A,
|
|
size_t lda,
|
|
size_t CountM,
|
|
size_t CountK,
|
|
int32_t* RowSumBuffer
|
|
)
|
|
/*++
|
|
|
|
Routine Description:
|
|
|
|
This routine copies elements from the source matrix to the destination
|
|
packed buffer.
|
|
|
|
Arguments:
|
|
|
|
D - Supplies the address of the destination packed buffer.
|
|
|
|
A - Supplies the address of the source matrix.
|
|
|
|
lda - Supplies the number of elements per row of the source matrix.
|
|
|
|
CountM - Supplies the number of rows of the source matrix to copy.
|
|
|
|
CountK - Supplies the number of columns of the source matrix to copy.
|
|
|
|
RowSumBuffer - Supplies the address of the buffer to receive the sums of
|
|
the elements along each of the rows.
|
|
|
|
Return Value:
|
|
|
|
None.
|
|
|
|
--*/
|
|
{
|
|
uint8_t PaddedMatrixAData[16];
|
|
|
|
//
|
|
// Process four rows of matrix A in a loop.
|
|
//
|
|
// The buffer is packed as a series of 16 byte vectors where four rows are
|
|
// interleaved with the following pattern:
|
|
//
|
|
// [ A0 A1 A2 A3 B0 B1 B2 B3 C0 C1 C2 C3 D0 D1 D2 D3 ]
|
|
// [ A4 A5 A6 A7 B4 B5 B6 B7 C4 C5 C6 C7 D4 D5 D6 D7 ]
|
|
//
|
|
// This pattern is repeated (CountK / 4) times.
|
|
//
|
|
// If CountK is not aligned to a multiple of four, then the vector is padded
|
|
// with zeroes.
|
|
//
|
|
|
|
while (CountM >= 4) {
|
|
|
|
const uint8_t* a0 = A;
|
|
const uint8_t* a1 = a0 + lda;
|
|
const uint8_t* a2 = a1 + lda;
|
|
const uint8_t* a3 = a2 + lda;
|
|
|
|
size_t k = CountK;
|
|
uint32x4_t RowSums = vmovq_n_u32(0);
|
|
|
|
while (k >= 16) {
|
|
|
|
uint32x4_t v0 = vld1q_u32(reinterpret_cast<const uint32_t*>(a0));
|
|
a0 += 16;
|
|
uint32x4_t v1 = vld1q_u32(reinterpret_cast<const uint32_t*>(a1));
|
|
a1 += 16;
|
|
uint32x4_t v2 = vld1q_u32(reinterpret_cast<const uint32_t*>(a2));
|
|
a2 += 16;
|
|
uint32x4_t v3 = vld1q_u32(reinterpret_cast<const uint32_t*>(a3));
|
|
a3 += 16;
|
|
|
|
#if defined(MLAS_NEON32_INTRINSICS)
|
|
uint32x4x2_t z0 = vzipq_u32(v0, v2);
|
|
uint32x4x2_t z1 = vzipq_u32(v1, v3);
|
|
|
|
v0 = z0.val[0];
|
|
v1 = z0.val[1];
|
|
v2 = z1.val[0];
|
|
v3 = z1.val[1];
|
|
|
|
uint32x4x2_t z2 = vzipq_u32(v0, v2);
|
|
uint32x4x2_t z3 = vzipq_u32(v1, v3);
|
|
|
|
v0 = z2.val[0];
|
|
v1 = z2.val[1];
|
|
v2 = z3.val[0];
|
|
v3 = z3.val[1];
|
|
#else
|
|
uint32x4_t z0 = vzip1q_u32(v0, v2);
|
|
uint32x4_t z1 = vzip2q_u32(v0, v2);
|
|
uint32x4_t z2 = vzip1q_u32(v1, v3);
|
|
uint32x4_t z3 = vzip2q_u32(v1, v3);
|
|
|
|
v0 = vzip1q_u32(z0, z2);
|
|
v1 = vzip2q_u32(z0, z2);
|
|
v2 = vzip1q_u32(z1, z3);
|
|
v3 = vzip2q_u32(z1, z3);
|
|
#endif
|
|
|
|
vst1q_u8(&D[0], vreinterpretq_u8_u32(v0));
|
|
vst1q_u8(&D[16], vreinterpretq_u8_u32(v1));
|
|
vst1q_u8(&D[32], vreinterpretq_u8_u32(v2));
|
|
vst1q_u8(&D[48], vreinterpretq_u8_u32(v3));
|
|
|
|
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v0)));
|
|
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v1)));
|
|
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v2)));
|
|
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v3)));
|
|
|
|
D += 64;
|
|
k -= 16;
|
|
}
|
|
|
|
uint32x4_t GatherVector = vmovq_n_u32(0);
|
|
|
|
while (k >= 4) {
|
|
|
|
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a0), GatherVector, 0);
|
|
a0 += 4;
|
|
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a1), GatherVector, 1);
|
|
a1 += 4;
|
|
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a2), GatherVector, 2);
|
|
a2 += 4;
|
|
GatherVector = vld1q_lane_u32(reinterpret_cast<const uint32_t*>(a3), GatherVector, 3);
|
|
a3 += 4;
|
|
|
|
uint8x16_t PackedVector = vreinterpretq_u8_u32(GatherVector);
|
|
vst1q_u8(D, PackedVector);
|
|
|
|
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(PackedVector));
|
|
|
|
D += 16;
|
|
k -= 4;
|
|
}
|
|
|
|
if (k > 0) {
|
|
|
|
//
|
|
// Copy the remaining bytes to the zero padded stack buffer.
|
|
//
|
|
|
|
uint8_t* d = PaddedMatrixAData;
|
|
|
|
vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0));
|
|
|
|
while (k > 0) {
|
|
|
|
d[0] = *a0++;
|
|
d[4] = *a1++;
|
|
d[8] = *a2++;
|
|
d[12] = *a3++;
|
|
|
|
d += 1;
|
|
k -= 1;
|
|
}
|
|
|
|
uint8x16_t PackedVector = vld1q_u8(PaddedMatrixAData);
|
|
vst1q_u8(D, PackedVector);
|
|
|
|
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(PackedVector));
|
|
|
|
D += 16;
|
|
}
|
|
|
|
vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums));
|
|
RowSumBuffer += 4;
|
|
|
|
A = A + lda * 4;
|
|
CountM -= 4;
|
|
}
|
|
|
|
//
|
|
// Process two rows of matrix A.
|
|
//
|
|
// The buffer is packed as a series of 8 byte vectors where two rows are
|
|
// interleaved with the following pattern:
|
|
//
|
|
// [ A0 A1 A2 A3 B0 B1 B2 B3 ]
|
|
// [ A4 A5 A6 A7 B4 B5 B6 B7 ]
|
|
//
|
|
// This pattern is repeated (CountK / 4) times.
|
|
//
|
|
// If CountK is not aligned to a multiple of four, then the vector is padded
|
|
// with zeroes.
|
|
//
|
|
|
|
if ((CountM & 2) != 0) {
|
|
|
|
const uint8_t* a0 = A;
|
|
const uint8_t* a1 = a0 + lda;
|
|
|
|
size_t k = CountK;
|
|
uint32x2_t RowSums = vmov_n_u32(0);
|
|
uint32x2_t GatherVector = vmov_n_u32(0);
|
|
|
|
while (k >= 4) {
|
|
|
|
GatherVector = vld1_lane_u32(reinterpret_cast<const uint32_t*>(a0), GatherVector, 0);
|
|
a0 += 4;
|
|
GatherVector = vld1_lane_u32(reinterpret_cast<const uint32_t*>(a1), GatherVector, 1);
|
|
a1 += 4;
|
|
|
|
uint8x8_t PackedVector = vreinterpret_u8_u32(GatherVector);
|
|
vst1_u8(D, PackedVector);
|
|
|
|
RowSums = vpadal_u16(RowSums, vpaddl_u8(PackedVector));
|
|
|
|
D += 8;
|
|
k -= 4;
|
|
}
|
|
|
|
if (k > 0) {
|
|
|
|
//
|
|
// Copy the remaining bytes to the zero padded stack buffer.
|
|
//
|
|
|
|
uint8_t* d = PaddedMatrixAData;
|
|
|
|
vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0));
|
|
|
|
while (k > 0) {
|
|
|
|
d[0] = *a0++;
|
|
d[4] = *a1++;
|
|
|
|
d += 1;
|
|
k -= 1;
|
|
}
|
|
|
|
uint8x8_t PackedVector = vld1_u8(PaddedMatrixAData);
|
|
vst1_u8(D, PackedVector);
|
|
|
|
RowSums = vpadal_u16(RowSums, vpaddl_u8(PackedVector));
|
|
|
|
D += 8;
|
|
}
|
|
|
|
vst1_s32(RowSumBuffer, vreinterpret_s32_u32(RowSums));
|
|
RowSumBuffer += 2;
|
|
|
|
A = A + lda * 2;
|
|
CountM -= 2;
|
|
}
|
|
|
|
//
|
|
// Process one row of matrix A.
|
|
//
|
|
// The buffer is packed as a series of 4 byte with the following pattern:
|
|
//
|
|
// [ A0 A1 A2 A3 ]
|
|
// [ A4 A5 A6 A7 ]
|
|
//
|
|
// This pattern is repeated (CountK / 4) times.
|
|
//
|
|
// If CountK is not aligned to a multiple of four, then the vector is padded
|
|
// with zeroes.
|
|
//
|
|
|
|
if ((CountM & 1) != 0) {
|
|
|
|
const uint8_t* a = A;
|
|
size_t k = CountK;
|
|
uint32x4_t RowSums = vmovq_n_u32(0);
|
|
|
|
while (k >= 16) {
|
|
|
|
uint8x16_t v = vld1q_u8(a);
|
|
a += 16;
|
|
|
|
vst1q_u8(D, v);
|
|
|
|
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v));
|
|
|
|
D += 16;
|
|
k -= 16;
|
|
}
|
|
|
|
if (k > 0) {
|
|
|
|
//
|
|
// Copy the remaining bytes to the zero padded stack buffer.
|
|
//
|
|
|
|
vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0));
|
|
|
|
for (size_t kk = 0; kk < k; kk++) {
|
|
PaddedMatrixAData[kk] = a[kk];
|
|
}
|
|
|
|
uint8x16_t v = vld1q_u8(PaddedMatrixAData);
|
|
vst1q_u8(D, v);
|
|
|
|
RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v));
|
|
}
|
|
|
|
#if defined(MLAS_NEON32_INTRINSICS)
|
|
uint32x2_t RowSumsLow = vpadd_u32(vget_high_u32(RowSums), vget_low_u32(RowSums));
|
|
RowSumsLow = vpadd_u32(RowSumsLow, RowSumsLow);
|
|
vst1_lane_u32(reinterpret_cast<uint32_t*>(RowSumBuffer), RowSumsLow, 0);
|
|
#elif defined(_M_ARM64)
|
|
// N.B. The workaround of defining a local vaddvq_u32 doesn't work here
|
|
// as VS2019 added new intrinsics to make the operation work. Also, not
|
|
// all build environments using VS2019 have the up-to-date arm64_neon.h,
|
|
// so fallback to pairwise addition.
|
|
RowSums = vpaddq_u32(RowSums, RowSums);
|
|
RowSums = vpaddq_u32(RowSums, RowSums);
|
|
vst1q_lane_u32(reinterpret_cast<uint32_t*>(RowSumBuffer), RowSums, 0);
|
|
#else
|
|
*RowSumBuffer = int32_t(vaddvq_u32(RowSums));
|
|
#endif
|
|
}
|
|
}
|
|
|
|
MLAS_FORCEINLINE
|
|
void
|
|
MlasGemmU8X8CopyPackBProcessNeon(
|
|
uint8_t* D,
|
|
const uint8_t* B,
|
|
uint8x8_t BitFlipVector,
|
|
uint32x4_t ColumnSums[2]
|
|
)
|
|
{
|
|
uint8x8_t BytesRow = veor_u8(vld1_u8(B), BitFlipVector);
|
|
vst1_u8(D, BytesRow);
|
|
|
|
uint16x8_t WordsRow = vmovl_u8(BytesRow);
|
|
ColumnSums[0] = vaddq_u32(ColumnSums[0], vmovl_u16(vget_low_u16(WordsRow)));
|
|
#if defined(MLAS_NEON32_INTRINSICS)
|
|
ColumnSums[1] = vaddq_u32(ColumnSums[1], vmovl_u16(vget_high_u16(WordsRow)));
|
|
#else
|
|
ColumnSums[1] = vaddq_u32(ColumnSums[1], vmovl_high_u16(WordsRow));
|
|
#endif
|
|
}
|
|
|
|
void
|
|
MLASCALL
|
|
MlasGemmU8X8CopyPackBNeon(
|
|
uint8_t* D,
|
|
const uint8_t* B,
|
|
size_t ldb,
|
|
size_t CountN,
|
|
size_t CountK,
|
|
int32_t* ColumnSumBuffer,
|
|
bool BIsSigned
|
|
)
|
|
/*++
|
|
|
|
Routine Description:
|
|
|
|
This routine copies elements from the source matrix to the destination
|
|
packed buffer.
|
|
|
|
Arguments:
|
|
|
|
D - Supplies the address of the destination packed buffer.
|
|
|
|
B - Supplies the address of the source matrix.
|
|
|
|
ldb - Supplies the number of elements per row of the source matrix.
|
|
|
|
CountN - Supplies the number of columns of the source matrix to copy.
|
|
|
|
CountK - Supplies the number of rows of the source matrix to copy.
|
|
|
|
ColumnSumBuffer - Supplies the address of the buffer to receive the sums of
|
|
the elements along each of the columns.
|
|
|
|
BIsSigned - Supplies true if the source matrix is signed data, else false
|
|
if the source matrix is unsigned data.
|
|
|
|
Return Value:
|
|
|
|
None.
|
|
|
|
--*/
|
|
{
|
|
const uint8x8_t BitFlipVector = vdup_n_u8(BIsSigned ? 0x80 : 0);
|
|
const uint8x8_t ZeroVector = vmov_n_u8(0);
|
|
const size_t AlignedCountK = (CountK + 3) & ~3;
|
|
|
|
//
|
|
// Process 8 columns of matrix B in a loop.
|
|
//
|
|
// Copy columns from matrix B to the packed buffer. Signed buffers are
|
|
// converted to unsigned buffers in order to share a common kernel.
|
|
//
|
|
// If CountK is not aligned to a multiple of four, then the packed buffer
|
|
// is padded with zero vectors.
|
|
//
|
|
// If CountN is not aligned to a multiple of four, then the extra columns
|
|
// are padded with zeroes.
|
|
//
|
|
|
|
while (CountN >= 8) {
|
|
|
|
const uint8_t* b = B;
|
|
uint32x4_t ColumnSums[2];
|
|
|
|
ColumnSums[0] = vmovq_n_u32(0);
|
|
ColumnSums[1] = vmovq_n_u32(0);
|
|
|
|
for (size_t k = CountK; k > 0; k--) {
|
|
|
|
MlasGemmU8X8CopyPackBProcessNeon(D, b, BitFlipVector, ColumnSums);
|
|
|
|
b += ldb;
|
|
D += 8;
|
|
}
|
|
|
|
for (size_t k = CountK; k < AlignedCountK; k++) {
|
|
vst1_u8(D, ZeroVector);
|
|
D += 8;
|
|
}
|
|
|
|
vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0]));
|
|
vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1]));
|
|
ColumnSumBuffer += 8;
|
|
|
|
B += 8;
|
|
CountN -= 8;
|
|
}
|
|
|
|
//
|
|
// Process the remaining columns of matrix B.
|
|
//
|
|
|
|
if (CountN > 0) {
|
|
|
|
const uint8_t* b = B;
|
|
uint8_t PaddedMatrixBData[8];
|
|
uint32x4_t ColumnSums[2];
|
|
|
|
vst1_u8(PaddedMatrixBData, ZeroVector);
|
|
|
|
ColumnSums[0] = vmovq_n_u32(0);
|
|
ColumnSums[1] = vmovq_n_u32(0);
|
|
|
|
for (size_t k = CountK; k > 0; k--) {
|
|
|
|
for (size_t n = 0; n < CountN; n++) {
|
|
PaddedMatrixBData[n] = b[n];
|
|
}
|
|
|
|
MlasGemmU8X8CopyPackBProcessNeon(D, PaddedMatrixBData, BitFlipVector, ColumnSums);
|
|
|
|
b += ldb;
|
|
D += 8;
|
|
}
|
|
|
|
for (size_t k = CountK; k < AlignedCountK; k++) {
|
|
vst1_u8(D, ZeroVector);
|
|
D += 8;
|
|
}
|
|
|
|
vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0]));
|
|
vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1]));
|
|
}
|
|
}
|
|
|
|
struct MLAS_GEMM_U8X8_KERNEL_NEON
|
|
{
|
|
typedef uint8_t PackedAType;
|
|
typedef uint8_t PackedBType;
|
|
typedef uint8_t OffsetBType;
|
|
|
|
static constexpr size_t PackedK = 4;
|
|
static constexpr MLAS_GEMM_U8X8_STRIDES Strides{24, 128, 256};
|
|
static constexpr MLAS_GEMM_U8X8_STRIDES PackedStrides{24, 256, 128};
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
bool
|
|
TryGemvKernel(
|
|
const uint8_t* A,
|
|
const uint8_t* B,
|
|
size_t ldb,
|
|
int32_t* C,
|
|
size_t CountK,
|
|
size_t CountN,
|
|
bool BIsSigned
|
|
)
|
|
{
|
|
MLAS_UNREFERENCED_PARAMETER(A);
|
|
MLAS_UNREFERENCED_PARAMETER(B);
|
|
MLAS_UNREFERENCED_PARAMETER(ldb);
|
|
MLAS_UNREFERENCED_PARAMETER(C);
|
|
MLAS_UNREFERENCED_PARAMETER(CountK);
|
|
MLAS_UNREFERENCED_PARAMETER(CountN);
|
|
MLAS_UNREFERENCED_PARAMETER(BIsSigned);
|
|
|
|
return false;
|
|
}
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
void
|
|
CopyPackA(
|
|
PackedAType* D,
|
|
const uint8_t* A,
|
|
size_t lda,
|
|
size_t CountM,
|
|
size_t CountK,
|
|
int32_t* RowSumBuffer
|
|
)
|
|
{
|
|
MlasGemmU8X8CopyPackANeon(D, A, lda, CountM, CountK, RowSumBuffer);
|
|
}
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
void
|
|
CopyPackB(
|
|
PackedBType* D,
|
|
const uint8_t* B,
|
|
size_t ldb,
|
|
size_t CountN,
|
|
size_t CountK,
|
|
int32_t* ColumnSumBuffer,
|
|
bool BIsSigned
|
|
)
|
|
{
|
|
MlasGemmU8X8CopyPackBNeon(D, B, ldb, CountN, CountK, ColumnSumBuffer,
|
|
BIsSigned);
|
|
}
|
|
|
|
MLAS_FORCEINLINE
|
|
static
|
|
size_t
|
|
GemmKernel(
|
|
const PackedAType* A,
|
|
const PackedBType* B,
|
|
int32_t* C,
|
|
size_t PackedCountK,
|
|
size_t CountM,
|
|
size_t CountN,
|
|
size_t ldc,
|
|
const int32_t* RowSumBuffer,
|
|
const int32_t* ColumnSumBuffer,
|
|
int32_t DepthValue,
|
|
bool ZeroMode
|
|
)
|
|
{
|
|
return MlasGemmU8X8KernelNeon(A, B, C, PackedCountK, CountM, CountN, ldc,
|
|
RowSumBuffer, ColumnSumBuffer, DepthValue, ZeroMode);
|
|
}
|
|
};
|
|
|
|
constexpr size_t MLAS_GEMM_U8X8_KERNEL_NEON::PackedK;
|
|
constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_NEON::Strides;
|
|
constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_NEON::PackedStrides;
|
|
|
|
template
|
|
void
|
|
MLASCALL
|
|
MlasGemmU8X8PackedOperation<MLAS_GEMM_U8X8_KERNEL_NEON>(
|
|
const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock
|
|
);
|
|
|
|
#endif
|
|
|
|
void
|
|
MlasGemmU8X8Threaded(
|
|
void* Context,
|
|
int32_t ThreadId
|
|
)
|
|
/*++
|
|
|
|
Routine Description:
|
|
|
|
This routine is invoked from a worker thread to execute a segment of a
|
|
QGEMM operation.
|
|
|
|
Arguments:
|
|
|
|
Context - Supplies the pointer to the context for the threaded operation.
|
|
|
|
ThreadId - Supplies the current index of the threaded operation.
|
|
|
|
Return Value:
|
|
|
|
None.
|
|
|
|
--*/
|
|
{
|
|
MLAS_GEMM_U8X8_WORK_BLOCK WorkBlock;
|
|
|
|
memcpy(&WorkBlock, Context, sizeof(MLAS_GEMM_U8X8_WORK_BLOCK));
|
|
|
|
const int32_t ThreadIdM = ThreadId / WorkBlock.ThreadCountN;
|
|
const int32_t ThreadIdN = ThreadId % WorkBlock.ThreadCountN;
|
|
|
|
//
|
|
// Partition the operation along the M dimension.
|
|
//
|
|
|
|
MlasPartitionWork(ThreadIdM, WorkBlock.ThreadCountM, WorkBlock.M,
|
|
&WorkBlock.RangeStartM, &WorkBlock.RangeCountM);
|
|
|
|
//
|
|
// Partition the operation along the N dimension.
|
|
//
|
|
|
|
const size_t BlockedN = (WorkBlock.N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) /
|
|
MLAS_QGEMM_STRIDEN_THREAD_ALIGN;
|
|
|
|
MlasPartitionWork(ThreadIdN, WorkBlock.ThreadCountN, BlockedN,
|
|
&WorkBlock.RangeStartN, &WorkBlock.RangeCountN);
|
|
|
|
WorkBlock.RangeStartN *= MLAS_QGEMM_STRIDEN_THREAD_ALIGN;
|
|
WorkBlock.RangeCountN *= MLAS_QGEMM_STRIDEN_THREAD_ALIGN;
|
|
|
|
WorkBlock.RangeCountN = std::min(WorkBlock.N - WorkBlock.RangeStartN,
|
|
WorkBlock.RangeCountN);
|
|
|
|
//
|
|
// Dispatch the partitioned operation.
|
|
//
|
|
|
|
#if defined(MLAS_TARGET_AMD64)
|
|
PMLAS_GEMM_U8X8_OPERATION GemmU8X8Operation;
|
|
|
|
if (WorkBlock.BIsSigned) {
|
|
GemmU8X8Operation = WorkBlock.BIsPacked ?
|
|
MlasPlatform.GemmU8S8PackedOperation : MlasPlatform.GemmU8S8Operation;
|
|
} else {
|
|
GemmU8X8Operation = WorkBlock.BIsPacked ?
|
|
MlasPlatform.GemmU8U8PackedOperation : MlasPlatform.GemmU8U8Operation;
|
|
}
|
|
|
|
GemmU8X8Operation(&WorkBlock);
|
|
#elif defined(MLAS_SSE2_INTRINSICS)
|
|
MlasGemmU8X8Operation<MLAS_GEMM_U8X8_KERNEL_SSE>(&WorkBlock);
|
|
#elif defined(MLAS_NEON_INTRINSICS)
|
|
if (WorkBlock.BIsPacked) {
|
|
MlasGemmU8X8PackedOperation<MLAS_GEMM_U8X8_KERNEL_NEON>(&WorkBlock);
|
|
} else {
|
|
MlasGemmU8X8Operation<MLAS_GEMM_U8X8_KERNEL_NEON>(&WorkBlock);
|
|
}
|
|
#else
|
|
#error Unsupported architecture.
|
|
#endif
|
|
}
|
|
|
|
void
|
|
MlasGemmU8X8Schedule(
|
|
MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock,
|
|
MLAS_THREADPOOL* ThreadPool
|
|
)
|
|
/*++
|
|
|
|
Routine Description:
|
|
|
|
This routine schedules the quantized integer matrix/matrix multiply
|
|
operation (QGEMM) across one or more threads.
|
|
|
|
Arguments:
|
|
|
|
WorkBlock - Supplies the structure containing the GEMM parameters.
|
|
|
|
ThreadPool - Supplies the thread pool object to use, else nullptr if the
|
|
base library threading support should be used.
|
|
|
|
Return Value:
|
|
|
|
None.
|
|
|
|
--*/
|
|
{
|
|
const size_t M = WorkBlock->M;
|
|
const size_t N = WorkBlock->N;
|
|
const size_t K = WorkBlock->K;
|
|
|
|
//
|
|
// Compute the number of target threads given the complexity of the SGEMM
|
|
// operation. Small requests should run using the single threaded path.
|
|
//
|
|
|
|
const double Complexity = double(M) * double(N) * double(K);
|
|
|
|
int32_t TargetThreadCount;
|
|
|
|
if (Complexity < double(MLAS_QGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) {
|
|
TargetThreadCount = int32_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1;
|
|
} else {
|
|
TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
|
|
}
|
|
|
|
int32_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool);
|
|
|
|
if (TargetThreadCount >= MaximumThreadCount) {
|
|
TargetThreadCount = MaximumThreadCount;
|
|
}
|
|
|
|
//
|
|
// Segment the operation across multiple threads.
|
|
//
|
|
// N.B. Currently, the operation is segmented as a 1D partition, which
|
|
// works okay for operations involving skinny matrices.
|
|
//
|
|
|
|
if (N > M) {
|
|
|
|
const size_t BlockedN = (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) /
|
|
MLAS_QGEMM_STRIDEN_THREAD_ALIGN;
|
|
|
|
if (size_t(TargetThreadCount) > BlockedN) {
|
|
TargetThreadCount = int32_t(BlockedN);
|
|
}
|
|
|
|
WorkBlock->ThreadCountM = 1;
|
|
WorkBlock->ThreadCountN = TargetThreadCount;
|
|
|
|
} else {
|
|
|
|
if (size_t(TargetThreadCount) > M) {
|
|
TargetThreadCount = int32_t(M);
|
|
}
|
|
|
|
WorkBlock->ThreadCountM = TargetThreadCount;
|
|
WorkBlock->ThreadCountN = 1;
|
|
}
|
|
|
|
MlasExecuteThreaded(MlasGemmU8X8Threaded, WorkBlock, TargetThreadCount, ThreadPool);
|
|
}
|
|
|
|
void
|
|
MLASCALL
|
|
MlasGemm(
|
|
size_t M,
|
|
size_t N,
|
|
size_t K,
|
|
const uint8_t* A,
|
|
size_t lda,
|
|
uint8_t offa,
|
|
const uint8_t* B,
|
|
size_t ldb,
|
|
uint8_t offb,
|
|
bool BIsSigned,
|
|
int32_t* C,
|
|
size_t ldc,
|
|
MLAS_THREADPOOL* ThreadPool,
|
|
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor
|
|
)
|
|
/*++
|
|
|
|
Routine Description:
|
|
|
|
This routine implements the quantized integer matrix/matrix multiply
|
|
operation (QGEMM).
|
|
|
|
Arguments:
|
|
|
|
M - Supplies the number of rows of matrix A and matrix C.
|
|
|
|
N - Supplies the number of columns of matrix B and matrix C.
|
|
|
|
K - Supplies the number of columns of matrix A and the number of rows of
|
|
matrix B.
|
|
|
|
A - Supplies the address of matrix A.
|
|
|
|
lda - Supplies the first dimension of matrix A.
|
|
|
|
offa - Supplies the zero point offset of matrix A.
|
|
|
|
B - Supplies the address of matrix B.
|
|
|
|
ldb - Supplies the first dimension of matrix B.
|
|
|
|
offb - Supplies the zero point offset of matrix B.
|
|
|
|
BIsSigned - Supplies true if matrix B is signed data, else false if matrix
|
|
B is unsigned data.
|
|
|
|
C - Supplies the address of matrix C.
|
|
|
|
ldc - Supplies the first dimension of matrix C.
|
|
|
|
ThreadPool - Supplies the thread pool object to use, else nullptr if the
|
|
base library threading support should be used.
|
|
|
|
OutputProcessor - Post Processor on C.
|
|
|
|
Return Value:
|
|
|
|
None.
|
|
|
|
--*/
|
|
{
|
|
MLAS_GEMM_U8X8_WORK_BLOCK WorkBlock;
|
|
|
|
//
|
|
// Capture the GEMM parameters to the work block.
|
|
//
|
|
|
|
memset(&WorkBlock, 0, sizeof(MLAS_GEMM_U8X8_WORK_BLOCK));
|
|
|
|
WorkBlock.M = M;
|
|
WorkBlock.N = N;
|
|
WorkBlock.K = K;
|
|
WorkBlock.A = A;
|
|
WorkBlock.lda = lda;
|
|
WorkBlock.B = B;
|
|
WorkBlock.ldb = ldb;
|
|
WorkBlock.C = C;
|
|
WorkBlock.ldc = ldc;
|
|
WorkBlock.OutputProcessor = OutputProcessor;
|
|
WorkBlock.offa = offa;
|
|
WorkBlock.offb = offb;
|
|
WorkBlock.BIsSigned = BIsSigned;
|
|
|
|
//
|
|
// Schedule the operation across a set of worker threads.
|
|
//
|
|
|
|
MlasGemmU8X8Schedule(&WorkBlock, ThreadPool);
|
|
}
|
|
|
|
#endif // MLAS_SUPPORTS_GEMM_U8X8
|
|
|
|
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
|
|
|
void
|
|
MLASCALL
|
|
MlasGemm(
|
|
size_t M,
|
|
size_t N,
|
|
size_t K,
|
|
const uint8_t* A,
|
|
size_t lda,
|
|
uint8_t offa,
|
|
const void* PackedB,
|
|
uint8_t offb,
|
|
bool BIsSigned,
|
|
int32_t* C,
|
|
size_t ldc,
|
|
MLAS_THREADPOOL* ThreadPool,
|
|
const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor
|
|
)
|
|
/*++
|
|
|
|
Routine Description:
|
|
|
|
This routine implements the quantized integer matrix/matrix multiply
|
|
operation (QGEMM).
|
|
|
|
Arguments:
|
|
|
|
M - Supplies the number of rows of matrix A and matrix C.
|
|
|
|
N - Supplies the number of columns of matrix B and matrix C.
|
|
|
|
K - Supplies the number of columns of matrix A and the number of rows of
|
|
matrix B.
|
|
|
|
A - Supplies the address of matrix A.
|
|
|
|
lda - Supplies the first dimension of matrix A.
|
|
|
|
offa - Supplies the zero point offset of matrix A.
|
|
|
|
PackedB - Supplies the address of packed matrix B.
|
|
|
|
offb - Supplies the zero point offset of matrix B.
|
|
|
|
BIsSigned - Supplies true if matrix B is signed data, else false if matrix
|
|
B is unsigned data.
|
|
|
|
C - Supplies the address of matrix C.
|
|
|
|
ldc - Supplies the first dimension of matrix C.
|
|
|
|
ThreadPool - Supplies the thread pool object to use, else nullptr if the
|
|
base library threading support should be used.
|
|
|
|
OutputProcessor - Post Processor on C
|
|
|
|
Return Value:
|
|
|
|
None.
|
|
|
|
--*/
|
|
{
|
|
MLAS_GEMM_U8X8_WORK_BLOCK WorkBlock;
|
|
|
|
//
|
|
// Capture the GEMM parameters to the work block.
|
|
//
|
|
|
|
memset(&WorkBlock, 0, sizeof(MLAS_GEMM_U8X8_WORK_BLOCK));
|
|
|
|
WorkBlock.M = M;
|
|
WorkBlock.N = N;
|
|
WorkBlock.K = K;
|
|
WorkBlock.A = A;
|
|
WorkBlock.lda = lda;
|
|
WorkBlock.B = PackedB;
|
|
WorkBlock.C = (int32_t*)C;
|
|
WorkBlock.ldc = ldc;
|
|
WorkBlock.OutputProcessor = OutputProcessor,
|
|
WorkBlock.offa = offa;
|
|
WorkBlock.offb = offb;
|
|
WorkBlock.BIsPacked = true;
|
|
WorkBlock.BIsSigned = BIsSigned;
|
|
|
|
//
|
|
// Schedule the operation across a set of worker threads.
|
|
//
|
|
|
|
MlasGemmU8X8Schedule(&WorkBlock, ThreadPool);
|
|
}
|
|
|
|
size_t
|
|
MLASCALL
|
|
MlasGemmPackBSize(
|
|
size_t N,
|
|
size_t K,
|
|
bool BIsSigned
|
|
)
|
|
/*++
|
|
|
|
Routine Description:
|
|
|
|
This routine computes the number of bytes required to pack a matrix with
|
|
the supplied shape and type.
|
|
|
|
Arguments:
|
|
|
|
N - Supplies the number of columns of matrix B.
|
|
|
|
K - Supplies the the number of rows of matrix B.
|
|
|
|
BIsSigned - Supplies true if matrix B is signed data, else false if matrix
|
|
B is unsigned data.
|
|
|
|
Return Value:
|
|
|
|
Returns the number of bytes required to pack the matrix.
|
|
|
|
--*/
|
|
{
|
|
//
|
|
// Retrieve the packing parameters based on the packed operation function.
|
|
//
|
|
|
|
size_t PackedK;
|
|
|
|
#if defined(MLAS_TARGET_AMD64)
|
|
PMLAS_GEMM_U8X8_OPERATION GemmU8X8Operation = BIsSigned ?
|
|
MlasPlatform.GemmU8S8PackedOperation : MlasPlatform.GemmU8U8PackedOperation;
|
|
|
|
if (GemmU8X8Operation == &MlasGemmU8X8PackedOperation<MLAS_GEMM_U8S8_KERNEL_AVX2>) {
|
|
PackedK = MLAS_GEMM_U8S8_KERNEL_AVX2::PackedK;
|
|
} else if (GemmU8X8Operation == &MlasGemmU8X8PackedOperation<MLAS_GEMM_U8U8_KERNEL_AVX2>) {
|
|
PackedK = MLAS_GEMM_U8U8_KERNEL_AVX2::PackedK;
|
|
} else {
|
|
return 0;
|
|
}
|
|
#elif defined(MLAS_NEON_INTRINSICS)
|
|
MLAS_UNREFERENCED_PARAMETER(BIsSigned);
|
|
|
|
PackedK = MLAS_GEMM_U8X8_KERNEL_NEON::PackedK;
|
|
#else
|
|
#error Unknown architecture.
|
|
#endif
|
|
|
|
//
|
|
// Compute the number of bytes required to hold the packed buffer.
|
|
//
|
|
|
|
const size_t AlignedN =
|
|
(N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1);
|
|
const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1);
|
|
|
|
const size_t BytesRequired =
|
|
(AlignedN * sizeof(int32_t)) + (AlignedN * AlignedK * sizeof(uint8_t));
|
|
const size_t BufferAlignment = MlasGetPreferredBufferAlignment();
|
|
const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) &
|
|
~(BufferAlignment - 1);
|
|
|
|
return AlignedBytesRequired;
|
|
}
|
|
|
|
void
|
|
MLASCALL
|
|
MlasGemmPackB(
|
|
size_t N,
|
|
size_t K,
|
|
const uint8_t* B,
|
|
size_t ldb,
|
|
bool BIsSigned,
|
|
void* PackedB
|
|
)
|
|
/*++
|
|
|
|
Routine Description:
|
|
|
|
This routine packs the supplied matrix B to the supplied packed matrix B
|
|
buffer. The size of the packed buffer was obtained from MlasGemmPackBSize.
|
|
|
|
Arguments:
|
|
|
|
N - Supplies the number of columns of matrix B.
|
|
|
|
K - Supplies the the number of rows of matrix B.
|
|
|
|
B - Supplies the address of matrix B.
|
|
|
|
ldb - Supplies the first dimension of matrix B.
|
|
|
|
BIsSigned - Supplies true if matrix B is signed data, else false if matrix
|
|
B is unsigned data.
|
|
|
|
PackedB - Supplies the address of packed matrix B.
|
|
|
|
Return Value:
|
|
|
|
None.
|
|
|
|
--*/
|
|
{
|
|
//
|
|
// Retrieve the packing parameters based on the packed operation function.
|
|
//
|
|
|
|
size_t PackedK;
|
|
size_t StrideK;
|
|
|
|
#if defined(MLAS_TARGET_AMD64)
|
|
PMLAS_GEMM_U8X8_OPERATION GemmU8X8Operation = BIsSigned ?
|
|
MlasPlatform.GemmU8S8PackedOperation : MlasPlatform.GemmU8U8PackedOperation;
|
|
|
|
if (GemmU8X8Operation == &MlasGemmU8X8PackedOperation<MLAS_GEMM_U8S8_KERNEL_AVX2>) {
|
|
PackedK = MLAS_GEMM_U8S8_KERNEL_AVX2::PackedK;
|
|
StrideK = MLAS_GEMM_U8S8_KERNEL_AVX2::PackedStrides.K;
|
|
} else if (GemmU8X8Operation == &MlasGemmU8X8PackedOperation<MLAS_GEMM_U8U8_KERNEL_AVX2>) {
|
|
PackedK = MLAS_GEMM_U8U8_KERNEL_AVX2::PackedK;
|
|
StrideK = MLAS_GEMM_U8U8_KERNEL_AVX2::PackedStrides.K;
|
|
} else {
|
|
#ifdef MLAS_NO_EXCEPTION
|
|
abort();
|
|
#else
|
|
throw std::runtime_error("packing unavailable");
|
|
#endif
|
|
}
|
|
#elif defined(MLAS_NEON_INTRINSICS)
|
|
PackedK = MLAS_GEMM_U8X8_KERNEL_NEON::PackedK;
|
|
StrideK = MLAS_GEMM_U8X8_KERNEL_NEON::PackedStrides.K;
|
|
#else
|
|
#error Unknown architecture.
|
|
#endif
|
|
|
|
//
|
|
// Reserve and initialize storage for the column sum buffer to hold the sums
|
|
// of the elements along each of the columns.
|
|
//
|
|
|
|
const size_t AlignedN =
|
|
(N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1);
|
|
|
|
int32_t* PackedColumnSumBuffer = (int32_t*)PackedB;
|
|
std::fill_n(PackedColumnSumBuffer, AlignedN, 0);
|
|
PackedB = PackedColumnSumBuffer + AlignedN;
|
|
|
|
//
|
|
// Step through each slice of matrix B along the K dimension.
|
|
//
|
|
|
|
size_t CountK;
|
|
|
|
for (size_t k = 0; k < K; k += CountK) {
|
|
|
|
CountK = std::min(K - k, StrideK);
|
|
|
|
//
|
|
// Step through each slice of matrix B along the N dimension.
|
|
//
|
|
|
|
const size_t AlignedK = (CountK + PackedK - 1) & ~(PackedK - 1);
|
|
uint8_t* pb = (uint8_t*)PackedB;
|
|
size_t CountN;
|
|
|
|
for (size_t n = 0; n < N; n += CountN) {
|
|
|
|
constexpr size_t BatchedN = 128;
|
|
MLAS_DECLSPEC_ALIGN(int32_t ColumnSumBuffer[BatchedN], 64);
|
|
|
|
CountN = std::min(N - n, BatchedN);
|
|
|
|
#if defined(MLAS_TARGET_AMD64)
|
|
if (GemmU8X8Operation == &MlasGemmU8X8PackedOperation<MLAS_GEMM_U8S8_KERNEL_AVX2>) {
|
|
MLAS_GEMM_U8S8_KERNEL_AVX2::CopyPackB(pb, B + n, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned);
|
|
} else {
|
|
MLAS_GEMM_U8U8_KERNEL_AVX2::CopyPackB(pb, B + n, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned);
|
|
}
|
|
#elif defined(MLAS_NEON_INTRINSICS)
|
|
MLAS_GEMM_U8X8_KERNEL_NEON::CopyPackB(pb, B + n, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned);
|
|
#else
|
|
#error Unknown architecture.
|
|
#endif
|
|
|
|
//
|
|
// Accumulate this batch of the column sum buffer into the packed
|
|
// buffer accumulators.
|
|
//
|
|
|
|
for (size_t nn = 0; nn < CountN; nn++) {
|
|
PackedColumnSumBuffer[n + nn] += ColumnSumBuffer[nn];
|
|
}
|
|
|
|
pb += CountN * AlignedK;
|
|
}
|
|
|
|
PackedB = (uint8_t*)PackedB + AlignedN * AlignedK;
|
|
B += ldb * CountK;
|
|
}
|
|
}
|
|
|
|
#endif // MLAS_SUPPORTS_PACKED_GEMM_U8X8
|