mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
MLAS: handle MlasGemm(M/N/K==0) cases (#6238)
This commit is contained in:
parent
4cc2ffef21
commit
ecb2e119e4
3 changed files with 312 additions and 294 deletions
|
|
@ -32,21 +32,21 @@ Abstract:
|
|||
//
|
||||
|
||||
struct MLAS_DGEMM_WORK_BLOCK {
|
||||
int32_t ThreadCountM;
|
||||
int32_t ThreadCountN;
|
||||
CBLAS_TRANSPOSE TransA;
|
||||
CBLAS_TRANSPOSE TransB;
|
||||
size_t M;
|
||||
size_t N;
|
||||
size_t K;
|
||||
const double* A;
|
||||
size_t lda;
|
||||
const double* B;
|
||||
size_t ldb;
|
||||
double* C;
|
||||
size_t ldc;
|
||||
double alpha;
|
||||
double beta;
|
||||
struct SEGMENT {
|
||||
size_t M;
|
||||
size_t N;
|
||||
const double* A;
|
||||
const double* B;
|
||||
double* C;
|
||||
} Segments[MLAS_MAXIMUM_THREAD_COUNT];
|
||||
};
|
||||
|
||||
#ifdef MLAS_TARGET_AMD64
|
||||
|
|
@ -86,7 +86,7 @@ Return Value:
|
|||
{
|
||||
MLAS_FLOAT64X2 BetaBroadcast = MlasBroadcastFloat64x2(beta);
|
||||
|
||||
do {
|
||||
while (CountM-- > 0) {
|
||||
|
||||
double* c = C;
|
||||
size_t n = CountN;
|
||||
|
|
@ -106,9 +106,7 @@ Return Value:
|
|||
}
|
||||
|
||||
C += ldc;
|
||||
CountM--;
|
||||
|
||||
} while (CountM > 0);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
|
|
@ -497,6 +495,82 @@ Return Value:
|
|||
}
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
double*
|
||||
MlasDgemmKernelLoop(
|
||||
const double* A,
|
||||
const double* B,
|
||||
double* C,
|
||||
size_t CountK,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t lda,
|
||||
size_t ldc,
|
||||
double alpha,
|
||||
bool ZeroMode
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine steps through the rows of the input and output matrices calling
|
||||
the kernel until all rows have been processed.
|
||||
|
||||
Arguments:
|
||||
|
||||
A - Supplies the address of matrix A.
|
||||
|
||||
B - Supplies the address of matrix B. The matrix data has been packed using
|
||||
MlasDgemmCopyPackB or MlasDgemmTransposePackB.
|
||||
|
||||
C - Supplies the address of matrix C.
|
||||
|
||||
CountK - Supplies the number of columns from matrix A and the number of rows
|
||||
from matrix B to iterate over.
|
||||
|
||||
CountM - Supplies the number of rows from matrix A and matrix C to iterate
|
||||
over.
|
||||
|
||||
CountN - Supplies the number of columns from matrix B and matrix C to
|
||||
iterate over.
|
||||
|
||||
lda - Supplies the first dimension of matrix A.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
alpha - Supplies the scalar alpha multiplier (see DGEMM definition).
|
||||
|
||||
ZeroMode - Supplies true if the output matrix must be zero initialized,
|
||||
else false if the output matrix is accumulated into.
|
||||
|
||||
Return Value:
|
||||
|
||||
Returns the next address of matrix C.
|
||||
|
||||
--*/
|
||||
{
|
||||
while (CountM > 0) {
|
||||
|
||||
size_t RowsHandled;
|
||||
|
||||
#if defined(MLAS_TARGET_AMD64_IX86)
|
||||
RowsHandled = MlasPlatform.GemmDoubleKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode);
|
||||
#else
|
||||
if (ZeroMode) {
|
||||
RowsHandled = MlasDgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, alpha);
|
||||
} else {
|
||||
RowsHandled = MlasDgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, alpha);
|
||||
}
|
||||
#endif
|
||||
|
||||
C += ldc * RowsHandled;
|
||||
A += lda * RowsHandled;
|
||||
CountM -= RowsHandled;
|
||||
}
|
||||
|
||||
return C;
|
||||
}
|
||||
|
||||
void
|
||||
MlasDgemmOperation(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
|
|
@ -558,6 +632,16 @@ Return Value:
|
|||
double PanelA[MLAS_DGEMM_TRANSA_ROWS * MLAS_DGEMM_STRIDEK];
|
||||
MLAS_DECLSPEC_ALIGN(double PanelB[MLAS_DGEMM_STRIDEN * MLAS_DGEMM_STRIDEK], 8 * sizeof(double));
|
||||
|
||||
//
|
||||
// Handle the special case of K equals zero. Apply the beta multiplier to
|
||||
// the output matrix and exit.
|
||||
//
|
||||
|
||||
if (K == 0) {
|
||||
MlasDgemmMultiplyBeta(C, M, N, ldc, beta);
|
||||
return;
|
||||
}
|
||||
|
||||
//
|
||||
// Compute the strides to step through slices of the input matrices.
|
||||
//
|
||||
|
|
@ -566,8 +650,8 @@ Return Value:
|
|||
// the A panel needs to be used for transposing.
|
||||
//
|
||||
|
||||
uint32_t StrideN = MLAS_DGEMM_STRIDEN;
|
||||
uint32_t StrideK = MLAS_DGEMM_STRIDEK;
|
||||
size_t StrideN = MLAS_DGEMM_STRIDEN;
|
||||
size_t StrideK = MLAS_DGEMM_STRIDEK;
|
||||
|
||||
if (N >= K) {
|
||||
|
||||
|
|
@ -589,15 +673,10 @@ Return Value:
|
|||
//
|
||||
|
||||
size_t CountN;
|
||||
size_t CountK;
|
||||
|
||||
for (size_t n = 0; n < N; n += CountN) {
|
||||
|
||||
CountN = StrideN;
|
||||
|
||||
if (CountN > (N - n)) {
|
||||
CountN = N - n;
|
||||
}
|
||||
CountN = std::min(N - n, StrideN);
|
||||
|
||||
//
|
||||
// Multiply the output matrix by beta as needed.
|
||||
|
|
@ -611,15 +690,12 @@ Return Value:
|
|||
// Step through each slice of matrix B along the K dimension.
|
||||
//
|
||||
|
||||
size_t CountK;
|
||||
bool ZeroMode = (beta == 0.0f);
|
||||
|
||||
for (size_t k = 0; k < K; k += CountK) {
|
||||
|
||||
bool ZeroMode = (k == 0 && beta == 0.0f);
|
||||
|
||||
CountK = StrideK;
|
||||
|
||||
if (CountK > (K - k)) {
|
||||
CountK = K - k;
|
||||
}
|
||||
CountK = std::min(K - k, StrideK);
|
||||
|
||||
//
|
||||
// Copy or transpose a panel of matrix B to a local packed buffer.
|
||||
|
|
@ -637,93 +713,45 @@ Return Value:
|
|||
|
||||
double* c = C + n;
|
||||
|
||||
size_t RowsRemaining = M;
|
||||
size_t RowsHandled;
|
||||
|
||||
if (TransA == CblasNoTrans) {
|
||||
|
||||
const double* a = A + k;
|
||||
|
||||
//
|
||||
// Step through the rows of matrix A.
|
||||
//
|
||||
|
||||
do {
|
||||
|
||||
#if defined(MLAS_TARGET_AMD64_IX86)
|
||||
RowsHandled = MlasPlatform.GemmDoubleKernel(a, PanelB, c, CountK, RowsRemaining, CountN, lda, ldc, alpha, ZeroMode);
|
||||
#else
|
||||
if (ZeroMode) {
|
||||
RowsHandled = MlasDgemmKernelZero(a, PanelB, c, CountK, RowsRemaining, CountN, lda, ldc, alpha);
|
||||
} else {
|
||||
RowsHandled = MlasDgemmKernelAdd(a, PanelB, c, CountK, RowsRemaining, CountN, lda, ldc, alpha);
|
||||
}
|
||||
#endif
|
||||
|
||||
c += ldc * RowsHandled;
|
||||
a += lda * RowsHandled;
|
||||
|
||||
RowsRemaining -= RowsHandled;
|
||||
|
||||
} while (RowsRemaining > 0);
|
||||
MlasDgemmKernelLoop(A + k, PanelB, c, CountK, M, CountN, lda, ldc, alpha, ZeroMode);
|
||||
|
||||
} else {
|
||||
|
||||
const double* a = A + k * lda;
|
||||
size_t RowsRemaining = M;
|
||||
|
||||
do {
|
||||
while (RowsRemaining > 0) {
|
||||
|
||||
//
|
||||
// Transpose elements from matrix A into a local buffer.
|
||||
//
|
||||
|
||||
size_t RowsTransposed = RowsRemaining;
|
||||
|
||||
if (RowsTransposed > MLAS_DGEMM_TRANSA_ROWS) {
|
||||
RowsTransposed = MLAS_DGEMM_TRANSA_ROWS;
|
||||
}
|
||||
|
||||
RowsRemaining -= RowsTransposed;
|
||||
size_t RowsTransposed = std::min(RowsRemaining, size_t(MLAS_DGEMM_TRANSA_ROWS));
|
||||
|
||||
MlasDgemmTransposeA(PanelA, a, lda, RowsTransposed, CountK);
|
||||
|
||||
RowsRemaining -= RowsTransposed;
|
||||
a += RowsTransposed;
|
||||
|
||||
//
|
||||
// Step through the rows of the local buffer.
|
||||
//
|
||||
|
||||
const double* pa = PanelA;
|
||||
|
||||
do {
|
||||
|
||||
#if defined(MLAS_TARGET_AMD64_IX86)
|
||||
RowsHandled = MlasPlatform.GemmDoubleKernel(pa, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode);
|
||||
#else
|
||||
if (ZeroMode) {
|
||||
RowsHandled = MlasDgemmKernelZero(pa, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha);
|
||||
} else {
|
||||
RowsHandled = MlasDgemmKernelAdd(pa, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha);
|
||||
}
|
||||
#endif
|
||||
|
||||
c += ldc * RowsHandled;
|
||||
pa += CountK * RowsHandled;
|
||||
|
||||
RowsTransposed -= RowsHandled;
|
||||
|
||||
} while (RowsTransposed > 0);
|
||||
|
||||
} while (RowsRemaining > 0);
|
||||
c = MlasDgemmKernelLoop(PanelA, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode);
|
||||
}
|
||||
}
|
||||
|
||||
ZeroMode = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
MlasDgemmOperationThreaded(
|
||||
MlasDgemmThreaded(
|
||||
void* Context,
|
||||
int32_t Index
|
||||
int32_t ThreadId
|
||||
)
|
||||
/*++
|
||||
|
||||
|
|
@ -736,7 +764,7 @@ Arguments:
|
|||
|
||||
Context - Supplies the pointer to the context for the threaded operation.
|
||||
|
||||
Index - Supplies the current index of the threaded operation.
|
||||
ThreadId - Supplies the current index of the threaded operation.
|
||||
|
||||
Return Value:
|
||||
|
||||
|
|
@ -744,19 +772,147 @@ Return Value:
|
|||
|
||||
--*/
|
||||
{
|
||||
MLAS_DGEMM_WORK_BLOCK* WorkBlock = (MLAS_DGEMM_WORK_BLOCK*)Context;
|
||||
const auto* WorkBlock = (MLAS_DGEMM_WORK_BLOCK*)Context;
|
||||
|
||||
MLAS_DGEMM_WORK_BLOCK::SEGMENT* Segment = &WorkBlock->Segments[Index];
|
||||
const int32_t ThreadCountM = WorkBlock->ThreadCountM;
|
||||
const int32_t ThreadCountN = WorkBlock->ThreadCountN;
|
||||
|
||||
MlasDgemmOperation(WorkBlock->TransA, WorkBlock->TransB, Segment->M,
|
||||
Segment->N, WorkBlock->K, WorkBlock->alpha, Segment->A, WorkBlock->lda,
|
||||
Segment->B, WorkBlock->ldb, WorkBlock->beta, Segment->C,
|
||||
WorkBlock->ldc);
|
||||
const int32_t ThreadIdM = ThreadId / ThreadCountN;
|
||||
const int32_t ThreadIdN = ThreadId % ThreadCountN;
|
||||
|
||||
//
|
||||
// Partition the operation along the M dimension.
|
||||
//
|
||||
|
||||
size_t M = WorkBlock->M;
|
||||
size_t RangeStartM;
|
||||
size_t RangeCountM;
|
||||
|
||||
MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM);
|
||||
|
||||
//
|
||||
// Partition the operation along the N dimension.
|
||||
//
|
||||
|
||||
size_t N = WorkBlock->N;
|
||||
size_t RangeStartN;
|
||||
size_t RangeCountN;
|
||||
|
||||
const size_t BlockedN = (N + MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1) /
|
||||
MLAS_DGEMM_STRIDEN_THREAD_ALIGN;
|
||||
|
||||
MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN,
|
||||
&RangeCountN);
|
||||
|
||||
RangeStartN *= MLAS_DGEMM_STRIDEN_THREAD_ALIGN;
|
||||
RangeCountN *= MLAS_DGEMM_STRIDEN_THREAD_ALIGN;
|
||||
|
||||
RangeCountN = std::min(N - RangeStartN, RangeCountN);
|
||||
|
||||
//
|
||||
// Dispatch the partitioned operation.
|
||||
//
|
||||
|
||||
CBLAS_TRANSPOSE TransA = WorkBlock->TransA;
|
||||
CBLAS_TRANSPOSE TransB = WorkBlock->TransB;
|
||||
|
||||
const size_t lda = WorkBlock->lda;
|
||||
const size_t ldb = WorkBlock->ldb;
|
||||
const size_t ldc = WorkBlock->ldc;
|
||||
|
||||
const double* A = WorkBlock->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1);
|
||||
const double* B = WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb);
|
||||
double* C = WorkBlock->C + RangeStartM * ldc + RangeStartN;
|
||||
|
||||
MlasDgemmOperation(TransA, TransB, RangeCountM, RangeCountN, WorkBlock->K,
|
||||
WorkBlock->alpha, A, lda, B, ldb, WorkBlock->beta, C, ldc);
|
||||
}
|
||||
|
||||
inline
|
||||
bool
|
||||
MlasDgemmTryMultithread(
|
||||
void
|
||||
MlasDgemmSchedule(
|
||||
MLAS_DGEMM_WORK_BLOCK* WorkBlock,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine schedules the double precision matrix/matrix multiply
|
||||
operation (DGEMM) across one or more threads.
|
||||
|
||||
Arguments:
|
||||
|
||||
WorkBlock - Supplies the structure containing the GEMM parameters.
|
||||
|
||||
ThreadPool - Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
const size_t M = WorkBlock->M;
|
||||
const size_t N = WorkBlock->N;
|
||||
const size_t K = WorkBlock->K;
|
||||
|
||||
//
|
||||
// Compute the number of target threads given the complexity of the DGEMM
|
||||
// operation. Small requests should run using the single threaded path.
|
||||
//
|
||||
|
||||
const double Complexity = double(M) * double(N) * double(K);
|
||||
|
||||
int32_t TargetThreadCount;
|
||||
|
||||
if (Complexity < double(MLAS_DGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) {
|
||||
TargetThreadCount = int32_t(Complexity / double(MLAS_DGEMM_THREAD_COMPLEXITY)) + 1;
|
||||
} else {
|
||||
TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
|
||||
}
|
||||
|
||||
int32_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool);
|
||||
|
||||
if (TargetThreadCount >= MaximumThreadCount) {
|
||||
TargetThreadCount = MaximumThreadCount;
|
||||
}
|
||||
|
||||
//
|
||||
// Segment the operation across multiple threads.
|
||||
//
|
||||
// N.B. Currently, the operation is segmented as a 1D partition, which
|
||||
// works okay for operations involving skinny matrices.
|
||||
//
|
||||
|
||||
if (N > M) {
|
||||
|
||||
const size_t BlockedN = (N + MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1) /
|
||||
MLAS_DGEMM_STRIDEN_THREAD_ALIGN;
|
||||
|
||||
if (size_t(TargetThreadCount) > BlockedN) {
|
||||
TargetThreadCount = int32_t(BlockedN);
|
||||
}
|
||||
|
||||
WorkBlock->ThreadCountM = 1;
|
||||
WorkBlock->ThreadCountN = TargetThreadCount;
|
||||
|
||||
} else {
|
||||
|
||||
if (size_t(TargetThreadCount) > M) {
|
||||
TargetThreadCount = int32_t(M);
|
||||
}
|
||||
|
||||
WorkBlock->ThreadCountM = TargetThreadCount;
|
||||
WorkBlock->ThreadCountN = 1;
|
||||
}
|
||||
|
||||
MlasExecuteThreaded(MlasDgemmThreaded, WorkBlock, TargetThreadCount, ThreadPool);
|
||||
}
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemm(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
CBLAS_TRANSPOSE TransB,
|
||||
size_t M,
|
||||
|
|
@ -776,8 +932,8 @@ MlasDgemmTryMultithread(
|
|||
|
||||
Routine Description:
|
||||
|
||||
This routine attempts to launch a single precision matrix/matrix multiply
|
||||
operation (DGEMM) across multiple threads.
|
||||
This routine implements the double precision matrix/matrix multiply
|
||||
operation (DGEMM).
|
||||
|
||||
Arguments:
|
||||
|
||||
|
|
@ -813,190 +969,37 @@ Arguments:
|
|||
|
||||
Return Value:
|
||||
|
||||
Returns true if the operation was completed across multiple threads, else
|
||||
false if the operation should fall back to a single thread.
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
MLAS_DGEMM_WORK_BLOCK WorkBlock;
|
||||
int32_t TargetThreadCount;
|
||||
|
||||
//
|
||||
// Compute the number of target threads given the complexity of the DGEMM
|
||||
// operation. Small requests should run using the single threaded path.
|
||||
// Capture the GEMM parameters to the work block.
|
||||
//
|
||||
|
||||
double Complexity = double(M) * double(N) * double(K);
|
||||
|
||||
if (Complexity < double(MLAS_DGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) {
|
||||
TargetThreadCount = int32_t(Complexity / double(MLAS_DGEMM_THREAD_COMPLEXITY)) + 1;
|
||||
} else {
|
||||
TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
|
||||
}
|
||||
|
||||
int32_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool);
|
||||
|
||||
if (TargetThreadCount >= MaximumThreadCount) {
|
||||
TargetThreadCount = MaximumThreadCount;
|
||||
}
|
||||
|
||||
if (TargetThreadCount == 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
//
|
||||
// Initialize the common fields of the work block.
|
||||
//
|
||||
memset(&WorkBlock, 0, sizeof(MLAS_DGEMM_WORK_BLOCK));
|
||||
|
||||
WorkBlock.TransA = TransA;
|
||||
WorkBlock.TransB = TransB;
|
||||
WorkBlock.M = M;
|
||||
WorkBlock.N = N;
|
||||
WorkBlock.K = K;
|
||||
WorkBlock.A = A;
|
||||
WorkBlock.lda = lda;
|
||||
WorkBlock.B = B;
|
||||
WorkBlock.ldb = ldb;
|
||||
WorkBlock.C = C;
|
||||
WorkBlock.ldc = ldc;
|
||||
WorkBlock.alpha = alpha;
|
||||
WorkBlock.beta = beta;
|
||||
|
||||
//
|
||||
// Segment the operation across multiple threads.
|
||||
// Schedule the operation across a set of worker threads.
|
||||
//
|
||||
|
||||
int32_t Index = 0;
|
||||
|
||||
if (N > M) {
|
||||
|
||||
size_t StrideN = N / TargetThreadCount;
|
||||
|
||||
if ((StrideN * TargetThreadCount) != N) {
|
||||
StrideN++;
|
||||
}
|
||||
|
||||
StrideN =
|
||||
(StrideN + MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1);
|
||||
|
||||
size_t pldb = (TransB == CblasNoTrans) ? 1 : ldb;
|
||||
|
||||
for (size_t CountN, n = 0; n < N; n += CountN) {
|
||||
|
||||
CountN = StrideN;
|
||||
|
||||
if (CountN > (N - n)) {
|
||||
CountN = N - n;
|
||||
}
|
||||
|
||||
WorkBlock.Segments[Index].M = M;
|
||||
WorkBlock.Segments[Index].N = CountN;
|
||||
WorkBlock.Segments[Index].A = A;
|
||||
WorkBlock.Segments[Index].B = B + n * pldb;
|
||||
WorkBlock.Segments[Index].C = C + n;
|
||||
|
||||
Index++;
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
size_t StrideM = M / TargetThreadCount;
|
||||
|
||||
if ((StrideM * TargetThreadCount) != M) {
|
||||
StrideM++;
|
||||
}
|
||||
|
||||
size_t plda = (TransA == CblasNoTrans) ? lda : 1;
|
||||
|
||||
for (size_t CountM, m = 0; m < M; m += CountM) {
|
||||
|
||||
CountM = StrideM;
|
||||
|
||||
if (CountM > (M - m)) {
|
||||
CountM = M - m;
|
||||
}
|
||||
|
||||
WorkBlock.Segments[Index].M = CountM;
|
||||
WorkBlock.Segments[Index].N = N;
|
||||
WorkBlock.Segments[Index].A = A + m * plda;
|
||||
WorkBlock.Segments[Index].B = B;
|
||||
WorkBlock.Segments[Index].C = C + m * ldc;
|
||||
|
||||
Index++;
|
||||
}
|
||||
}
|
||||
|
||||
MlasExecuteThreaded(MlasDgemmOperationThreaded, &WorkBlock, Index, ThreadPool);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemm(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
CBLAS_TRANSPOSE TransB,
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
double alpha,
|
||||
const double* A,
|
||||
size_t lda,
|
||||
const double* B,
|
||||
size_t ldb,
|
||||
double beta,
|
||||
double* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine implements the single precision matrix/matrix multiply
|
||||
operation (SGEMM).
|
||||
|
||||
Arguments:
|
||||
|
||||
TransA - Supplies the transpose operation for matrix A.
|
||||
|
||||
TransB - Supplies the transpose operation for matrix B.
|
||||
|
||||
M - Supplies the number of rows of matrix A and matrix C.
|
||||
|
||||
N - Supplies the number of columns of matrix B and matrix C.
|
||||
|
||||
K - Supplies the number of columns of matrix A and the number of rows of
|
||||
matrix B.
|
||||
|
||||
alpha - Supplies the scalar alpha multiplier (see SGEMM definition).
|
||||
|
||||
A - Supplies the address of matrix A.
|
||||
|
||||
lda - Supplies the first dimension of matrix A.
|
||||
|
||||
B - Supplies the address of matrix B.
|
||||
|
||||
ldb - Supplies the first dimension of matrix B.
|
||||
|
||||
beta - Supplies the scalar beta multiplier (see SGEMM definition).
|
||||
|
||||
C - Supplies the address of matrix C.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
ThreadPool - Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
//
|
||||
// Try to run the operation across multiple threads or fall back to a
|
||||
// single thread based on the GEMM parameters and system configuration.
|
||||
//
|
||||
|
||||
if (!MlasDgemmTryMultithread(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, ThreadPool)) {
|
||||
MlasDgemmOperation(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
|
||||
}
|
||||
MlasDgemmSchedule(&WorkBlock, ThreadPool);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -41,13 +41,13 @@ struct MLAS_SGEMM_WORK_BLOCK {
|
|||
size_t K;
|
||||
const float* A;
|
||||
size_t lda;
|
||||
const float* B;
|
||||
const void* B;
|
||||
size_t ldb;
|
||||
const void* PackedB;
|
||||
float* C;
|
||||
size_t ldc;
|
||||
float alpha;
|
||||
float beta;
|
||||
bool BIsPacked;
|
||||
};
|
||||
|
||||
void
|
||||
|
|
@ -85,7 +85,7 @@ Return Value:
|
|||
{
|
||||
MLAS_FLOAT32X4 BetaBroadcast = MlasBroadcastFloat32x4(beta);
|
||||
|
||||
do {
|
||||
while (CountM-- > 0) {
|
||||
|
||||
float* c = C;
|
||||
size_t n = CountN;
|
||||
|
|
@ -107,9 +107,7 @@ Return Value:
|
|||
}
|
||||
|
||||
C += ldc;
|
||||
CountM--;
|
||||
|
||||
} while (CountM > 0);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
|
|
@ -807,7 +805,7 @@ Return Value:
|
|||
|
||||
--*/
|
||||
{
|
||||
do {
|
||||
while (CountM > 0) {
|
||||
|
||||
size_t RowsHandled;
|
||||
|
||||
|
|
@ -826,8 +824,7 @@ Return Value:
|
|||
C += ldc * RowsHandled;
|
||||
A += lda * RowsHandled;
|
||||
CountM -= RowsHandled;
|
||||
|
||||
} while (CountM > 0);
|
||||
}
|
||||
|
||||
return C;
|
||||
}
|
||||
|
|
@ -893,6 +890,16 @@ Return Value:
|
|||
float PanelA[MLAS_SGEMM_TRANSA_ROWS * MLAS_SGEMM_STRIDEK];
|
||||
MLAS_DECLSPEC_ALIGN(float PanelB[MLAS_SGEMM_STRIDEN * MLAS_SGEMM_STRIDEK], 16 * sizeof(float));
|
||||
|
||||
//
|
||||
// Handle the special case of K equals zero. Apply the beta multiplier to
|
||||
// the output matrix and exit.
|
||||
//
|
||||
|
||||
if (K == 0) {
|
||||
MlasSgemmMultiplyBeta(C, M, N, ldc, beta);
|
||||
return;
|
||||
}
|
||||
|
||||
//
|
||||
// Handle the special case of a small M. The data from matrix B is not
|
||||
// referenced multiple times, so using a local packed buffer is a wasted
|
||||
|
|
@ -1035,7 +1042,7 @@ Return Value:
|
|||
const float* a = A + k * lda;
|
||||
size_t RowsRemaining = M;
|
||||
|
||||
do {
|
||||
while (RowsRemaining > 0) {
|
||||
|
||||
//
|
||||
// Transpose elements from matrix A into a local buffer.
|
||||
|
|
@ -1053,8 +1060,7 @@ Return Value:
|
|||
//
|
||||
|
||||
c = MlasSgemmKernelLoop(PanelA, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode);
|
||||
|
||||
} while (RowsRemaining > 0);
|
||||
}
|
||||
}
|
||||
|
||||
ZeroMode = false;
|
||||
|
|
@ -1171,7 +1177,7 @@ Return Value:
|
|||
const float* a = A + k * lda;
|
||||
size_t RowsRemaining = M;
|
||||
|
||||
do {
|
||||
while (RowsRemaining > 0) {
|
||||
|
||||
//
|
||||
// Transpose elements from matrix A into a local buffer.
|
||||
|
|
@ -1189,8 +1195,7 @@ Return Value:
|
|||
//
|
||||
|
||||
c = MlasSgemmKernelLoop(PanelA, pb, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode);
|
||||
|
||||
} while (RowsRemaining > 0);
|
||||
}
|
||||
}
|
||||
|
||||
ZeroMode = false;
|
||||
|
|
@ -1271,22 +1276,22 @@ Return Value:
|
|||
const float* A = WorkBlock->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1);
|
||||
float* C = WorkBlock->C + RangeStartM * ldc + RangeStartN;
|
||||
|
||||
if (WorkBlock->B != nullptr) {
|
||||
if (WorkBlock->BIsPacked) {
|
||||
|
||||
MlasSgemmPackedOperation(TransA, RangeCountM, RangeStartN, RangeCountN,
|
||||
WorkBlock->K, WorkBlock->alpha, A, lda, WorkBlock->B,
|
||||
BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, WorkBlock->beta, C, ldc);
|
||||
|
||||
} else {
|
||||
|
||||
CBLAS_TRANSPOSE TransB = WorkBlock->TransB;
|
||||
|
||||
const size_t ldb = WorkBlock->ldb;
|
||||
|
||||
const float* B = WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb);
|
||||
const float* B = (const float*)WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb);
|
||||
|
||||
MlasSgemmOperation(TransA, TransB, RangeCountM, RangeCountN, WorkBlock->K,
|
||||
WorkBlock->alpha, A, lda, B, ldb, WorkBlock->beta, C, ldc);
|
||||
|
||||
} else {
|
||||
|
||||
MlasSgemmPackedOperation(TransA, RangeCountM, RangeStartN, RangeCountN,
|
||||
WorkBlock->K, WorkBlock->alpha, A, lda, WorkBlock->PackedB,
|
||||
BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, WorkBlock->beta, C, ldc);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1535,11 +1540,12 @@ Return Value:
|
|||
WorkBlock.K = K;
|
||||
WorkBlock.A = A;
|
||||
WorkBlock.lda = lda;
|
||||
WorkBlock.PackedB = PackedB;
|
||||
WorkBlock.B = PackedB;
|
||||
WorkBlock.C = C;
|
||||
WorkBlock.ldc = ldc;
|
||||
WorkBlock.alpha = alpha;
|
||||
WorkBlock.beta = beta;
|
||||
WorkBlock.BIsPacked = true;
|
||||
|
||||
//
|
||||
// Schedule the operation across a set of worker threads.
|
||||
|
|
|
|||
|
|
@ -202,10 +202,10 @@ public:
|
|||
};
|
||||
|
||||
template<typename T, bool Packed>
|
||||
class MlasFgemmTestBase;
|
||||
class FgemmPackedContext;
|
||||
|
||||
template<typename T>
|
||||
class MlasFgemmTestBase<T, false> : public MlasTestBase
|
||||
class FgemmPackedContext<T, false>
|
||||
{
|
||||
public:
|
||||
void
|
||||
|
|
@ -230,7 +230,7 @@ public:
|
|||
};
|
||||
|
||||
template<typename T>
|
||||
class MlasFgemmTestBase<T, true> : public MlasTestBase
|
||||
class FgemmPackedContext<T, true>
|
||||
{
|
||||
public:
|
||||
void
|
||||
|
|
@ -261,7 +261,7 @@ private:
|
|||
};
|
||||
|
||||
template<typename T, bool Packed>
|
||||
class MlasFgemmTest : public MlasFgemmTestBase<T, Packed>
|
||||
class MlasFgemmTest : public MlasTestBase
|
||||
{
|
||||
private:
|
||||
void
|
||||
|
|
@ -273,6 +273,14 @@ private:
|
|||
float beta
|
||||
)
|
||||
{
|
||||
//
|
||||
// Skip the test if the B buffer cannot be packed.
|
||||
//
|
||||
|
||||
if (Packed && (N == 0 || K == 0)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const T* A = BufferA.GetBuffer(K * M);
|
||||
const T* B = BufferB.GetBuffer(N * K);
|
||||
T* C = BufferC.GetBuffer(N * M);
|
||||
|
|
@ -305,7 +313,7 @@ private:
|
|||
std::fill_n(C, M * N, -0.5f);
|
||||
std::fill_n(CReference, M * N, -0.5f);
|
||||
|
||||
this->TestGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
|
||||
PackedContext.TestGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
|
||||
ReferenceGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, CReference, ldc);
|
||||
|
||||
for (size_t f = 0; f < M * N; f++) {
|
||||
|
|
@ -430,6 +438,7 @@ private:
|
|||
MatrixGuardBuffer<T> BufferB;
|
||||
MatrixGuardBuffer<T> BufferC;
|
||||
MatrixGuardBuffer<T> BufferCReference;
|
||||
FgemmPackedContext<T, Packed> PackedContext;
|
||||
|
||||
public:
|
||||
void
|
||||
|
|
@ -437,7 +446,7 @@ public:
|
|||
void
|
||||
) override
|
||||
{
|
||||
for (size_t b = 1; b < 16; b++) {
|
||||
for (size_t b = 0; b < 16; b++) {
|
||||
Test(b, b, b, 1.0f, 0.0f);
|
||||
}
|
||||
for (size_t b = 16; b <= 256; b <<= 1) {
|
||||
|
|
@ -504,9 +513,9 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
for (size_t M = 1; M < 160; M++) {
|
||||
for (size_t N = 1; N < 160; N++) {
|
||||
for (size_t K = 1; K < 160; K++) {
|
||||
for (size_t M = 0; M < 160; M++) {
|
||||
for (size_t N = 0; N < 160; N++) {
|
||||
for (size_t K = 0; K < 160; K++) {
|
||||
Test(M, N, K, 1.0f, 0.0f);
|
||||
}
|
||||
}
|
||||
|
|
@ -515,7 +524,7 @@ public:
|
|||
|
||||
for (size_t M = 160; M < 320; M += 24) {
|
||||
for (size_t N = 112; N < 320; N += 24) {
|
||||
for (size_t K = 1; K < 16; K++) {
|
||||
for (size_t K = 0; K < 16; K++) {
|
||||
Test(M, N, K, 1.0f, 0.0f);
|
||||
}
|
||||
for (size_t K = 16; K < 160; K += 32) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue