Add batch interface to floating point GEMM (#7323)

Currently in high dimension matmul, we call multiple GEMM sequentially. In this change we execute these GEMMs in parallel, removing barriers between two adjacent GEMM operations.

Performance tested with Bert and T5 model. Bert model shows no noticeable perf differences, as the heavy lifting is done by the attention operator, which is not changed in this PR. In T5 model, we see no regression on low parallel threads (x4), and performance improvement is more pronounced in high number of threads (8-16). T5 shows 10% speedup with 16 threads. With profiling, we can see the most expensive MatMul operators in T5 achieves around 20% speedup with 16 threads.

Co-authored-by: Chen Fu <fuchen@microsoft.com>
This commit is contained in:
Chen Fu 2021-04-23 17:34:22 -07:00 committed by GitHub
parent 7a3c1787af
commit f4f2cc1a00
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 624 additions and 623 deletions

View file

@ -309,6 +309,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
T* qkv_dest = QKV[qkv_index];
int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size);
// TODO!! memcpy here makes it not worthwhile to use Gemm batch. Possible to post process?
// broadcast 3NH -> (3.B.N.S.H)
const T* broadcast_data_src = bias_data + weights_offset;
T* broadcast_data_dest = QKV[qkv_index] + qkv_offset;

View file

@ -147,10 +147,102 @@ MlasActivation(
//
// Matrix/matrix multiply routines.
// C := alpha * op(A) * op(B) + beta * C
// op(X) = X or op(X) = transpose(X) or op(X) = conjg(transpose(X))
//
/**
* @brief Supply matrices data information to single precision gemm functions
*/
struct MLAS_SGEMM_DATA_PARAMS {
const float* A = nullptr; /**< Supplies the address of matrix A */
size_t lda = 0; /**< Supplies the first dimension of matrix A. */
const float* B = nullptr; /**< Supplies the address of matrix B */
size_t ldb = 0; /**< Supplies the first dimension of matrix B. */
float* C = nullptr; /**< Supplies the address of matrix C */
size_t ldc = 0; /**< Supplies the first dimension of matrix C. */
float alpha = 1.0f; /**< Supplies the scalar alpha multiplier (see SGEMM definition) */
float beta = 0.0f; /**< Supplies the scalar beta multiplier (see SGEMM definition) */
bool BIsPacked = false; /**< Whether B is pre-packed */
};
/**
* @brief Batched single precision matrix/matrix multiply operation (SGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param Data A array of matrices data parameters
* @param BatchSize Supplies number of multiplications in this batch
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
void
MLASCALL
MlasGemmBatch(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_SGEMM_DATA_PARAMS* Data,
size_t BatchSize,
MLAS_THREADPOOL* ThreadPool
);
/**
* @brief Single precision matrix/matrix multiply operation (SGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param Data Supplies the matrices data parameters
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
inline
void
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_SGEMM_DATA_PARAMS& Data,
MLAS_THREADPOOL* ThreadPool
)
{
MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool);
}
/**
* @brief Single precision matrix/matrix multiply operation (SGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
* @param A Supplies the address of matrix A
* @param lda Supplies the first dimension of matrix A.
* @param B Supplies the address of matrix B
* @param ldb Supplies the first dimension of matrix B.
* @param beta Supplies the scalar beta multiplier (see SGEMM definition)
* @param C Supplies the address of matrix C
* @param ldc Supplies the first dimension of matrix C.
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
inline
void
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
@ -166,10 +258,41 @@ MlasGemm(
float* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
);
)
{
MLAS_SGEMM_DATA_PARAMS Data;
Data.alpha = alpha;
Data.A = A;
Data.lda = lda;
Data.B = B;
Data.ldb = ldb;
Data.beta = beta;
Data.C = C;
Data.ldc = ldc;
MlasGemm(TransA, TransB, M, N, K, Data, ThreadPool);
}
/**
* @brief the single precision matrix/matrix multiply operation (SGEMM) with pre-packed B
*
* @param TransA - Supplies the transpose operation for matrix A.
* @param M - Supplies the number of rows of matrix A and matrix C.
* @param N - Supplies the number of columns of matrix B and matrix C.
* @param K - Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param alpha - Supplies the scalar alpha multiplier (see SGEMM definition).
* @param A - Supplies the address of matrix A.
* @param lda - Supplies the first dimension of matrix A.
* @param PackedB - Supplies the address of packed matrix B.
* @param beta - Supplies the scalar beta multiplier (see SGEMM definition).
* @param C - Supplies the address of matrix C.
* @param ldc - Supplies the first dimension of matrix C.
* @param ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
inline
void
MLASCALL
MlasGemm(
CBLAS_TRANSPOSE TransA,
size_t M,
@ -183,10 +306,117 @@ MlasGemm(
float* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
);
)
{
MLAS_SGEMM_DATA_PARAMS DataParams;
DataParams.A = A;
DataParams.lda = lda;
DataParams.B = static_cast<const float*>(PackedB);
DataParams.ldb = 0;
DataParams.C = C;
DataParams.ldc = ldc;
DataParams.alpha = alpha;
DataParams.beta = beta;
DataParams.BIsPacked = true;
MlasGemmBatch(TransA,
CblasTrans, // deos not matter when B is packed
M, N, K, &DataParams, 1, ThreadPool);
}
/**
* @brief Supply matrices data information to double precision gemm functions
*/
struct MLAS_DGEMM_DATA_PARAMS {
const double* A = nullptr; /**< Supplies the address of matrix A */
size_t lda = 0; /**< Supplies the first dimension of matrix A. */
const double* B = nullptr; /**< Supplies the address of matrix B */
size_t ldb = 0; /**< Supplies the first dimension of matrix B. */
double* C = nullptr; /**< Supplies the address of matrix C */
size_t ldc = 0; /**< Supplies the first dimension of matrix C. */
double alpha = 1.0; /**< Supplies the scalar alpha multiplier (see SGEMM definition) */
double beta = 0.0; /**< Supplies the scalar beta multiplier (see SGEMM definition) */
};
/**
* @brief Batched double precision matrix/matrix multiply operation (DGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param Data A array of matrices data parameters
* @param BatchSize Supplies number of multiplications in this batch
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
void
MLASCALL
MlasGemmBatch(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_DGEMM_DATA_PARAMS* Data,
size_t BatchSize,
MLAS_THREADPOOL* ThreadPool
);
/**
* @brief Double precision matrix/matrix multiply operation (DGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param Data Supplies the matrices data parameters
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
inline
void
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_DGEMM_DATA_PARAMS& Data,
MLAS_THREADPOOL* ThreadPool
)
{
MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool);
}
/**
* @brief Double precision matrix/matrix multiply operation (DGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
* @param A Supplies the address of matrix A
* @param lda Supplies the first dimension of matrix A.
* @param B Supplies the address of matrix B
* @param ldb Supplies the first dimension of matrix B.
* @param beta Supplies the scalar beta multiplier (see SGEMM definition)
* @param C Supplies the address of matrix C
* @param ldc Supplies the first dimension of matrix C.
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
inline
void
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
@ -202,7 +432,19 @@ MlasGemm(
double* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
);
)
{
MLAS_DGEMM_DATA_PARAMS Data;
Data.alpha = alpha;
Data.A = A;
Data.lda = lda;
Data.B = B;
Data.ldb = ldb;
Data.beta = beta;
Data.C = C;
Data.ldc = ldc;
MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool);
}
enum class MLAS_QUANTIZATION_GRANULARITY {
PerMatrix,

View file

@ -26,29 +26,6 @@ Abstract:
#define MLAS_DGEMM_TRANSA_ROWS 12
//
// Define the parameters to execute segments of a DGEMM operation on worker
// threads.
//
struct MLAS_DGEMM_WORK_BLOCK {
ptrdiff_t ThreadCountM;
ptrdiff_t ThreadCountN;
CBLAS_TRANSPOSE TransA;
CBLAS_TRANSPOSE TransB;
size_t M;
size_t N;
size_t K;
const double* A;
size_t lda;
const double* B;
size_t ldb;
double* C;
size_t ldc;
double alpha;
double beta;
};
#ifdef MLAS_TARGET_AMD64
void
@ -750,8 +727,15 @@ Return Value:
void
MlasDgemmThreaded(
void* Context,
ptrdiff_t ThreadId
const ptrdiff_t ThreadCountM,
const ptrdiff_t ThreadCountN,
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const size_t M,
const size_t N,
const size_t K,
const MLAS_DGEMM_DATA_PARAMS* Data,
const ptrdiff_t ThreadId
)
/*++
@ -772,10 +756,6 @@ Return Value:
--*/
{
const auto* WorkBlock = (MLAS_DGEMM_WORK_BLOCK*)Context;
const ptrdiff_t ThreadCountM = WorkBlock->ThreadCountM;
const ptrdiff_t ThreadCountN = WorkBlock->ThreadCountN;
const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN;
const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN;
@ -784,7 +764,6 @@ Return Value:
// Partition the operation along the M dimension.
//
size_t M = WorkBlock->M;
size_t RangeStartM;
size_t RangeCountM;
@ -794,7 +773,6 @@ Return Value:
// Partition the operation along the N dimension.
//
size_t N = WorkBlock->N;
size_t RangeStartN;
size_t RangeCountN;
@ -813,50 +791,32 @@ Return Value:
// Dispatch the partitioned operation.
//
CBLAS_TRANSPOSE TransA = WorkBlock->TransA;
CBLAS_TRANSPOSE TransB = WorkBlock->TransB;
const size_t lda = Data->lda;
const size_t ldb = Data->ldb;
const size_t ldc = Data->ldc;
const size_t lda = WorkBlock->lda;
const size_t ldb = WorkBlock->ldb;
const size_t ldc = WorkBlock->ldc;
const double* A = Data->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1);
const double* B = Data->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb);
double* C = Data->C + RangeStartM * ldc + RangeStartN;
const double* A = WorkBlock->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1);
const double* B = WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb);
double* C = WorkBlock->C + RangeStartM * ldc + RangeStartN;
MlasDgemmOperation(TransA, TransB, RangeCountM, RangeCountN, WorkBlock->K,
WorkBlock->alpha, A, lda, B, ldb, WorkBlock->beta, C, ldc);
MlasDgemmOperation(TransA, TransB, RangeCountM, RangeCountN, K,
Data->alpha, A, lda, B, ldb, Data->beta, C, ldc);
}
void
MlasDgemmSchedule(
MLAS_DGEMM_WORK_BLOCK* WorkBlock,
MLASCALL
MlasGemmBatch(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_DGEMM_DATA_PARAMS* Data,
size_t BatchSize,
MLAS_THREADPOOL* ThreadPool
)
/*++
Routine Description:
This routine schedules the double precision matrix/matrix multiply
operation (DGEMM) across one or more threads.
Arguments:
WorkBlock - Supplies the structure containing the GEMM parameters.
ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
Return Value:
None.
--*/
{
const size_t M = WorkBlock->M;
const size_t N = WorkBlock->N;
const size_t K = WorkBlock->K;
//
// Compute the number of target threads given the complexity of the DGEMM
// operation. Small requests should run using the single threaded path.
@ -885,121 +845,40 @@ Return Value:
// works okay for operations involving skinny matrices.
//
ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchSize - 1) / BatchSize;
ptrdiff_t ThreadCountM;
ptrdiff_t ThreadCountN;
if (N > M) {
const size_t BlockedN = (N + MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1) /
MLAS_DGEMM_STRIDEN_THREAD_ALIGN;
if (size_t(TargetThreadCount) > BlockedN) {
TargetThreadCount = ptrdiff_t(BlockedN);
if (size_t(ThreadsPerGemm) > BlockedN) {
ThreadsPerGemm = ptrdiff_t(BlockedN);
}
WorkBlock->ThreadCountM = 1;
WorkBlock->ThreadCountN = TargetThreadCount;
ThreadCountM = 1;
ThreadCountN = ThreadsPerGemm;
} else {
if (size_t(TargetThreadCount) > M) {
TargetThreadCount = ptrdiff_t(M);
if (size_t(ThreadsPerGemm) > M) {
ThreadsPerGemm = ptrdiff_t(M);
}
WorkBlock->ThreadCountM = TargetThreadCount;
WorkBlock->ThreadCountN = 1;
ThreadCountM = ThreadsPerGemm;
ThreadCountN = 1;
}
MlasExecuteThreaded(MlasDgemmThreaded, WorkBlock, TargetThreadCount, ThreadPool);
}
const ptrdiff_t TotalThreads = ThreadsPerGemm * static_cast<ptrdiff_t>(BatchSize);
MlasTrySimpleParallel(ThreadPool, TotalThreads, [=](ptrdiff_t tid) {
const ptrdiff_t GemmIdx = tid / ThreadsPerGemm;
const ptrdiff_t ThreadIdx = tid % ThreadsPerGemm;
MlasDgemmThreaded(ThreadCountM, ThreadCountN, TransA, TransB,
M, N, K, &(Data[GemmIdx]), ThreadIdx);
});
void
MLASCALL
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
double alpha,
const double* A,
size_t lda,
const double* B,
size_t ldb,
double beta,
double* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
)
/*++
Routine Description:
This routine implements the double precision matrix/matrix multiply
operation (DGEMM).
Arguments:
TransA - Supplies the transpose operation for matrix A.
TransB - Supplies the transpose operation for matrix B.
M - Supplies the number of rows of matrix A and matrix C.
N - Supplies the number of columns of matrix B and matrix C.
K - Supplies the number of columns of matrix A and the number of rows of
matrix B.
alpha - Supplies the scalar alpha multiplier (see DGEMM definition).
A - Supplies the address of matrix A.
lda - Supplies the first dimension of matrix A.
B - Supplies the address of matrix B.
ldb - Supplies the first dimension of matrix B.
beta - Supplies the scalar beta multiplier (see DGEMM definition).
C - Supplies the address of matrix C.
ldc - Supplies the first dimension of matrix C.
ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
Return Value:
None.
--*/
{
MLAS_DGEMM_WORK_BLOCK WorkBlock;
//
// Capture the GEMM parameters to the work block.
//
memset(&WorkBlock, 0, sizeof(MLAS_DGEMM_WORK_BLOCK));
WorkBlock.TransA = TransA;
WorkBlock.TransB = TransB;
WorkBlock.M = M;
WorkBlock.N = N;
WorkBlock.K = K;
WorkBlock.A = A;
WorkBlock.lda = lda;
WorkBlock.B = B;
WorkBlock.ldb = ldb;
WorkBlock.C = C;
WorkBlock.ldc = ldc;
WorkBlock.alpha = alpha;
WorkBlock.beta = beta;
//
// Schedule the operation across a set of worker threads.
//
MlasDgemmSchedule(&WorkBlock, ThreadPool);
}
#endif

View file

@ -31,25 +31,6 @@ Abstract:
// threads.
//
struct MLAS_SGEMM_WORK_BLOCK {
ptrdiff_t ThreadCountM;
ptrdiff_t ThreadCountN;
CBLAS_TRANSPOSE TransA;
CBLAS_TRANSPOSE TransB;
size_t M;
size_t N;
size_t K;
const float* A;
size_t lda;
const void* B;
size_t ldb;
float* C;
size_t ldc;
float alpha;
float beta;
bool BIsPacked;
};
void
MlasSgemmMultiplyBeta(
float* C,
@ -1476,7 +1457,15 @@ Return Value:
void
MlasSgemmThreaded(
void* Context,
const ptrdiff_t ThreadCountM,
const ptrdiff_t ThreadCountN,
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const size_t M,
const size_t N,
const size_t K,
const MLAS_SGEMM_DATA_PARAMS* DataParams,
ptrdiff_t ThreadId
)
/*++
@ -1488,7 +1477,17 @@ Routine Description:
Arguments:
Context - Supplies the pointer to the context for the threaded operation.
ThreadCountM - Supplies the total thread partition on the M dimension.
ThreadCountN - Supplies the total thread partition on the N dimension.
TransA - Supplies the transpose operation on A matrix
TransB - Supplies the transpose operation on B matrix
M, N, K - Supplies the shape of the multiplication
DataParams - Supplies the data position and layout of the matrices
ThreadId - Supplies the current index of the threaded operation.
@ -1498,10 +1497,6 @@ Return Value:
--*/
{
const auto* WorkBlock = (MLAS_SGEMM_WORK_BLOCK*)Context;
const ptrdiff_t ThreadCountM = WorkBlock->ThreadCountM;
const ptrdiff_t ThreadCountN = WorkBlock->ThreadCountN;
const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN;
const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN;
@ -1510,7 +1505,6 @@ Return Value:
// Partition the operation along the M dimension.
//
size_t M = WorkBlock->M;
size_t RangeStartM;
size_t RangeCountM;
@ -1520,7 +1514,6 @@ Return Value:
// Partition the operation along the N dimension.
//
size_t N = WorkBlock->N;
size_t RangeStartN;
size_t RangeCountN;
@ -1539,61 +1532,42 @@ Return Value:
// Dispatch the partitioned operation.
//
CBLAS_TRANSPOSE TransA = WorkBlock->TransA;
const size_t lda = DataParams->lda;
const size_t ldc = DataParams->ldc;
const size_t lda = WorkBlock->lda;
const size_t ldc = WorkBlock->ldc;
const float* A = DataParams->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1);
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;
const float* A = WorkBlock->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1);
float* C = WorkBlock->C + RangeStartM * ldc + RangeStartN;
if (WorkBlock->BIsPacked) {
if (DataParams->BIsPacked) {
MlasSgemmPackedOperation(TransA, RangeCountM, RangeStartN, RangeCountN,
WorkBlock->K, WorkBlock->alpha, A, lda, WorkBlock->B,
BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, WorkBlock->beta, C, ldc);
K, DataParams->alpha, A, lda, DataParams->B,
BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, DataParams->beta, C, ldc);
} else {
CBLAS_TRANSPOSE TransB = WorkBlock->TransB;
const size_t ldb = DataParams->ldb;
const size_t ldb = WorkBlock->ldb;
const float* B = (const float*)DataParams->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb);
const float* B = (const float*)WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb);
MlasSgemmOperation(TransA, TransB, RangeCountM, RangeCountN, WorkBlock->K,
WorkBlock->alpha, A, lda, B, ldb, WorkBlock->beta, C, ldc);
MlasSgemmOperation(TransA, TransB, RangeCountM, RangeCountN, K,
DataParams->alpha, A, lda, B, ldb, DataParams->beta, C, ldc);
}
}
void
MlasSgemmSchedule(
MLAS_SGEMM_WORK_BLOCK* WorkBlock,
MLASCALL
MlasGemmBatch(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_SGEMM_DATA_PARAMS* Data,
size_t BatchSize,
MLAS_THREADPOOL* ThreadPool
)
/*++
Routine Description:
This routine schedules the single precision matrix/matrix multiply
operation (SGEMM) across one or more threads.
Arguments:
WorkBlock - Supplies the structure containing the GEMM parameters.
ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
Return Value:
None.
--*/
{
const size_t M = WorkBlock->M;
const size_t N = WorkBlock->N;
const size_t K = WorkBlock->K;
//
// Compute the number of target threads given the complexity of the SGEMM
@ -1623,206 +1597,41 @@ Return Value:
// works okay for operations involving skinny matrices.
//
ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchSize - 1) / BatchSize;
ptrdiff_t ThreadCountM;
ptrdiff_t ThreadCountN;
if (N > M) {
const size_t BlockedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) /
MLAS_SGEMM_STRIDEN_THREAD_ALIGN;
if (size_t(TargetThreadCount) > BlockedN) {
TargetThreadCount = ptrdiff_t(BlockedN);
if (size_t(ThreadsPerGemm) > BlockedN) {
ThreadsPerGemm = ptrdiff_t(BlockedN);
}
WorkBlock->ThreadCountM = 1;
WorkBlock->ThreadCountN = TargetThreadCount;
ThreadCountM = 1;
ThreadCountN = ThreadsPerGemm;
} else {
if (size_t(TargetThreadCount) > M) {
TargetThreadCount = ptrdiff_t(M);
if (size_t(ThreadsPerGemm) > M) {
ThreadsPerGemm = ptrdiff_t(M);
}
WorkBlock->ThreadCountM = TargetThreadCount;
WorkBlock->ThreadCountN = 1;
ThreadCountM = ThreadsPerGemm;
ThreadCountN = 1;
}
MlasExecuteThreaded(MlasSgemmThreaded, WorkBlock, TargetThreadCount, ThreadPool);
}
void
MLASCALL
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
float alpha,
const float* A,
size_t lda,
const float* B,
size_t ldb,
float beta,
float* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
)
/*++
Routine Description:
This routine implements the single precision matrix/matrix multiply
operation (SGEMM).
Arguments:
TransA - Supplies the transpose operation for matrix A.
TransB - Supplies the transpose operation for matrix B.
M - Supplies the number of rows of matrix A and matrix C.
N - Supplies the number of columns of matrix B and matrix C.
K - Supplies the number of columns of matrix A and the number of rows of
matrix B.
alpha - Supplies the scalar alpha multiplier (see SGEMM definition).
A - Supplies the address of matrix A.
lda - Supplies the first dimension of matrix A.
B - Supplies the address of matrix B.
ldb - Supplies the first dimension of matrix B.
beta - Supplies the scalar beta multiplier (see SGEMM definition).
C - Supplies the address of matrix C.
ldc - Supplies the first dimension of matrix C.
ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
Return Value:
None.
--*/
{
MLAS_SGEMM_WORK_BLOCK WorkBlock;
//
// Capture the GEMM parameters to the work block.
//
memset(&WorkBlock, 0, sizeof(MLAS_SGEMM_WORK_BLOCK));
WorkBlock.TransA = TransA;
WorkBlock.TransB = TransB;
WorkBlock.M = M;
WorkBlock.N = N;
WorkBlock.K = K;
WorkBlock.A = A;
WorkBlock.lda = lda;
WorkBlock.B = B;
WorkBlock.ldb = ldb;
WorkBlock.C = C;
WorkBlock.ldc = ldc;
WorkBlock.alpha = alpha;
WorkBlock.beta = beta;
//
// Schedule the operation across a set of worker threads.
//
MlasSgemmSchedule(&WorkBlock, ThreadPool);
}
void
MLASCALL
MlasGemm(
CBLAS_TRANSPOSE TransA,
size_t M,
size_t N,
size_t K,
float alpha,
const float* A,
size_t lda,
const void* PackedB,
float beta,
float* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
)
/*++
Routine Description:
This routine implements the single precision matrix/matrix multiply
operation (SGEMM).
Arguments:
TransA - Supplies the transpose operation for matrix A.
M - Supplies the number of rows of matrix A and matrix C.
N - Supplies the number of columns of matrix B and matrix C.
K - Supplies the number of columns of matrix A and the number of rows of
matrix B.
alpha - Supplies the scalar alpha multiplier (see SGEMM definition).
A - Supplies the address of matrix A.
lda - Supplies the first dimension of matrix A.
PackedB - Supplies the address of packed matrix B.
beta - Supplies the scalar beta multiplier (see SGEMM definition).
C - Supplies the address of matrix C.
ldc - Supplies the first dimension of matrix C.
ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
Return Value:
None.
--*/
{
MLAS_SGEMM_WORK_BLOCK WorkBlock;
//
// Capture the GEMM parameters to the work block.
//
memset(&WorkBlock, 0, sizeof(MLAS_SGEMM_WORK_BLOCK));
WorkBlock.TransA = TransA;
WorkBlock.M = M;
WorkBlock.N = N;
WorkBlock.K = K;
WorkBlock.A = A;
WorkBlock.lda = lda;
WorkBlock.B = PackedB;
WorkBlock.C = C;
WorkBlock.ldc = ldc;
WorkBlock.alpha = alpha;
WorkBlock.beta = beta;
WorkBlock.BIsPacked = true;
//
// Schedule the operation across a set of worker threads.
//
MlasSgemmSchedule(&WorkBlock, ThreadPool);
MlasTrySimpleParallel(ThreadPool,
ThreadsPerGemm * static_cast<ptrdiff_t>(BatchSize),
[=](ptrdiff_t tid)
{
ptrdiff_t GemmIdx = tid / ThreadsPerGemm;
ptrdiff_t ThreadIdx = tid % ThreadsPerGemm;
MlasSgemmThreaded(ThreadCountM, ThreadCountN,
TransA, TransB, M, N, K, &(Data[GemmIdx]), ThreadIdx);
});
}
size_t

