MLAS: handle MlasGemm(M/N/K==0) cases (#6238)

This commit is contained in:
Tracy Sharpe 2020-12-30 23:25:10 -08:00 committed by GitHub
parent 4cc2ffef21
commit ecb2e119e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 312 additions and 294 deletions

View file

@ -32,21 +32,21 @@ Abstract:
//
struct MLAS_DGEMM_WORK_BLOCK {
int32_t ThreadCountM;
int32_t ThreadCountN;
CBLAS_TRANSPOSE TransA;
CBLAS_TRANSPOSE TransB;
size_t M;
size_t N;
size_t K;
const double* A;
size_t lda;
const double* B;
size_t ldb;
double* C;
size_t ldc;
double alpha;
double beta;
struct SEGMENT {
size_t M;
size_t N;
const double* A;
const double* B;
double* C;
} Segments[MLAS_MAXIMUM_THREAD_COUNT];
};
#ifdef MLAS_TARGET_AMD64
@ -86,7 +86,7 @@ Return Value:
{
MLAS_FLOAT64X2 BetaBroadcast = MlasBroadcastFloat64x2(beta);
do {
while (CountM-- > 0) {
double* c = C;
size_t n = CountN;
@ -106,9 +106,7 @@ Return Value:
}
C += ldc;
CountM--;
} while (CountM > 0);
}
}
void
@ -497,6 +495,82 @@ Return Value:
}
}
MLAS_FORCEINLINE
double*
MlasDgemmKernelLoop(
const double* A,
const double* B,
double* C,
size_t CountK,
size_t CountM,
size_t CountN,
size_t lda,
size_t ldc,
double alpha,
bool ZeroMode
)
/*++
Routine Description:
This routine steps through the rows of the input and output matrices calling
the kernel until all rows have been processed.
Arguments:
A - Supplies the address of matrix A.
B - Supplies the address of matrix B. The matrix data has been packed using
MlasDgemmCopyPackB or MlasDgemmTransposePackB.
C - Supplies the address of matrix C.
CountK - Supplies the number of columns from matrix A and the number of rows
from matrix B to iterate over.
CountM - Supplies the number of rows from matrix A and matrix C to iterate
over.
CountN - Supplies the number of columns from matrix B and matrix C to
iterate over.
lda - Supplies the first dimension of matrix A.
ldc - Supplies the first dimension of matrix C.
alpha - Supplies the scalar alpha multiplier (see DGEMM definition).
ZeroMode - Supplies true if the output matrix must be zero initialized,
else false if the output matrix is accumulated into.
Return Value:
Returns the next address of matrix C.
--*/
{
while (CountM > 0) {
size_t RowsHandled;
#if defined(MLAS_TARGET_AMD64_IX86)
RowsHandled = MlasPlatform.GemmDoubleKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode);
#else
if (ZeroMode) {
RowsHandled = MlasDgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, alpha);
} else {
RowsHandled = MlasDgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, alpha);
}
#endif
C += ldc * RowsHandled;
A += lda * RowsHandled;
CountM -= RowsHandled;
}
return C;
}
void
MlasDgemmOperation(
CBLAS_TRANSPOSE TransA,
@ -558,6 +632,16 @@ Return Value:
double PanelA[MLAS_DGEMM_TRANSA_ROWS * MLAS_DGEMM_STRIDEK];
MLAS_DECLSPEC_ALIGN(double PanelB[MLAS_DGEMM_STRIDEN * MLAS_DGEMM_STRIDEK], 8 * sizeof(double));
//
// Handle the special case of K equals zero. Apply the beta multiplier to
// the output matrix and exit.
//
if (K == 0) {
MlasDgemmMultiplyBeta(C, M, N, ldc, beta);
return;
}
//
// Compute the strides to step through slices of the input matrices.
//
@ -566,8 +650,8 @@ Return Value:
// the A panel needs to be used for transposing.
//
uint32_t StrideN = MLAS_DGEMM_STRIDEN;
uint32_t StrideK = MLAS_DGEMM_STRIDEK;
size_t StrideN = MLAS_DGEMM_STRIDEN;
size_t StrideK = MLAS_DGEMM_STRIDEK;
if (N >= K) {
@ -589,15 +673,10 @@ Return Value:
//
size_t CountN;
size_t CountK;
for (size_t n = 0; n < N; n += CountN) {
CountN = StrideN;
if (CountN > (N - n)) {
CountN = N - n;
}
CountN = std::min(N - n, StrideN);
//
// Multiply the output matrix by beta as needed.
@ -611,15 +690,12 @@ Return Value:
// Step through each slice of matrix B along the K dimension.
//
size_t CountK;
bool ZeroMode = (beta == 0.0f);
for (size_t k = 0; k < K; k += CountK) {
bool ZeroMode = (k == 0 && beta == 0.0f);
CountK = StrideK;
if (CountK > (K - k)) {
CountK = K - k;
}
CountK = std::min(K - k, StrideK);
//
// Copy or transpose a panel of matrix B to a local packed buffer.
@ -637,93 +713,45 @@ Return Value:
double* c = C + n;
size_t RowsRemaining = M;
size_t RowsHandled;
if (TransA == CblasNoTrans) {
const double* a = A + k;
//
// Step through the rows of matrix A.
//
do {
#if defined(MLAS_TARGET_AMD64_IX86)
RowsHandled = MlasPlatform.GemmDoubleKernel(a, PanelB, c, CountK, RowsRemaining, CountN, lda, ldc, alpha, ZeroMode);
#else
if (ZeroMode) {
RowsHandled = MlasDgemmKernelZero(a, PanelB, c, CountK, RowsRemaining, CountN, lda, ldc, alpha);
} else {
RowsHandled = MlasDgemmKernelAdd(a, PanelB, c, CountK, RowsRemaining, CountN, lda, ldc, alpha);
}
#endif
c += ldc * RowsHandled;
a += lda * RowsHandled;
RowsRemaining -= RowsHandled;
} while (RowsRemaining > 0);
MlasDgemmKernelLoop(A + k, PanelB, c, CountK, M, CountN, lda, ldc, alpha, ZeroMode);
} else {
const double* a = A + k * lda;
size_t RowsRemaining = M;
do {
while (RowsRemaining > 0) {
//
// Transpose elements from matrix A into a local buffer.
//
size_t RowsTransposed = RowsRemaining;
if (RowsTransposed > MLAS_DGEMM_TRANSA_ROWS) {
RowsTransposed = MLAS_DGEMM_TRANSA_ROWS;
}
RowsRemaining -= RowsTransposed;
size_t RowsTransposed = std::min(RowsRemaining, size_t(MLAS_DGEMM_TRANSA_ROWS));
MlasDgemmTransposeA(PanelA, a, lda, RowsTransposed, CountK);
RowsRemaining -= RowsTransposed;
a += RowsTransposed;
//
// Step through the rows of the local buffer.
//
const double* pa = PanelA;
do {
#if defined(MLAS_TARGET_AMD64_IX86)
RowsHandled = MlasPlatform.GemmDoubleKernel(pa, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode);
#else
if (ZeroMode) {
RowsHandled = MlasDgemmKernelZero(pa, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha);
} else {
RowsHandled = MlasDgemmKernelAdd(pa, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha);
}
#endif
c += ldc * RowsHandled;
pa += CountK * RowsHandled;
RowsTransposed -= RowsHandled;
} while (RowsTransposed > 0);
} while (RowsRemaining > 0);
c = MlasDgemmKernelLoop(PanelA, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode);
}
}
ZeroMode = false;
}
}
}
void
MlasDgemmOperationThreaded(
MlasDgemmThreaded(
void* Context,
int32_t Index
int32_t ThreadId
)
/*++
@ -736,7 +764,7 @@ Arguments:
Context - Supplies the pointer to the context for the threaded operation.
Index - Supplies the current index of the threaded operation.
ThreadId - Supplies the current index of the threaded operation.
Return Value:
@ -744,19 +772,147 @@ Return Value:
--*/
{
MLAS_DGEMM_WORK_BLOCK* WorkBlock = (MLAS_DGEMM_WORK_BLOCK*)Context;
const auto* WorkBlock = (MLAS_DGEMM_WORK_BLOCK*)Context;
MLAS_DGEMM_WORK_BLOCK::SEGMENT* Segment = &WorkBlock->Segments[Index];
const int32_t ThreadCountM = WorkBlock->ThreadCountM;
const int32_t ThreadCountN = WorkBlock->ThreadCountN;
MlasDgemmOperation(WorkBlock->TransA, WorkBlock->TransB, Segment->M,
Segment->N, WorkBlock->K, WorkBlock->alpha, Segment->A, WorkBlock->lda,
Segment->B, WorkBlock->ldb, WorkBlock->beta, Segment->C,
WorkBlock->ldc);
const int32_t ThreadIdM = ThreadId / ThreadCountN;
const int32_t ThreadIdN = ThreadId % ThreadCountN;
//
// Partition the operation along the M dimension.
//
size_t M = WorkBlock->M;
size_t RangeStartM;
size_t RangeCountM;
MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM);
//
// Partition the operation along the N dimension.
//
size_t N = WorkBlock->N;
size_t RangeStartN;
size_t RangeCountN;
const size_t BlockedN = (N + MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1) /
MLAS_DGEMM_STRIDEN_THREAD_ALIGN;
MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN,
&RangeCountN);
RangeStartN *= MLAS_DGEMM_STRIDEN_THREAD_ALIGN;
RangeCountN *= MLAS_DGEMM_STRIDEN_THREAD_ALIGN;
RangeCountN = std::min(N - RangeStartN, RangeCountN);
//
// Dispatch the partitioned operation.
//
CBLAS_TRANSPOSE TransA = WorkBlock->TransA;
CBLAS_TRANSPOSE TransB = WorkBlock->TransB;
const size_t lda = WorkBlock->lda;
const size_t ldb = WorkBlock->ldb;
const size_t ldc = WorkBlock->ldc;
const double* A = WorkBlock->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1);
const double* B = WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb);
double* C = WorkBlock->C + RangeStartM * ldc + RangeStartN;
MlasDgemmOperation(TransA, TransB, RangeCountM, RangeCountN, WorkBlock->K,
WorkBlock->alpha, A, lda, B, ldb, WorkBlock->beta, C, ldc);
}
inline
bool
MlasDgemmTryMultithread(
void
MlasDgemmSchedule(
MLAS_DGEMM_WORK_BLOCK* WorkBlock,
MLAS_THREADPOOL* ThreadPool
)
/*++
Routine Description:
This routine schedules the double precision matrix/matrix multiply
operation (DGEMM) across one or more threads.
Arguments:
WorkBlock - Supplies the structure containing the GEMM parameters.
ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
Return Value:
None.
--*/
{
const size_t M = WorkBlock->M;
const size_t N = WorkBlock->N;
const size_t K = WorkBlock->K;
//
// Compute the number of target threads given the complexity of the DGEMM
// operation. Small requests should run using the single threaded path.
//
const double Complexity = double(M) * double(N) * double(K);
int32_t TargetThreadCount;
if (Complexity < double(MLAS_DGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) {
TargetThreadCount = int32_t(Complexity / double(MLAS_DGEMM_THREAD_COMPLEXITY)) + 1;
} else {
TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
}
int32_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool);
if (TargetThreadCount >= MaximumThreadCount) {
TargetThreadCount = MaximumThreadCount;
}
//
// Segment the operation across multiple threads.
//
// N.B. Currently, the operation is segmented as a 1D partition, which
// works okay for operations involving skinny matrices.
//
if (N > M) {
const size_t BlockedN = (N + MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1) /
MLAS_DGEMM_STRIDEN_THREAD_ALIGN;
if (size_t(TargetThreadCount) > BlockedN) {
TargetThreadCount = int32_t(BlockedN);
}
WorkBlock->ThreadCountM = 1;
WorkBlock->ThreadCountN = TargetThreadCount;
} else {
if (size_t(TargetThreadCount) > M) {
TargetThreadCount = int32_t(M);
}
WorkBlock->ThreadCountM = TargetThreadCount;
WorkBlock->ThreadCountN = 1;
}
MlasExecuteThreaded(MlasDgemmThreaded, WorkBlock, TargetThreadCount, ThreadPool);
}
void
MLASCALL
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
@ -776,8 +932,8 @@ MlasDgemmTryMultithread(
Routine Description:
This routine attempts to launch a single precision matrix/matrix multiply
operation (DGEMM) across multiple threads.
This routine implements the double precision matrix/matrix multiply
operation (DGEMM).
Arguments:
@ -813,190 +969,37 @@ Arguments:
Return Value:
Returns true if the operation was completed across multiple threads, else
false if the operation should fall back to a single thread.
None.
--*/
{
MLAS_DGEMM_WORK_BLOCK WorkBlock;
int32_t TargetThreadCount;
//
// Compute the number of target threads given the complexity of the DGEMM
// operation. Small requests should run using the single threaded path.
// Capture the GEMM parameters to the work block.
//
double Complexity = double(M) * double(N) * double(K);
if (Complexity < double(MLAS_DGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) {
TargetThreadCount = int32_t(Complexity / double(MLAS_DGEMM_THREAD_COMPLEXITY)) + 1;
} else {
TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
}
int32_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool);
if (TargetThreadCount >= MaximumThreadCount) {
TargetThreadCount = MaximumThreadCount;
}
if (TargetThreadCount == 1) {
return false;
}
//
// Initialize the common fields of the work block.
//
memset(&WorkBlock, 0, sizeof(MLAS_DGEMM_WORK_BLOCK));
WorkBlock.TransA = TransA;
WorkBlock.TransB = TransB;
WorkBlock.M = M;
WorkBlock.N = N;
WorkBlock.K = K;
WorkBlock.A = A;
WorkBlock.lda = lda;
WorkBlock.B = B;
WorkBlock.ldb = ldb;
WorkBlock.C = C;
WorkBlock.ldc = ldc;
WorkBlock.alpha = alpha;
WorkBlock.beta = beta;
//
// Segment the operation across multiple threads.
// Schedule the operation across a set of worker threads.
//
int32_t Index = 0;
if (N > M) {
size_t StrideN = N / TargetThreadCount;
if ((StrideN * TargetThreadCount) != N) {
StrideN++;
}
StrideN =
(StrideN + MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1);
size_t pldb = (TransB == CblasNoTrans) ? 1 : ldb;
for (size_t CountN, n = 0; n < N; n += CountN) {
CountN = StrideN;
if (CountN > (N - n)) {
CountN = N - n;
}
WorkBlock.Segments[Index].M = M;
WorkBlock.Segments[Index].N = CountN;
WorkBlock.Segments[Index].A = A;
WorkBlock.Segments[Index].B = B + n * pldb;
WorkBlock.Segments[Index].C = C + n;
Index++;
}
} else {
size_t StrideM = M / TargetThreadCount;
if ((StrideM * TargetThreadCount) != M) {
StrideM++;
}
size_t plda = (TransA == CblasNoTrans) ? lda : 1;
for (size_t CountM, m = 0; m < M; m += CountM) {
CountM = StrideM;
if (CountM > (M - m)) {
CountM = M - m;
}
WorkBlock.Segments[Index].M = CountM;
WorkBlock.Segments[Index].N = N;
WorkBlock.Segments[Index].A = A + m * plda;
WorkBlock.Segments[Index].B = B;
WorkBlock.Segments[Index].C = C + m * ldc;
Index++;
}
}
MlasExecuteThreaded(MlasDgemmOperationThreaded, &WorkBlock, Index, ThreadPool);
return true;
}
void
MLASCALL
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
double alpha,
const double* A,
size_t lda,
const double* B,
size_t ldb,
double beta,
double* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
)
/*++
Routine Description:
This routine implements the single precision matrix/matrix multiply
operation (SGEMM).
Arguments:
TransA - Supplies the transpose operation for matrix A.
TransB - Supplies the transpose operation for matrix B.
M - Supplies the number of rows of matrix A and matrix C.
N - Supplies the number of columns of matrix B and matrix C.
K - Supplies the number of columns of matrix A and the number of rows of
matrix B.
alpha - Supplies the scalar alpha multiplier (see SGEMM definition).
A - Supplies the address of matrix A.
lda - Supplies the first dimension of matrix A.
B - Supplies the address of matrix B.
ldb - Supplies the first dimension of matrix B.
beta - Supplies the scalar beta multiplier (see SGEMM definition).
C - Supplies the address of matrix C.
ldc - Supplies the first dimension of matrix C.
ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
Return Value:
None.
--*/
{
//
// Try to run the operation across multiple threads or fall back to a
// single thread based on the GEMM parameters and system configuration.
//
if (!MlasDgemmTryMultithread(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, ThreadPool)) {
MlasDgemmOperation(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}
MlasDgemmSchedule(&WorkBlock, ThreadPool);
}
#endif

View file

@ -41,13 +41,13 @@ struct MLAS_SGEMM_WORK_BLOCK {
size_t K;
const float* A;
size_t lda;
const float* B;
const void* B;
size_t ldb;
const void* PackedB;
float* C;
size_t ldc;
float alpha;
float beta;
bool BIsPacked;
};
void
@ -85,7 +85,7 @@ Return Value:
{
MLAS_FLOAT32X4 BetaBroadcast = MlasBroadcastFloat32x4(beta);
do {
while (CountM-- > 0) {
float* c = C;
size_t n = CountN;
@ -107,9 +107,7 @@ Return Value:
}
C += ldc;
CountM--;
} while (CountM > 0);
}
}
void
@ -807,7 +805,7 @@ Return Value:
--*/
{
do {
while (CountM > 0) {
size_t RowsHandled;
@ -826,8 +824,7 @@ Return Value:
C += ldc * RowsHandled;
A += lda * RowsHandled;
CountM -= RowsHandled;
} while (CountM > 0);
}
return C;
}
@ -893,6 +890,16 @@ Return Value:
float PanelA[MLAS_SGEMM_TRANSA_ROWS * MLAS_SGEMM_STRIDEK];
MLAS_DECLSPEC_ALIGN(float PanelB[MLAS_SGEMM_STRIDEN * MLAS_SGEMM_STRIDEK], 16 * sizeof(float));
//
// Handle the special case of K equals zero. Apply the beta multiplier to
// the output matrix and exit.
//
if (K == 0) {
MlasSgemmMultiplyBeta(C, M, N, ldc, beta);
return;
}
//
// Handle the special case of a small M. The data from matrix B is not
// referenced multiple times, so using a local packed buffer is a wasted
@ -1035,7 +1042,7 @@ Return Value:
const float* a = A + k * lda;
size_t RowsRemaining = M;
do {
while (RowsRemaining > 0) {
//
// Transpose elements from matrix A into a local buffer.
@ -1053,8 +1060,7 @@ Return Value:
//
c = MlasSgemmKernelLoop(PanelA, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode);
} while (RowsRemaining > 0);
}
}
ZeroMode = false;
@ -1171,7 +1177,7 @@ Return Value:
const float* a = A + k * lda;
size_t RowsRemaining = M;
do {
while (RowsRemaining > 0) {
//
// Transpose elements from matrix A into a local buffer.
@ -1189,8 +1195,7 @@ Return Value:
//
c = MlasSgemmKernelLoop(PanelA, pb, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode);
} while (RowsRemaining > 0);
}
}
ZeroMode = false;
@ -1271,22 +1276,22 @@ Return Value:
const float* A = WorkBlock->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1);
float* C = WorkBlock->C + RangeStartM * ldc + RangeStartN;
if (WorkBlock->B != nullptr) {
if (WorkBlock->BIsPacked) {
MlasSgemmPackedOperation(TransA, RangeCountM, RangeStartN, RangeCountN,
WorkBlock->K, WorkBlock->alpha, A, lda, WorkBlock->B,
BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, WorkBlock->beta, C, ldc);
} else {
CBLAS_TRANSPOSE TransB = WorkBlock->TransB;
const size_t ldb = WorkBlock->ldb;
const float* B = WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb);
const float* B = (const float*)WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb);
MlasSgemmOperation(TransA, TransB, RangeCountM, RangeCountN, WorkBlock->K,
WorkBlock->alpha, A, lda, B, ldb, WorkBlock->beta, C, ldc);
} else {
MlasSgemmPackedOperation(TransA, RangeCountM, RangeStartN, RangeCountN,
WorkBlock->K, WorkBlock->alpha, A, lda, WorkBlock->PackedB,
BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, WorkBlock->beta, C, ldc);
}
}
@ -1535,11 +1540,12 @@ Return Value:
WorkBlock.K = K;
WorkBlock.A = A;
WorkBlock.lda = lda;
WorkBlock.PackedB = PackedB;
WorkBlock.B = PackedB;
WorkBlock.C = C;
WorkBlock.ldc = ldc;
WorkBlock.alpha = alpha;
WorkBlock.beta = beta;
WorkBlock.BIsPacked = true;
//
// Schedule the operation across a set of worker threads.