View file

@ -88,7 +88,7 @@ MlasTrySimpleParallel(
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (ptrdiff_t tid = 0; tid < Iterations; tid++) {
Work(tid);
}

View file

@ -159,38 +159,27 @@ Status MatMul<float>::Compute(OpKernelContext* ctx) const {
const auto* b_data = b ? b->Data<float>() : nullptr;
auto* y_data = y->MutableData<float>();
// TODO: replace it with GemmBatch for performance, it's OK for now as GemmBatch unrolls as well
size_t max_len = helper.OutputOffsets().size();
const size_t max_len = helper.OutputOffsets().size();
const size_t M = static_cast<size_t>(helper.M());
const size_t N = static_cast<size_t>(helper.N());
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = static_cast<int>(trans_a ? M : K);
const size_t ldb = static_cast<int>(trans_b ? K : N);
std::vector<MLAS_SGEMM_DATA_PARAMS> data(max_len);
for (size_t i = 0; i < max_len; i++) {
if (packed_b_) {
MlasGemm(
trans_a ? CblasTrans : CblasNoTrans,
static_cast<size_t>(helper.M()),
static_cast<size_t>(helper.N()),
static_cast<size_t>(helper.K()),
alpha_attr_,
a_data + helper.LeftOffsets()[i],
static_cast<size_t>(trans_a ? helper.M() : helper.K()),
packed_b_.get(),
0.0f,
y_data + helper.OutputOffsets()[i],
static_cast<size_t>(helper.N()),
thread_pool);
continue;
}
math::Gemm<float, concurrency::ThreadPool>(
trans_a ? CblasTrans : CblasNoTrans,
trans_b ? CblasTrans : CblasNoTrans,
helper.M(),
helper.N(),
helper.K(),
alpha_attr_,
a_data + helper.LeftOffsets()[i],
b_data + helper.RightOffsets()[i],
0.0f,
y_data + helper.OutputOffsets()[i],
thread_pool);
data[i].BIsPacked = bool(packed_b_);
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].B = data[i].BIsPacked ? (float*)packed_b_.get() : b_data + helper.RightOffsets()[i];
data[i].ldb = ldb;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
data[i].alpha = alpha_attr_;
data[i].beta = 0.0f;
}
MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans,
M, N, K, data.data(), max_len, thread_pool);
return Status::OK();
}

View file

@ -108,6 +108,7 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) const {
MlasGemm(gemm_shape, gemm_params, ctx->GetOperatorThreadPool());
//TODO!! consider making this a post processor, so that we can parallize this loop
MlasRequantizeOutput(gemm_output,
y->template MutableData<uint8_t>() + helper.OutputOffsets()[i],
nullptr,

View file

@ -17,18 +17,18 @@ void QGEMM(benchmark::State& state, bool pack_b) {
const uint8_t a_zero_point = 29;
const uint8_t b_zero_point = 179;
const int64_t M = state.range(0);
const int64_t N = state.range(1);
const int64_t K = state.range(2);
if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!");
if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!");
if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!");
if (state.range(3) <= 0) throw std::invalid_argument("Batch must greater than 0!");
if (state.range(4) <= 0) throw std::invalid_argument("Threads must greater than 0!");
const int64_t batch = state.range(3);
const int64_t threads = state.range(4);
const size_t M = static_cast<size_t>(state.range(0));
const size_t N = static_cast<size_t>(state.range(1));
const size_t K = static_cast<size_t>(state.range(2));
if (M <= 0) throw std::invalid_argument("M must greater than 0!");
if (N <= 0) throw std::invalid_argument("N must greater than 0!");
if (K <= 0) throw std::invalid_argument("K must greater than 0!");
if (batch <= 0) throw std::invalid_argument("Batch must greater than 0!");
if (threads <= 0) throw std::invalid_argument("Threads must greater than 0!");
const size_t batch = static_cast<size_t>(state.range(3));
const size_t threads = static_cast<size_t>(state.range(4));
OrtThreadPoolParams tpo;
tpo.thread_pool_size = int(threads);

View file

@ -10,13 +10,13 @@
static const std::vector<std::string> sgemm_bench_arg_names = {"M", "N", "K"};
void SGEMM(benchmark::State& state, bool pack_b, bool trans_a, bool trans_b, float alpha = 1.0f, float beta = 0.0f) {
const int64_t M = state.range(0);
const int64_t N = state.range(1);
const int64_t K = state.range(2);
if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!");
if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!");
if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!");
const size_t M = static_cast<size_t>(state.range(0));
const size_t N = static_cast<size_t>(state.range(1));
const size_t K = static_cast<size_t>(state.range(2));
if (M <= 0) throw std::invalid_argument("M must greater than 0!");
if (N <= 0) throw std::invalid_argument("N must greater than 0!");
if (K <= 0) throw std::invalid_argument("K must greater than 0!");
auto A = RandomVectorUniform(static_cast<size_t>(M * K), -1.0f, 1.0f);
auto B = RandomVectorUniform(static_cast<size_t>(N * K), -1.0f, 1.0f);

View file

@ -7,7 +7,6 @@
#include <memory>
#include <sstream>
template <> MlasFgemmTest<float, false, false>* MlasTestFixture<MlasFgemmTest<float, false, false>>::mlas_tester(nullptr);
template <> MlasFgemmTest<float, false, true>* MlasTestFixture<MlasFgemmTest<float, false, true>>::mlas_tester(nullptr);
template <> MlasFgemmTest<float, true, false>* MlasTestFixture<MlasFgemmTest<float, true, false>>::mlas_tester(nullptr);

View file

@ -21,8 +21,8 @@ const char* GetGemmTestSuitePrefix<double>() {
template <typename T, bool Packed>
class FgemmPackedContext;
template <typename T>
class FgemmPackedContext<T, false> {
template <>
class FgemmPackedContext<float, false> {
public:
void
TestGemm(
@ -31,21 +31,69 @@ class FgemmPackedContext<T, false> {
size_t M,
size_t N,
size_t K,
float alpha,
const T* A,
size_t BatchSize,
const float alpha,
const float* A,
size_t lda,
const T* B,
const float* B,
size_t ldb,
float beta,
T* C,
const float beta,
float* C,
size_t ldc,
MLAS_THREADPOOL* threadpool) {
MlasGemm(TransA, TransB, M, N, K, T(alpha), A, lda, B, ldb, T(beta), C, ldc, threadpool);
std::vector<MLAS_SGEMM_DATA_PARAMS> data(BatchSize);
for (size_t i = 0; i < BatchSize; i++) {
data[i].A = A + M * K * i;
data[i].lda = lda;
data[i].B = B + K * N * i;
data[i].ldb = ldb;
data[i].C = C + M * N * i;
data[i].ldc = ldc;
data[i].alpha = alpha;
data[i].beta = beta;
}
MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool);
}
};
template <typename T>
class FgemmPackedContext<T, true> {
#ifdef MLAS_TARGET_AMD64
template <>
class FgemmPackedContext<double, false> {
public:
void TestGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
size_t BatchSize,
double alpha,
const double* A,
size_t lda,
const double* B,
size_t ldb,
double beta,
double* C,
size_t ldc,
MLAS_THREADPOOL* threadpool) {
std::vector<MLAS_DGEMM_DATA_PARAMS> data(BatchSize);
for (size_t i = 0; i < BatchSize; i++) {
data[i].A = A + M * K * i;
data[i].lda = lda;
data[i].B = B + K * N * i;
data[i].ldb = ldb;
data[i].C = C + M * N * i;
data[i].ldc = ldc;
data[i].alpha = alpha;
data[i].beta = beta;
}
MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool);
}
};
#endif
template <>
class FgemmPackedContext<float, true> {
public:
void
TestGemm(
@ -54,19 +102,34 @@ class FgemmPackedContext<T, true> {
size_t M,
size_t N,
size_t K,
float alpha,
const T* A,
size_t BatchSize,
const float alpha,
const float* A,
size_t lda,
const T* B,
const float* B,
size_t ldb,
float beta,
T* C,
const float beta,
float* C,
size_t ldc,
MLAS_THREADPOOL* threadpool) {
size_t PackedBSize = MlasGemmPackBSize(N, K);
void* PackedB = BufferBPacked.GetBuffer(PackedBSize, true);
MlasGemmPackB(TransB, N, K, B, ldb, PackedB);
MlasGemm(TransA, M, N, K, T(alpha), A, lda, PackedB, T(beta), C, ldc, threadpool);
void* PackedB = BufferBPacked.GetBuffer(PackedBSize * BatchSize, true);
std::vector<MLAS_SGEMM_DATA_PARAMS> data(BatchSize);
for (size_t i = 0; i < BatchSize; i++) {
MlasGemmPackB(TransB, N, K, B + K * N * i, ldb, (uint8_t*)PackedB + PackedBSize * i);
data[i].BIsPacked = true;
data[i].A = A + M * K * i;
data[i].lda = lda;
data[i].B = (float*)((uint8_t*)PackedB + PackedBSize * i);
data[i].ldb = ldb;
data[i].C = C + M * N * i;
data[i].ldc = ldc;
data[i].alpha = alpha;
data[i].beta = beta;
}
MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool);
MlasGemm(TransA, M, N, K, alpha, A, lda, PackedB, beta, C, ldc, threadpool);
}
private:
@ -89,14 +152,14 @@ class MlasFgemmTest : public MlasTestBase {
MlasFgemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) { }
void Test(size_t M, size_t N, size_t K, float alpha, float beta) {
Test(false, false, M, N, K, alpha, beta);
Test(false, true, M, N, K, alpha, beta);
Test(true, false, M, N, K, alpha, beta);
Test(true, true, M, N, K, alpha, beta);
void Test(size_t M, size_t N, size_t K, size_t BatchSize, T alpha, T beta) {
Test(false, false, M, N, K, BatchSize, alpha, beta);
Test(false, true, M, N, K, BatchSize, alpha, beta);
Test(true, false, M, N, K, BatchSize, alpha, beta);
Test(true, true, M, N, K, BatchSize, alpha, beta);
}
void Test(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, float alpha, float beta) {
void Test(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, size_t BatchSize, T alpha, T beta) {
//
// Skip the test if the B buffer cannot be packed.
//
@ -105,14 +168,14 @@ class MlasFgemmTest : public MlasTestBase {
return;
}
const T* A = BufferA.GetBuffer(K * M);
const T* B = BufferB.GetBuffer(N * K);
T* C = BufferC.GetBuffer(N * M);
T* CReference = BufferCReference.GetBuffer(N * M);
const T* A = BufferA.GetBuffer(K * M * BatchSize);
const T* B = BufferB.GetBuffer(N * K * BatchSize);
T* C = BufferC.GetBuffer(N * M * BatchSize);
T* CReference = BufferCReference.GetBuffer(N * M * BatchSize);
Test(trans_a ? CblasTrans : CblasNoTrans,
trans_b ? CblasTrans : CblasNoTrans,
M, N, K, alpha, A, trans_a ? M : K, B, trans_b ? K : N,
M, N, K, BatchSize, alpha, A, trans_a ? M : K, B, trans_b ? K : N,
beta, C, CReference, N);
}
@ -121,33 +184,36 @@ class MlasFgemmTest : public MlasTestBase {
size_t M,
size_t N,
size_t K,
float alpha,
size_t BatchSize,
T alpha,
const T* A,
size_t lda,
const T* B,
size_t ldb,
float beta,
T beta,
T* C,
T* CReference,
size_t ldc) {
std::fill_n(C, M * N, -0.5f);
std::fill_n(CReference, M * N, -0.5f);
std::fill_n(C, M * N * BatchSize, -0.5f);
std::fill_n(CReference, M * N * BatchSize, -0.5f);
PackedContext.TestGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, threadpool_);
ReferenceGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, CReference, ldc);
PackedContext.TestGemm(TransA, TransB, M, N, K, BatchSize, alpha, A, lda, B, ldb, beta, C, ldc, threadpool_);
ReferenceGemm(TransA, TransB, M, N, K, BatchSize, alpha, A, lda, B, ldb, beta, CReference, ldc);
for (size_t m = 0, f = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
// Sensitive to comparing positive/negative zero.
ASSERT_EQ(C[f], CReference[f])
<< " Diff @[" << m << ", " << n << "] f=" << f << ", "
<< (Packed ? "Packed" : "NoPack") << "."
<< (Threaded ? "SingleThread" : "Threaded") << "/"
<< (TransA == CblasTrans ? "TransA" : "A") << "/"
<< (TransB == CblasTrans ? "TransB" : "B") << "/"
<< "M" << M << "xN" << N << "xK" << K << "/"
<< "Alpha" << alpha << "/"
<< "Beta" << beta;
for (size_t batch = 0, f = 0; batch < BatchSize; batch++) {
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++, f++) {
// Sensitive to comparing positive/negative zero.
ASSERT_EQ(C[f], CReference[f])
<< " Diff @[" << batch << ", " << m << ", " << n << "] f=" << f << ", "
<< (Packed ? "Packed" : "NoPack") << "."
<< (Threaded ? "SingleThread" : "Threaded") << "/"
<< (TransA == CblasTrans ? "TransA" : "A") << "/"
<< (TransB == CblasTrans ? "TransB" : "B") << "/"
<< "M" << M << "xN" << N << "xK" << K << "/"
<< "Alpha" << alpha << "/"
<< "Beta" << beta;
}
}
}
}
@ -157,111 +223,120 @@ class MlasFgemmTest : public MlasTestBase {
size_t M,
size_t N,
size_t K,
float alpha,
size_t BatchSize,
T alpha,
const T* A,
size_t lda,
const T* B,
size_t ldb,
float beta,
T beta,
T* C,
size_t ldc) {
if (TransA == CblasNoTrans) {
if (TransB == CblasNoTrans) {
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
const T* a = A + (m * lda);
const T* b = B + n;
T* c = C + (m * ldc) + n;
T sum = 0.0f;
for (size_t batch = 0; batch < BatchSize; batch++) {
if (TransA == CblasNoTrans) {
if (TransB == CblasNoTrans) {
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
const T* a = A + (m * lda);
const T* b = B + n;
T* c = C + (m * ldc) + n;
T sum = 0.0f;
for (size_t k = 0; k < K; k++) {
sum += (*b * *a);
b += ldb;
a += 1;
for (size_t k = 0; k < K; k++) {
sum += (*b * *a);
b += ldb;
a += 1;
}
*c = (*c * beta) + (sum * alpha);
}
}
*c = (*c * beta) + (sum * alpha);
} else {
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
const T* a = A + (m * lda);
const T* b = B + (n * ldb);
T* c = C + (m * ldc) + n;
T sum = 0.0f;
for (size_t k = 0; k < K; k++) {
sum += (*b * *a);
b += 1;
a += 1;
}
*c = (*c * beta) + (sum * alpha);
}
}
}
} else {
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
const T* a = A + (m * lda);
const T* b = B + (n * ldb);
T* c = C + (m * ldc) + n;
T sum = 0.0f;
if (TransB == CblasNoTrans) {
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
const T* a = A + m;
const T* b = B + n;
T* c = C + (m * ldc) + n;
T sum = 0.0f;
for (size_t k = 0; k < K; k++) {
sum += (*b * *a);
b += 1;
a += 1;
for (size_t k = 0; k < K; k++) {
sum += (*b * *a);
b += ldb;
a += lda;
}
*c = (*c * beta) + (sum * alpha);
}
}
*c = (*c * beta) + (sum * alpha);
}
}
}
} else {
if (TransB == CblasNoTrans) {
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
const T* a = A + m;
const T* b = B + n;
T* c = C + (m * ldc) + n;
T sum = 0.0f;
for (size_t k = 0; k < K; k++) {
sum += (*b * *a);
b += ldb;
a += lda;
}
*c = (*c * beta) + (sum * alpha);
}
}
} else {
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
const T* a = A + m;
const T* b = B + (n * ldb);
T* c = C + (m * ldc) + n;
T sum = 0.0f;
for (size_t k = 0; k < K; k++) {
sum += (*b * *a);
b += 1;
a += lda;
}
*c = (*c * beta) + (sum * alpha);
} else {
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
const T* a = A + m;
const T* b = B + (n * ldb);
T* c = C + (m * ldc) + n;
T sum = 0.0f;
for (size_t k = 0; k < K; k++) {
sum += (*b * *a);
b += 1;
a += lda;
}
*c = (*c * beta) + (sum * alpha);
}
}
}
}
A += M * K;
B += K * N;
C += M * N;
}
}
void ExecuteLong() override {
static const float multipliers[] = {0.0f, -0.0f, 0.25f, -0.5f, 1.0f, -1.0f};
static const T multipliers[] = {0.0f, -0.0f, 0.25f, -0.5f, 1.0f, -1.0f};
for (size_t N = 1; N < 128; N++) {
for (size_t K = 1; K < 128; K++) {
for (size_t a = 0; a < _countof(multipliers); a++) {
for (size_t b = 0; b < _countof(multipliers); b++) {
Test(1, N, K, multipliers[a], multipliers[b]);
Test(N, 1, K, multipliers[a], multipliers[b]);
Test(1, N, K, 1, multipliers[a], multipliers[b]);
Test(N, 1, K, 1, multipliers[a], multipliers[b]);
if (!Packed) {
Test(1, N, K, 3, multipliers[a], multipliers[b]);
}
}
}
}
}
for (size_t a = 0; a < _countof(multipliers); a++) {
float alpha = multipliers[a];
T alpha = multipliers[a];
for (size_t b = 0; b < _countof(multipliers); b++) {
float beta = multipliers[b];
T beta = multipliers[b];
for (size_t M = 16; M < 160; M += 32) {
for (size_t N = 16; N < 160; N += 32) {
@ -269,21 +344,25 @@ class MlasFgemmTest : public MlasTestBase {
for (size_t k = 0; k < _countof(ks); k++) {
size_t K = ks[k];
Test(M, N, K, alpha, beta);
Test(M + 1, N, K, alpha, beta);
Test(M, N + 1, K, alpha, beta);
Test(M + 1, N + 1, K, alpha, beta);
Test(M + 3, N + 2, K, alpha, beta);
Test(M + 4, N, K, alpha, beta);
Test(M, N + 4, K, alpha, beta);
Test(M + 4, N + 4, K, alpha, beta);
Test(M + 3, N + 7, K, alpha, beta);
Test(M + 8, N, K, alpha, beta);
Test(M, N + 8, K, alpha, beta);
Test(M + 12, N + 12, K, alpha, beta);
Test(M + 13, N, K, alpha, beta);
Test(M, N + 15, K, alpha, beta);
Test(M + 15, N + 15, K, alpha, beta);
Test(M, N, K, 1, alpha, beta);
Test(M + 1, N, K, 1, alpha, beta);
Test(M, N + 1, K, 1, alpha, beta);
Test(M + 1, N + 1, K, 1, alpha, beta);
Test(M + 3, N + 2, K, 1, alpha, beta);
Test(M + 4, N, K, 1, alpha, beta);
Test(M, N + 4, K, 1, alpha, beta);
Test(M + 4, N + 4, K, 1, alpha, beta);
Test(M + 3, N + 7, K, 1, alpha, beta);
Test(M + 8, N, K, 1, alpha, beta);
Test(M, N + 8, K, 1, alpha, beta);
Test(M + 12, N + 12, K, 1, alpha, beta);
Test(M + 13, N, K, 1, alpha, beta);
Test(M, N + 15, K, 1, alpha, beta);
Test(M + 15, N + 15, K, 1, alpha, beta);
if (!Packed) {
Test(M + 3, N + 1, K, 7, multipliers[a], multipliers[b]);
Test(M + 13, N + 2, K, 9, multipliers[a], multipliers[b]);
}
}
}
printf("a %zd/%zd b %zd/%zd M %zd\n", a, _countof(multipliers), b, _countof(multipliers), M);
@ -294,7 +373,7 @@ class MlasFgemmTest : public MlasTestBase {
for (size_t M = 0; M < 160; M++) {
for (size_t N = 0; N < 160; N++) {
for (size_t K = 0; K < 160; K++) {
Test(M, N, K, 1.0f, 0.0f);
Test(M, N, K, 1, 1.0f, 0.0f);
}
}
printf("M %zd\n", M);
@ -303,10 +382,10 @@ class MlasFgemmTest : public MlasTestBase {
for (size_t M = 160; M < 320; M += 24) {
for (size_t N = 112; N < 320; N += 24) {
for (size_t K = 0; K < 16; K++) {
Test(M, N, K, 1.0f, 0.0f);
Test(M, N, K, 1, 1.0f, 0.0f);
}
for (size_t K = 16; K < 160; K += 32) {
Test(M, N, K, 1.0f, 0.0f);
Test(M, N, K, 1, 1.0f, 0.0f);
}
}
printf("M %zd\n", M);

View file

@ -14,20 +14,20 @@
template <typename T, bool Packed, bool Threaded>
class FgemmShortExecuteTest : public MlasTestFixture<MlasFgemmTest<T, Packed, Threaded>> {
public:
explicit FgemmShortExecuteTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, float alpha, float beta)
: trans_a_(trans_a), trans_b_(trans_b), M_(M), N_(N), K_(K), alpha_(alpha), beta_(beta) {
explicit FgemmShortExecuteTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, size_t BatchSize, float alpha, float beta)
: trans_a_(trans_a), trans_b_(trans_b), M_(M), N_(N), K_(K), Batch_(BatchSize), alpha_(alpha), beta_(beta) {
}
void TestBody() override {
MlasTestFixture<MlasFgemmTest<T, Packed, Threaded>>::mlas_tester->Test(
trans_a_, trans_b_, M_, N_, K_, alpha_, beta_);
trans_a_, trans_b_, M_, N_, K_, Batch_, alpha_, beta_);
}
static size_t RegisterSingleTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, float alpha, float beta) {
static size_t RegisterSingleTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, size_t BatchSize, float alpha, float beta) {
std::stringstream ss;
ss << (trans_a ? "TransA" : "A") << "/"
<< (trans_b ? "TransB" : "B") << "/"
<< "M" << M << "xN" << N << "xK" << K << "/"
<< "BatchSize" << BatchSize << "/M" << M << "xN" << N << "xK" << K << "/"
<< "Alpha" << alpha << "/"
<< "Beta" << beta;
auto test_name = ss.str();
@ -42,37 +42,39 @@ class FgemmShortExecuteTest : public MlasTestFixture<MlasFgemmTest<T, Packed, Th
// Important to use the fixture type as the return type here.
[=]() -> MlasTestFixture<MlasFgemmTest<T, Packed, Threaded>>* {
return new FgemmShortExecuteTest<T, Packed, Threaded>(
trans_a, trans_b, M, N, K, alpha, beta);
trans_a, trans_b, M, N, K, BatchSize, alpha, beta);
});
return 1;
}
static size_t RegisterTestTransposeABProduct(size_t M, size_t N, size_t K, float alpha, float beta) {
return RegisterSingleTest(false, false, M, N, K, alpha, beta) +
RegisterSingleTest(false, true, M, N, K, alpha, beta) +
RegisterSingleTest(true, false, M, N, K, alpha, beta) +
RegisterSingleTest(true, true, M, N, K, alpha, beta);
static size_t RegisterTestTransposeABProduct(size_t M, size_t N, size_t K, size_t BatchSize, float alpha, float beta) {
return RegisterSingleTest(false, false, M, N, K, BatchSize, alpha, beta) +
RegisterSingleTest(false, true, M, N, K, BatchSize, alpha, beta) +
RegisterSingleTest(true, false, M, N, K, BatchSize, alpha, beta) +
RegisterSingleTest(true, true, M, N, K, BatchSize, alpha, beta);
}
static size_t RegisterShortExecuteTests() {
size_t test_registered = 0;
for (size_t b = 0; b < 16; b++) {
test_registered += RegisterTestTransposeABProduct(b, b, b, 1.0f, 0.0f);
test_registered += RegisterTestTransposeABProduct(b, b, b, 1, 1.0f, 0.0f);
test_registered += RegisterTestTransposeABProduct(b, b, b, 3, 1.0f, 0.0f);
}
for (size_t b = 16; b <= 256; b <<= 1) {
test_registered += RegisterTestTransposeABProduct(b, b, b, 1.0f, 0.0f);
test_registered += RegisterTestTransposeABProduct(b, b, b, 1, 1.0f, 0.0f);
}
for (size_t b = 256; b < 320; b += 32) {
test_registered += RegisterTestTransposeABProduct(b, b, b, 1.0f, 0.0f);
test_registered += RegisterTestTransposeABProduct(b, b, b, 1, 1.0f, 0.0f);
}
test_registered += RegisterTestTransposeABProduct(128, 3072, 768, 1.0f, 0.0f);
test_registered += RegisterTestTransposeABProduct(128, 768, 3072, 1.0f, 0.0f);
test_registered += RegisterTestTransposeABProduct(128, 3072, 768, 1, 1.0f, 0.0f);
test_registered += RegisterTestTransposeABProduct(128, 768, 3072, 1, 1.0f, 0.0f);
test_registered += RegisterTestTransposeABProduct(25, 81, 79, 7, 1.0f, 0.0f);
return test_registered;
}
private:
bool trans_a_, trans_b_;
size_t M_, N_, K_;
float alpha_, beta_;
const size_t M_, N_, K_, Batch_;
const T alpha_, beta_;
};