View file

@ -202,10 +202,10 @@ public:
};
template<typename T, bool Packed>
class MlasFgemmTestBase;
class FgemmPackedContext;
template<typename T>
class MlasFgemmTestBase<T, false> : public MlasTestBase
class FgemmPackedContext<T, false>
{
public:
void
@ -230,7 +230,7 @@ public:
};
template<typename T>
class MlasFgemmTestBase<T, true> : public MlasTestBase
class FgemmPackedContext<T, true>
{
public:
void
@ -261,7 +261,7 @@ private:
};
template<typename T, bool Packed>
class MlasFgemmTest : public MlasFgemmTestBase<T, Packed>
class MlasFgemmTest : public MlasTestBase
{
private:
void
@ -273,6 +273,14 @@ private:
float beta
)
{
//
// Skip the test if the B buffer cannot be packed.
//
if (Packed && (N == 0 || K == 0)) {
return;
}
const T* A = BufferA.GetBuffer(K * M);
const T* B = BufferB.GetBuffer(N * K);
T* C = BufferC.GetBuffer(N * M);
@ -305,7 +313,7 @@ private:
std::fill_n(C, M * N, -0.5f);
std::fill_n(CReference, M * N, -0.5f);
this->TestGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
PackedContext.TestGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
ReferenceGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, CReference, ldc);
for (size_t f = 0; f < M * N; f++) {
@ -430,6 +438,7 @@ private:
MatrixGuardBuffer<T> BufferB;
MatrixGuardBuffer<T> BufferC;
MatrixGuardBuffer<T> BufferCReference;
FgemmPackedContext<T, Packed> PackedContext;
public:
void
@ -437,7 +446,7 @@ public:
void
) override
{
for (size_t b = 1; b < 16; b++) {
for (size_t b = 0; b < 16; b++) {
Test(b, b, b, 1.0f, 0.0f);
}
for (size_t b = 16; b <= 256; b <<= 1) {
@ -504,9 +513,9 @@ public:
}
}
for (size_t M = 1; M < 160; M++) {
for (size_t N = 1; N < 160; N++) {
for (size_t K = 1; K < 160; K++) {
for (size_t M = 0; M < 160; M++) {
for (size_t N = 0; N < 160; N++) {
for (size_t K = 0; K < 160; K++) {
Test(M, N, K, 1.0f, 0.0f);
}
}
@ -515,7 +524,7 @@ public:
for (size_t M = 160; M < 320; M += 24) {
for (size_t N = 112; N < 320; N += 24) {
for (size_t K = 1; K < 16; K++) {
for (size_t K = 0; K < 16; K++) {
Test(M, N, K, 1.0f, 0.0f);
}
for (size_t K = 16; K < 160; K += 32) {