mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Add batch interface to floating point GEMM (#7323)
Currently in high dimension matmul, we call multiple GEMM sequentially. In this change we execute these GEMMs in parallel, removing barriers between two adjacent GEMM operations. Performance tested with Bert and T5 model. Bert model shows no noticeable perf differences, as the heavy lifting is done by the attention operator, which is not changed in this PR. In T5 model, we see no regression on low parallel threads (x4), and performance improvement is more pronounced in high number of threads (8-16). T5 shows 10% speedup with 16 threads. With profiling, we can see the most expensive MatMul operators in T5 achieves around 20% speedup with 16 threads. Co-authored-by: Chen Fu <fuchen@microsoft.com>
This commit is contained in:
parent
7a3c1787af
commit
f4f2cc1a00
12 changed files with 624 additions and 623 deletions
|
|
@ -309,6 +309,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
T* qkv_dest = QKV[qkv_index];
|
||||
int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size);
|
||||
|
||||
// TODO!! memcpy here makes it not worthwhile to use Gemm batch. Possible to post process?
|
||||
// broadcast 3NH -> (3.B.N.S.H)
|
||||
const T* broadcast_data_src = bias_data + weights_offset;
|
||||
T* broadcast_data_dest = QKV[qkv_index] + qkv_offset;
|
||||
|
|
|
|||
|
|
@ -147,10 +147,102 @@ MlasActivation(
|
|||
|
||||
//
|
||||
// Matrix/matrix multiply routines.
|
||||
// C := alpha * op(A) * op(B) + beta * C
|
||||
// op(X) = X or op(X) = transpose(X) or op(X) = conjg(transpose(X))
|
||||
//
|
||||
|
||||
/**
|
||||
* @brief Supply matrices data information to single precision gemm functions
|
||||
*/
|
||||
struct MLAS_SGEMM_DATA_PARAMS {
|
||||
const float* A = nullptr; /**< Supplies the address of matrix A */
|
||||
size_t lda = 0; /**< Supplies the first dimension of matrix A. */
|
||||
const float* B = nullptr; /**< Supplies the address of matrix B */
|
||||
size_t ldb = 0; /**< Supplies the first dimension of matrix B. */
|
||||
float* C = nullptr; /**< Supplies the address of matrix C */
|
||||
size_t ldc = 0; /**< Supplies the first dimension of matrix C. */
|
||||
float alpha = 1.0f; /**< Supplies the scalar alpha multiplier (see SGEMM definition) */
|
||||
float beta = 0.0f; /**< Supplies the scalar beta multiplier (see SGEMM definition) */
|
||||
bool BIsPacked = false; /**< Whether B is pre-packed */
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Batched single precision matrix/matrix multiply operation (SGEMM)
|
||||
*
|
||||
* @param TransA Supplies the transpose operation for matrix A.
|
||||
* @param TransB Supplies the transpose operation for matrix B.
|
||||
* @param M Supplies the number of rows of matrix A and matrix C.
|
||||
* @param N Supplies the number of columns of matrix B and matrix C.
|
||||
* @param K Supplies the number of columns of matrix A and the number
|
||||
of rows of matrix B.
|
||||
* @param Data A array of matrices data parameters
|
||||
* @param BatchSize Supplies number of multiplications in this batch
|
||||
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
*/
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemmBatch(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
CBLAS_TRANSPOSE TransB,
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const MLAS_SGEMM_DATA_PARAMS* Data,
|
||||
size_t BatchSize,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Single precision matrix/matrix multiply operation (SGEMM)
|
||||
*
|
||||
* @param TransA Supplies the transpose operation for matrix A.
|
||||
* @param TransB Supplies the transpose operation for matrix B.
|
||||
* @param M Supplies the number of rows of matrix A and matrix C.
|
||||
* @param N Supplies the number of columns of matrix B and matrix C.
|
||||
* @param K Supplies the number of columns of matrix A and the number
|
||||
of rows of matrix B.
|
||||
* @param Data Supplies the matrices data parameters
|
||||
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
*/
|
||||
inline
|
||||
void
|
||||
MlasGemm(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
CBLAS_TRANSPOSE TransB,
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const MLAS_SGEMM_DATA_PARAMS& Data,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
{
|
||||
MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Single precision matrix/matrix multiply operation (SGEMM)
|
||||
*
|
||||
* @param TransA Supplies the transpose operation for matrix A.
|
||||
* @param TransB Supplies the transpose operation for matrix B.
|
||||
* @param M Supplies the number of rows of matrix A and matrix C.
|
||||
* @param N Supplies the number of columns of matrix B and matrix C.
|
||||
* @param K Supplies the number of columns of matrix A and the number
|
||||
of rows of matrix B.
|
||||
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
|
||||
* @param A Supplies the address of matrix A
|
||||
* @param lda Supplies the first dimension of matrix A.
|
||||
* @param B Supplies the address of matrix B
|
||||
* @param ldb Supplies the first dimension of matrix B.
|
||||
* @param beta Supplies the scalar beta multiplier (see SGEMM definition)
|
||||
* @param C Supplies the address of matrix C
|
||||
* @param ldc Supplies the first dimension of matrix C.
|
||||
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
*/
|
||||
inline
|
||||
void
|
||||
MlasGemm(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
CBLAS_TRANSPOSE TransB,
|
||||
|
|
@ -166,10 +258,41 @@ MlasGemm(
|
|||
float* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
);
|
||||
)
|
||||
{
|
||||
MLAS_SGEMM_DATA_PARAMS Data;
|
||||
Data.alpha = alpha;
|
||||
Data.A = A;
|
||||
Data.lda = lda;
|
||||
Data.B = B;
|
||||
Data.ldb = ldb;
|
||||
Data.beta = beta;
|
||||
Data.C = C;
|
||||
Data.ldc = ldc;
|
||||
|
||||
MlasGemm(TransA, TransB, M, N, K, Data, ThreadPool);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief the single precision matrix/matrix multiply operation (SGEMM) with pre-packed B
|
||||
*
|
||||
* @param TransA - Supplies the transpose operation for matrix A.
|
||||
* @param M - Supplies the number of rows of matrix A and matrix C.
|
||||
* @param N - Supplies the number of columns of matrix B and matrix C.
|
||||
* @param K - Supplies the number of columns of matrix A and the number
|
||||
of rows of matrix B.
|
||||
* @param alpha - Supplies the scalar alpha multiplier (see SGEMM definition).
|
||||
* @param A - Supplies the address of matrix A.
|
||||
* @param lda - Supplies the first dimension of matrix A.
|
||||
* @param PackedB - Supplies the address of packed matrix B.
|
||||
* @param beta - Supplies the scalar beta multiplier (see SGEMM definition).
|
||||
* @param C - Supplies the address of matrix C.
|
||||
* @param ldc - Supplies the first dimension of matrix C.
|
||||
* @param ThreadPool - Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
*/
|
||||
inline
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemm(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
size_t M,
|
||||
|
|
@ -183,10 +306,117 @@ MlasGemm(
|
|||
float* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
);
|
||||
)
|
||||
{
|
||||
MLAS_SGEMM_DATA_PARAMS DataParams;
|
||||
DataParams.A = A;
|
||||
DataParams.lda = lda;
|
||||
DataParams.B = static_cast<const float*>(PackedB);
|
||||
DataParams.ldb = 0;
|
||||
DataParams.C = C;
|
||||
DataParams.ldc = ldc;
|
||||
DataParams.alpha = alpha;
|
||||
DataParams.beta = beta;
|
||||
DataParams.BIsPacked = true;
|
||||
|
||||
MlasGemmBatch(TransA,
|
||||
CblasTrans, // deos not matter when B is packed
|
||||
M, N, K, &DataParams, 1, ThreadPool);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Supply matrices data information to double precision gemm functions
|
||||
*/
|
||||
struct MLAS_DGEMM_DATA_PARAMS {
|
||||
const double* A = nullptr; /**< Supplies the address of matrix A */
|
||||
size_t lda = 0; /**< Supplies the first dimension of matrix A. */
|
||||
const double* B = nullptr; /**< Supplies the address of matrix B */
|
||||
size_t ldb = 0; /**< Supplies the first dimension of matrix B. */
|
||||
double* C = nullptr; /**< Supplies the address of matrix C */
|
||||
size_t ldc = 0; /**< Supplies the first dimension of matrix C. */
|
||||
double alpha = 1.0; /**< Supplies the scalar alpha multiplier (see SGEMM definition) */
|
||||
double beta = 0.0; /**< Supplies the scalar beta multiplier (see SGEMM definition) */
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Batched double precision matrix/matrix multiply operation (DGEMM)
|
||||
*
|
||||
* @param TransA Supplies the transpose operation for matrix A.
|
||||
* @param TransB Supplies the transpose operation for matrix B.
|
||||
* @param M Supplies the number of rows of matrix A and matrix C.
|
||||
* @param N Supplies the number of columns of matrix B and matrix C.
|
||||
* @param K Supplies the number of columns of matrix A and the number
|
||||
of rows of matrix B.
|
||||
* @param Data A array of matrices data parameters
|
||||
* @param BatchSize Supplies number of multiplications in this batch
|
||||
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
*/
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemmBatch(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
CBLAS_TRANSPOSE TransB,
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const MLAS_DGEMM_DATA_PARAMS* Data,
|
||||
size_t BatchSize,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
);
|
||||
|
||||
|
||||
/**
|
||||
* @brief Double precision matrix/matrix multiply operation (DGEMM)
|
||||
*
|
||||
* @param TransA Supplies the transpose operation for matrix A.
|
||||
* @param TransB Supplies the transpose operation for matrix B.
|
||||
* @param M Supplies the number of rows of matrix A and matrix C.
|
||||
* @param N Supplies the number of columns of matrix B and matrix C.
|
||||
* @param K Supplies the number of columns of matrix A and the number
|
||||
of rows of matrix B.
|
||||
* @param Data Supplies the matrices data parameters
|
||||
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
*/
|
||||
inline
|
||||
void
|
||||
MlasGemm(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
CBLAS_TRANSPOSE TransB,
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const MLAS_DGEMM_DATA_PARAMS& Data,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
{
|
||||
MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Double precision matrix/matrix multiply operation (DGEMM)
|
||||
*
|
||||
* @param TransA Supplies the transpose operation for matrix A.
|
||||
* @param TransB Supplies the transpose operation for matrix B.
|
||||
* @param M Supplies the number of rows of matrix A and matrix C.
|
||||
* @param N Supplies the number of columns of matrix B and matrix C.
|
||||
* @param K Supplies the number of columns of matrix A and the number
|
||||
of rows of matrix B.
|
||||
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
|
||||
* @param A Supplies the address of matrix A
|
||||
* @param lda Supplies the first dimension of matrix A.
|
||||
* @param B Supplies the address of matrix B
|
||||
* @param ldb Supplies the first dimension of matrix B.
|
||||
* @param beta Supplies the scalar beta multiplier (see SGEMM definition)
|
||||
* @param C Supplies the address of matrix C
|
||||
* @param ldc Supplies the first dimension of matrix C.
|
||||
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
*/
|
||||
inline
|
||||
void
|
||||
MlasGemm(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
CBLAS_TRANSPOSE TransB,
|
||||
|
|
@ -202,7 +432,19 @@ MlasGemm(
|
|||
double* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
);
|
||||
)
|
||||
{
|
||||
MLAS_DGEMM_DATA_PARAMS Data;
|
||||
Data.alpha = alpha;
|
||||
Data.A = A;
|
||||
Data.lda = lda;
|
||||
Data.B = B;
|
||||
Data.ldb = ldb;
|
||||
Data.beta = beta;
|
||||
Data.C = C;
|
||||
Data.ldc = ldc;
|
||||
MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool);
|
||||
}
|
||||
|
||||
enum class MLAS_QUANTIZATION_GRANULARITY {
|
||||
PerMatrix,
|
||||
|
|
|
|||
|
|
@ -26,29 +26,6 @@ Abstract:
|
|||
|
||||
#define MLAS_DGEMM_TRANSA_ROWS 12
|
||||
|
||||
//
|
||||
// Define the parameters to execute segments of a DGEMM operation on worker
|
||||
// threads.
|
||||
//
|
||||
|
||||
struct MLAS_DGEMM_WORK_BLOCK {
|
||||
ptrdiff_t ThreadCountM;
|
||||
ptrdiff_t ThreadCountN;
|
||||
CBLAS_TRANSPOSE TransA;
|
||||
CBLAS_TRANSPOSE TransB;
|
||||
size_t M;
|
||||
size_t N;
|
||||
size_t K;
|
||||
const double* A;
|
||||
size_t lda;
|
||||
const double* B;
|
||||
size_t ldb;
|
||||
double* C;
|
||||
size_t ldc;
|
||||
double alpha;
|
||||
double beta;
|
||||
};
|
||||
|
||||
#ifdef MLAS_TARGET_AMD64
|
||||
|
||||
void
|
||||
|
|
@ -750,8 +727,15 @@ Return Value:
|
|||
|
||||
void
|
||||
MlasDgemmThreaded(
|
||||
void* Context,
|
||||
ptrdiff_t ThreadId
|
||||
const ptrdiff_t ThreadCountM,
|
||||
const ptrdiff_t ThreadCountN,
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const size_t M,
|
||||
const size_t N,
|
||||
const size_t K,
|
||||
const MLAS_DGEMM_DATA_PARAMS* Data,
|
||||
const ptrdiff_t ThreadId
|
||||
)
|
||||
/*++
|
||||
|
||||
|
|
@ -772,10 +756,6 @@ Return Value:
|
|||
|
||||
--*/
|
||||
{
|
||||
const auto* WorkBlock = (MLAS_DGEMM_WORK_BLOCK*)Context;
|
||||
|
||||
const ptrdiff_t ThreadCountM = WorkBlock->ThreadCountM;
|
||||
const ptrdiff_t ThreadCountN = WorkBlock->ThreadCountN;
|
||||
|
||||
const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN;
|
||||
const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN;
|
||||
|
|
@ -784,7 +764,6 @@ Return Value:
|
|||
// Partition the operation along the M dimension.
|
||||
//
|
||||
|
||||
size_t M = WorkBlock->M;
|
||||
size_t RangeStartM;
|
||||
size_t RangeCountM;
|
||||
|
||||
|
|
@ -794,7 +773,6 @@ Return Value:
|
|||
// Partition the operation along the N dimension.
|
||||
//
|
||||
|
||||
size_t N = WorkBlock->N;
|
||||
size_t RangeStartN;
|
||||
size_t RangeCountN;
|
||||
|
||||
|
|
@ -813,50 +791,32 @@ Return Value:
|
|||
// Dispatch the partitioned operation.
|
||||
//
|
||||
|
||||
CBLAS_TRANSPOSE TransA = WorkBlock->TransA;
|
||||
CBLAS_TRANSPOSE TransB = WorkBlock->TransB;
|
||||
const size_t lda = Data->lda;
|
||||
const size_t ldb = Data->ldb;
|
||||
const size_t ldc = Data->ldc;
|
||||
|
||||
const size_t lda = WorkBlock->lda;
|
||||
const size_t ldb = WorkBlock->ldb;
|
||||
const size_t ldc = WorkBlock->ldc;
|
||||
const double* A = Data->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1);
|
||||
const double* B = Data->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb);
|
||||
double* C = Data->C + RangeStartM * ldc + RangeStartN;
|
||||
|
||||
const double* A = WorkBlock->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1);
|
||||
const double* B = WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb);
|
||||
double* C = WorkBlock->C + RangeStartM * ldc + RangeStartN;
|
||||
|
||||
MlasDgemmOperation(TransA, TransB, RangeCountM, RangeCountN, WorkBlock->K,
|
||||
WorkBlock->alpha, A, lda, B, ldb, WorkBlock->beta, C, ldc);
|
||||
MlasDgemmOperation(TransA, TransB, RangeCountM, RangeCountN, K,
|
||||
Data->alpha, A, lda, B, ldb, Data->beta, C, ldc);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
MlasDgemmSchedule(
|
||||
MLAS_DGEMM_WORK_BLOCK* WorkBlock,
|
||||
MLASCALL
|
||||
MlasGemmBatch(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
CBLAS_TRANSPOSE TransB,
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const MLAS_DGEMM_DATA_PARAMS* Data,
|
||||
size_t BatchSize,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine schedules the double precision matrix/matrix multiply
|
||||
operation (DGEMM) across one or more threads.
|
||||
|
||||
Arguments:
|
||||
|
||||
WorkBlock - Supplies the structure containing the GEMM parameters.
|
||||
|
||||
ThreadPool - Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
const size_t M = WorkBlock->M;
|
||||
const size_t N = WorkBlock->N;
|
||||
const size_t K = WorkBlock->K;
|
||||
|
||||
//
|
||||
// Compute the number of target threads given the complexity of the DGEMM
|
||||
// operation. Small requests should run using the single threaded path.
|
||||
|
|
@ -885,121 +845,40 @@ Return Value:
|
|||
// works okay for operations involving skinny matrices.
|
||||
//
|
||||
|
||||
ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchSize - 1) / BatchSize;
|
||||
ptrdiff_t ThreadCountM;
|
||||
ptrdiff_t ThreadCountN;
|
||||
|
||||
if (N > M) {
|
||||
|
||||
const size_t BlockedN = (N + MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1) /
|
||||
MLAS_DGEMM_STRIDEN_THREAD_ALIGN;
|
||||
|
||||
if (size_t(TargetThreadCount) > BlockedN) {
|
||||
TargetThreadCount = ptrdiff_t(BlockedN);
|
||||
if (size_t(ThreadsPerGemm) > BlockedN) {
|
||||
ThreadsPerGemm = ptrdiff_t(BlockedN);
|
||||
}
|
||||
|
||||
WorkBlock->ThreadCountM = 1;
|
||||
WorkBlock->ThreadCountN = TargetThreadCount;
|
||||
ThreadCountM = 1;
|
||||
ThreadCountN = ThreadsPerGemm;
|
||||
|
||||
} else {
|
||||
|
||||
if (size_t(TargetThreadCount) > M) {
|
||||
TargetThreadCount = ptrdiff_t(M);
|
||||
if (size_t(ThreadsPerGemm) > M) {
|
||||
ThreadsPerGemm = ptrdiff_t(M);
|
||||
}
|
||||
|
||||
WorkBlock->ThreadCountM = TargetThreadCount;
|
||||
WorkBlock->ThreadCountN = 1;
|
||||
ThreadCountM = ThreadsPerGemm;
|
||||
ThreadCountN = 1;
|
||||
}
|
||||
|
||||
MlasExecuteThreaded(MlasDgemmThreaded, WorkBlock, TargetThreadCount, ThreadPool);
|
||||
}
|
||||
const ptrdiff_t TotalThreads = ThreadsPerGemm * static_cast<ptrdiff_t>(BatchSize);
|
||||
MlasTrySimpleParallel(ThreadPool, TotalThreads, [=](ptrdiff_t tid) {
|
||||
const ptrdiff_t GemmIdx = tid / ThreadsPerGemm;
|
||||
const ptrdiff_t ThreadIdx = tid % ThreadsPerGemm;
|
||||
MlasDgemmThreaded(ThreadCountM, ThreadCountN, TransA, TransB,
|
||||
M, N, K, &(Data[GemmIdx]), ThreadIdx);
|
||||
});
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemm(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
CBLAS_TRANSPOSE TransB,
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
double alpha,
|
||||
const double* A,
|
||||
size_t lda,
|
||||
const double* B,
|
||||
size_t ldb,
|
||||
double beta,
|
||||
double* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine implements the double precision matrix/matrix multiply
|
||||
operation (DGEMM).
|
||||
|
||||
Arguments:
|
||||
|
||||
TransA - Supplies the transpose operation for matrix A.
|
||||
|
||||
TransB - Supplies the transpose operation for matrix B.
|
||||
|
||||
M - Supplies the number of rows of matrix A and matrix C.
|
||||
|
||||
N - Supplies the number of columns of matrix B and matrix C.
|
||||
|
||||
K - Supplies the number of columns of matrix A and the number of rows of
|
||||
matrix B.
|
||||
|
||||
alpha - Supplies the scalar alpha multiplier (see DGEMM definition).
|
||||
|
||||
A - Supplies the address of matrix A.
|
||||
|
||||
lda - Supplies the first dimension of matrix A.
|
||||
|
||||
B - Supplies the address of matrix B.
|
||||
|
||||
ldb - Supplies the first dimension of matrix B.
|
||||
|
||||
beta - Supplies the scalar beta multiplier (see DGEMM definition).
|
||||
|
||||
C - Supplies the address of matrix C.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
ThreadPool - Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
MLAS_DGEMM_WORK_BLOCK WorkBlock;
|
||||
|
||||
//
|
||||
// Capture the GEMM parameters to the work block.
|
||||
//
|
||||
|
||||
memset(&WorkBlock, 0, sizeof(MLAS_DGEMM_WORK_BLOCK));
|
||||
|
||||
WorkBlock.TransA = TransA;
|
||||
WorkBlock.TransB = TransB;
|
||||
WorkBlock.M = M;
|
||||
WorkBlock.N = N;
|
||||
WorkBlock.K = K;
|
||||
WorkBlock.A = A;
|
||||
WorkBlock.lda = lda;
|
||||
WorkBlock.B = B;
|
||||
WorkBlock.ldb = ldb;
|
||||
WorkBlock.C = C;
|
||||
WorkBlock.ldc = ldc;
|
||||
WorkBlock.alpha = alpha;
|
||||
WorkBlock.beta = beta;
|
||||
|
||||
//
|
||||
// Schedule the operation across a set of worker threads.
|
||||
//
|
||||
|
||||
MlasDgemmSchedule(&WorkBlock, ThreadPool);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -31,25 +31,6 @@ Abstract:
|
|||
// threads.
|
||||
//
|
||||
|
||||
struct MLAS_SGEMM_WORK_BLOCK {
|
||||
ptrdiff_t ThreadCountM;
|
||||
ptrdiff_t ThreadCountN;
|
||||
CBLAS_TRANSPOSE TransA;
|
||||
CBLAS_TRANSPOSE TransB;
|
||||
size_t M;
|
||||
size_t N;
|
||||
size_t K;
|
||||
const float* A;
|
||||
size_t lda;
|
||||
const void* B;
|
||||
size_t ldb;
|
||||
float* C;
|
||||
size_t ldc;
|
||||
float alpha;
|
||||
float beta;
|
||||
bool BIsPacked;
|
||||
};
|
||||
|
||||
void
|
||||
MlasSgemmMultiplyBeta(
|
||||
float* C,
|
||||
|
|
@ -1476,7 +1457,15 @@ Return Value:
|
|||
|
||||
void
|
||||
MlasSgemmThreaded(
|
||||
void* Context,
|
||||
const ptrdiff_t ThreadCountM,
|
||||
const ptrdiff_t ThreadCountN,
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const size_t M,
|
||||
const size_t N,
|
||||
const size_t K,
|
||||
|
||||
const MLAS_SGEMM_DATA_PARAMS* DataParams,
|
||||
ptrdiff_t ThreadId
|
||||
)
|
||||
/*++
|
||||
|
|
@ -1488,7 +1477,17 @@ Routine Description:
|
|||
|
||||
Arguments:
|
||||
|
||||
Context - Supplies the pointer to the context for the threaded operation.
|
||||
ThreadCountM - Supplies the total thread partition on the M dimension.
|
||||
|
||||
ThreadCountN - Supplies the total thread partition on the N dimension.
|
||||
|
||||
TransA - Supplies the transpose operation on A matrix
|
||||
|
||||
TransB - Supplies the transpose operation on B matrix
|
||||
|
||||
M, N, K - Supplies the shape of the multiplication
|
||||
|
||||
DataParams - Supplies the data position and layout of the matrices
|
||||
|
||||
ThreadId - Supplies the current index of the threaded operation.
|
||||
|
||||
|
|
@ -1498,10 +1497,6 @@ Return Value:
|
|||
|
||||
--*/
|
||||
{
|
||||
const auto* WorkBlock = (MLAS_SGEMM_WORK_BLOCK*)Context;
|
||||
|
||||
const ptrdiff_t ThreadCountM = WorkBlock->ThreadCountM;
|
||||
const ptrdiff_t ThreadCountN = WorkBlock->ThreadCountN;
|
||||
|
||||
const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN;
|
||||
const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN;
|
||||
|
|
@ -1510,7 +1505,6 @@ Return Value:
|
|||
// Partition the operation along the M dimension.
|
||||
//
|
||||
|
||||
size_t M = WorkBlock->M;
|
||||
size_t RangeStartM;
|
||||
size_t RangeCountM;
|
||||
|
||||
|
|
@ -1520,7 +1514,6 @@ Return Value:
|
|||
// Partition the operation along the N dimension.
|
||||
//
|
||||
|
||||
size_t N = WorkBlock->N;
|
||||
size_t RangeStartN;
|
||||
size_t RangeCountN;
|
||||
|
||||
|
|
@ -1539,61 +1532,42 @@ Return Value:
|
|||
// Dispatch the partitioned operation.
|
||||
//
|
||||
|
||||
CBLAS_TRANSPOSE TransA = WorkBlock->TransA;
|
||||
const size_t lda = DataParams->lda;
|
||||
const size_t ldc = DataParams->ldc;
|
||||
|
||||
const size_t lda = WorkBlock->lda;
|
||||
const size_t ldc = WorkBlock->ldc;
|
||||
const float* A = DataParams->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1);
|
||||
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;
|
||||
|
||||
const float* A = WorkBlock->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1);
|
||||
float* C = WorkBlock->C + RangeStartM * ldc + RangeStartN;
|
||||
|
||||
if (WorkBlock->BIsPacked) {
|
||||
if (DataParams->BIsPacked) {
|
||||
|
||||
MlasSgemmPackedOperation(TransA, RangeCountM, RangeStartN, RangeCountN,
|
||||
WorkBlock->K, WorkBlock->alpha, A, lda, WorkBlock->B,
|
||||
BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, WorkBlock->beta, C, ldc);
|
||||
K, DataParams->alpha, A, lda, DataParams->B,
|
||||
BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, DataParams->beta, C, ldc);
|
||||
|
||||
} else {
|
||||
|
||||
CBLAS_TRANSPOSE TransB = WorkBlock->TransB;
|
||||
const size_t ldb = DataParams->ldb;
|
||||
|
||||
const size_t ldb = WorkBlock->ldb;
|
||||
const float* B = (const float*)DataParams->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb);
|
||||
|
||||
const float* B = (const float*)WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb);
|
||||
|
||||
MlasSgemmOperation(TransA, TransB, RangeCountM, RangeCountN, WorkBlock->K,
|
||||
WorkBlock->alpha, A, lda, B, ldb, WorkBlock->beta, C, ldc);
|
||||
MlasSgemmOperation(TransA, TransB, RangeCountM, RangeCountN, K,
|
||||
DataParams->alpha, A, lda, B, ldb, DataParams->beta, C, ldc);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
MlasSgemmSchedule(
|
||||
MLAS_SGEMM_WORK_BLOCK* WorkBlock,
|
||||
MLASCALL
|
||||
MlasGemmBatch(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
CBLAS_TRANSPOSE TransB,
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const MLAS_SGEMM_DATA_PARAMS* Data,
|
||||
size_t BatchSize,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine schedules the single precision matrix/matrix multiply
|
||||
operation (SGEMM) across one or more threads.
|
||||
|
||||
Arguments:
|
||||
|
||||
WorkBlock - Supplies the structure containing the GEMM parameters.
|
||||
|
||||
ThreadPool - Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
const size_t M = WorkBlock->M;
|
||||
const size_t N = WorkBlock->N;
|
||||
const size_t K = WorkBlock->K;
|
||||
|
||||
//
|
||||
// Compute the number of target threads given the complexity of the SGEMM
|
||||
|
|
@ -1623,206 +1597,41 @@ Return Value:
|
|||
// works okay for operations involving skinny matrices.
|
||||
//
|
||||
|
||||
ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchSize - 1) / BatchSize;
|
||||
ptrdiff_t ThreadCountM;
|
||||
ptrdiff_t ThreadCountN;
|
||||
|
||||
if (N > M) {
|
||||
|
||||
const size_t BlockedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) /
|
||||
MLAS_SGEMM_STRIDEN_THREAD_ALIGN;
|
||||
|
||||
if (size_t(TargetThreadCount) > BlockedN) {
|
||||
TargetThreadCount = ptrdiff_t(BlockedN);
|
||||
if (size_t(ThreadsPerGemm) > BlockedN) {
|
||||
ThreadsPerGemm = ptrdiff_t(BlockedN);
|
||||
}
|
||||
|
||||
WorkBlock->ThreadCountM = 1;
|
||||
WorkBlock->ThreadCountN = TargetThreadCount;
|
||||
ThreadCountM = 1;
|
||||
ThreadCountN = ThreadsPerGemm;
|
||||
|
||||
} else {
|
||||
|
||||
if (size_t(TargetThreadCount) > M) {
|
||||
TargetThreadCount = ptrdiff_t(M);
|
||||
if (size_t(ThreadsPerGemm) > M) {
|
||||
ThreadsPerGemm = ptrdiff_t(M);
|
||||
}
|
||||
|
||||
WorkBlock->ThreadCountM = TargetThreadCount;
|
||||
WorkBlock->ThreadCountN = 1;
|
||||
ThreadCountM = ThreadsPerGemm;
|
||||
ThreadCountN = 1;
|
||||
}
|
||||
|
||||
MlasExecuteThreaded(MlasSgemmThreaded, WorkBlock, TargetThreadCount, ThreadPool);
|
||||
}
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemm(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
CBLAS_TRANSPOSE TransB,
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
float alpha,
|
||||
const float* A,
|
||||
size_t lda,
|
||||
const float* B,
|
||||
size_t ldb,
|
||||
float beta,
|
||||
float* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine implements the single precision matrix/matrix multiply
|
||||
operation (SGEMM).
|
||||
|
||||
Arguments:
|
||||
|
||||
TransA - Supplies the transpose operation for matrix A.
|
||||
|
||||
TransB - Supplies the transpose operation for matrix B.
|
||||
|
||||
M - Supplies the number of rows of matrix A and matrix C.
|
||||
|
||||
N - Supplies the number of columns of matrix B and matrix C.
|
||||
|
||||
K - Supplies the number of columns of matrix A and the number of rows of
|
||||
matrix B.
|
||||
|
||||
alpha - Supplies the scalar alpha multiplier (see SGEMM definition).
|
||||
|
||||
A - Supplies the address of matrix A.
|
||||
|
||||
lda - Supplies the first dimension of matrix A.
|
||||
|
||||
B - Supplies the address of matrix B.
|
||||
|
||||
ldb - Supplies the first dimension of matrix B.
|
||||
|
||||
beta - Supplies the scalar beta multiplier (see SGEMM definition).
|
||||
|
||||
C - Supplies the address of matrix C.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
ThreadPool - Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
MLAS_SGEMM_WORK_BLOCK WorkBlock;
|
||||
|
||||
//
|
||||
// Capture the GEMM parameters to the work block.
|
||||
//
|
||||
|
||||
memset(&WorkBlock, 0, sizeof(MLAS_SGEMM_WORK_BLOCK));
|
||||
|
||||
WorkBlock.TransA = TransA;
|
||||
WorkBlock.TransB = TransB;
|
||||
WorkBlock.M = M;
|
||||
WorkBlock.N = N;
|
||||
WorkBlock.K = K;
|
||||
WorkBlock.A = A;
|
||||
WorkBlock.lda = lda;
|
||||
WorkBlock.B = B;
|
||||
WorkBlock.ldb = ldb;
|
||||
WorkBlock.C = C;
|
||||
WorkBlock.ldc = ldc;
|
||||
WorkBlock.alpha = alpha;
|
||||
WorkBlock.beta = beta;
|
||||
|
||||
//
|
||||
// Schedule the operation across a set of worker threads.
|
||||
//
|
||||
|
||||
MlasSgemmSchedule(&WorkBlock, ThreadPool);
|
||||
}
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasGemm(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
float alpha,
|
||||
const float* A,
|
||||
size_t lda,
|
||||
const void* PackedB,
|
||||
float beta,
|
||||
float* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine implements the single precision matrix/matrix multiply
|
||||
operation (SGEMM).
|
||||
|
||||
Arguments:
|
||||
|
||||
TransA - Supplies the transpose operation for matrix A.
|
||||
|
||||
M - Supplies the number of rows of matrix A and matrix C.
|
||||
|
||||
N - Supplies the number of columns of matrix B and matrix C.
|
||||
|
||||
K - Supplies the number of columns of matrix A and the number of rows of
|
||||
matrix B.
|
||||
|
||||
alpha - Supplies the scalar alpha multiplier (see SGEMM definition).
|
||||
|
||||
A - Supplies the address of matrix A.
|
||||
|
||||
lda - Supplies the first dimension of matrix A.
|
||||
|
||||
PackedB - Supplies the address of packed matrix B.
|
||||
|
||||
beta - Supplies the scalar beta multiplier (see SGEMM definition).
|
||||
|
||||
C - Supplies the address of matrix C.
|
||||
|
||||
ldc - Supplies the first dimension of matrix C.
|
||||
|
||||
ThreadPool - Supplies the thread pool object to use, else nullptr if the
|
||||
base library threading support should be used.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
MLAS_SGEMM_WORK_BLOCK WorkBlock;
|
||||
|
||||
//
|
||||
// Capture the GEMM parameters to the work block.
|
||||
//
|
||||
|
||||
memset(&WorkBlock, 0, sizeof(MLAS_SGEMM_WORK_BLOCK));
|
||||
|
||||
WorkBlock.TransA = TransA;
|
||||
WorkBlock.M = M;
|
||||
WorkBlock.N = N;
|
||||
WorkBlock.K = K;
|
||||
WorkBlock.A = A;
|
||||
WorkBlock.lda = lda;
|
||||
WorkBlock.B = PackedB;
|
||||
WorkBlock.C = C;
|
||||
WorkBlock.ldc = ldc;
|
||||
WorkBlock.alpha = alpha;
|
||||
WorkBlock.beta = beta;
|
||||
WorkBlock.BIsPacked = true;
|
||||
|
||||
//
|
||||
// Schedule the operation across a set of worker threads.
|
||||
//
|
||||
|
||||
MlasSgemmSchedule(&WorkBlock, ThreadPool);
|
||||
MlasTrySimpleParallel(ThreadPool,
|
||||
ThreadsPerGemm * static_cast<ptrdiff_t>(BatchSize),
|
||||
[=](ptrdiff_t tid)
|
||||
{
|
||||
ptrdiff_t GemmIdx = tid / ThreadsPerGemm;
|
||||
ptrdiff_t ThreadIdx = tid % ThreadsPerGemm;
|
||||
MlasSgemmThreaded(ThreadCountM, ThreadCountN,
|
||||
TransA, TransB, M, N, K, &(Data[GemmIdx]), ThreadIdx);
|
||||
});
|
||||
}
|
||||
|
||||
size_t
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ MlasTrySimpleParallel(
|
|||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
|
||||
|
||||
for (ptrdiff_t tid = 0; tid < Iterations; tid++) {
|
||||
Work(tid);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -159,38 +159,27 @@ Status MatMul<float>::Compute(OpKernelContext* ctx) const {
|
|||
const auto* b_data = b ? b->Data<float>() : nullptr;
|
||||
auto* y_data = y->MutableData<float>();
|
||||
|
||||
// TODO: replace it with GemmBatch for performance, it's OK for now as GemmBatch unrolls as well
|
||||
size_t max_len = helper.OutputOffsets().size();
|
||||
const size_t max_len = helper.OutputOffsets().size();
|
||||
const size_t M = static_cast<size_t>(helper.M());
|
||||
const size_t N = static_cast<size_t>(helper.N());
|
||||
const size_t K = static_cast<size_t>(helper.K());
|
||||
const size_t lda = static_cast<int>(trans_a ? M : K);
|
||||
const size_t ldb = static_cast<int>(trans_b ? K : N);
|
||||
|
||||
std::vector<MLAS_SGEMM_DATA_PARAMS> data(max_len);
|
||||
for (size_t i = 0; i < max_len; i++) {
|
||||
if (packed_b_) {
|
||||
MlasGemm(
|
||||
trans_a ? CblasTrans : CblasNoTrans,
|
||||
static_cast<size_t>(helper.M()),
|
||||
static_cast<size_t>(helper.N()),
|
||||
static_cast<size_t>(helper.K()),
|
||||
alpha_attr_,
|
||||
a_data + helper.LeftOffsets()[i],
|
||||
static_cast<size_t>(trans_a ? helper.M() : helper.K()),
|
||||
packed_b_.get(),
|
||||
0.0f,
|
||||
y_data + helper.OutputOffsets()[i],
|
||||
static_cast<size_t>(helper.N()),
|
||||
thread_pool);
|
||||
continue;
|
||||
}
|
||||
math::Gemm<float, concurrency::ThreadPool>(
|
||||
trans_a ? CblasTrans : CblasNoTrans,
|
||||
trans_b ? CblasTrans : CblasNoTrans,
|
||||
helper.M(),
|
||||
helper.N(),
|
||||
helper.K(),
|
||||
alpha_attr_,
|
||||
a_data + helper.LeftOffsets()[i],
|
||||
b_data + helper.RightOffsets()[i],
|
||||
0.0f,
|
||||
y_data + helper.OutputOffsets()[i],
|
||||
thread_pool);
|
||||
data[i].BIsPacked = bool(packed_b_);
|
||||
data[i].A = a_data + helper.LeftOffsets()[i];
|
||||
data[i].lda = lda;
|
||||
data[i].B = data[i].BIsPacked ? (float*)packed_b_.get() : b_data + helper.RightOffsets()[i];
|
||||
data[i].ldb = ldb;
|
||||
data[i].C = y_data + helper.OutputOffsets()[i];
|
||||
data[i].ldc = N;
|
||||
data[i].alpha = alpha_attr_;
|
||||
data[i].beta = 0.0f;
|
||||
}
|
||||
MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans,
|
||||
M, N, K, data.data(), max_len, thread_pool);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,6 +108,7 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) const {
|
|||
|
||||
MlasGemm(gemm_shape, gemm_params, ctx->GetOperatorThreadPool());
|
||||
|
||||
//TODO!! consider making this a post processor, so that we can parallize this loop
|
||||
MlasRequantizeOutput(gemm_output,
|
||||
y->template MutableData<uint8_t>() + helper.OutputOffsets()[i],
|
||||
nullptr,
|
||||
|
|
|
|||
|
|
@ -17,18 +17,18 @@ void QGEMM(benchmark::State& state, bool pack_b) {
|
|||
const uint8_t a_zero_point = 29;
|
||||
const uint8_t b_zero_point = 179;
|
||||
|
||||
const int64_t M = state.range(0);
|
||||
const int64_t N = state.range(1);
|
||||
const int64_t K = state.range(2);
|
||||
if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!");
|
||||
if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!");
|
||||
if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!");
|
||||
if (state.range(3) <= 0) throw std::invalid_argument("Batch must greater than 0!");
|
||||
if (state.range(4) <= 0) throw std::invalid_argument("Threads must greater than 0!");
|
||||
|
||||
const int64_t batch = state.range(3);
|
||||
const int64_t threads = state.range(4);
|
||||
const size_t M = static_cast<size_t>(state.range(0));
|
||||
const size_t N = static_cast<size_t>(state.range(1));
|
||||
const size_t K = static_cast<size_t>(state.range(2));
|
||||
|
||||
if (M <= 0) throw std::invalid_argument("M must greater than 0!");
|
||||
if (N <= 0) throw std::invalid_argument("N must greater than 0!");
|
||||
if (K <= 0) throw std::invalid_argument("K must greater than 0!");
|
||||
if (batch <= 0) throw std::invalid_argument("Batch must greater than 0!");
|
||||
if (threads <= 0) throw std::invalid_argument("Threads must greater than 0!");
|
||||
const size_t batch = static_cast<size_t>(state.range(3));
|
||||
const size_t threads = static_cast<size_t>(state.range(4));
|
||||
|
||||
OrtThreadPoolParams tpo;
|
||||
tpo.thread_pool_size = int(threads);
|
||||
|
|
|
|||
|
|
@ -10,13 +10,13 @@
|
|||
static const std::vector<std::string> sgemm_bench_arg_names = {"M", "N", "K"};
|
||||
|
||||
void SGEMM(benchmark::State& state, bool pack_b, bool trans_a, bool trans_b, float alpha = 1.0f, float beta = 0.0f) {
|
||||
const int64_t M = state.range(0);
|
||||
const int64_t N = state.range(1);
|
||||
const int64_t K = state.range(2);
|
||||
if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!");
|
||||
if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!");
|
||||
if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!");
|
||||
const size_t M = static_cast<size_t>(state.range(0));
|
||||
const size_t N = static_cast<size_t>(state.range(1));
|
||||
const size_t K = static_cast<size_t>(state.range(2));
|
||||
|
||||
if (M <= 0) throw std::invalid_argument("M must greater than 0!");
|
||||
if (N <= 0) throw std::invalid_argument("N must greater than 0!");
|
||||
if (K <= 0) throw std::invalid_argument("K must greater than 0!");
|
||||
|
||||
auto A = RandomVectorUniform(static_cast<size_t>(M * K), -1.0f, 1.0f);
|
||||
auto B = RandomVectorUniform(static_cast<size_t>(N * K), -1.0f, 1.0f);
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
#include <memory>
|
||||
#include <sstream>
|
||||
|
||||
|
||||
template <> MlasFgemmTest<float, false, false>* MlasTestFixture<MlasFgemmTest<float, false, false>>::mlas_tester(nullptr);
|
||||
template <> MlasFgemmTest<float, false, true>* MlasTestFixture<MlasFgemmTest<float, false, true>>::mlas_tester(nullptr);
|
||||
template <> MlasFgemmTest<float, true, false>* MlasTestFixture<MlasFgemmTest<float, true, false>>::mlas_tester(nullptr);
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@ const char* GetGemmTestSuitePrefix<double>() {
|
|||
template <typename T, bool Packed>
|
||||
class FgemmPackedContext;
|
||||
|
||||
template <typename T>
|
||||
class FgemmPackedContext<T, false> {
|
||||
template <>
|
||||
class FgemmPackedContext<float, false> {
|
||||
public:
|
||||
void
|
||||
TestGemm(
|
||||
|
|
@ -31,21 +31,69 @@ class FgemmPackedContext<T, false> {
|
|||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
float alpha,
|
||||
const T* A,
|
||||
size_t BatchSize,
|
||||
const float alpha,
|
||||
const float* A,
|
||||
size_t lda,
|
||||
const T* B,
|
||||
const float* B,
|
||||
size_t ldb,
|
||||
float beta,
|
||||
T* C,
|
||||
const float beta,
|
||||
float* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* threadpool) {
|
||||
MlasGemm(TransA, TransB, M, N, K, T(alpha), A, lda, B, ldb, T(beta), C, ldc, threadpool);
|
||||
std::vector<MLAS_SGEMM_DATA_PARAMS> data(BatchSize);
|
||||
for (size_t i = 0; i < BatchSize; i++) {
|
||||
data[i].A = A + M * K * i;
|
||||
data[i].lda = lda;
|
||||
data[i].B = B + K * N * i;
|
||||
data[i].ldb = ldb;
|
||||
data[i].C = C + M * N * i;
|
||||
data[i].ldc = ldc;
|
||||
data[i].alpha = alpha;
|
||||
data[i].beta = beta;
|
||||
}
|
||||
MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class FgemmPackedContext<T, true> {
|
||||
#ifdef MLAS_TARGET_AMD64
|
||||
template <>
|
||||
class FgemmPackedContext<double, false> {
|
||||
public:
|
||||
void TestGemm(
|
||||
CBLAS_TRANSPOSE TransA,
|
||||
CBLAS_TRANSPOSE TransB,
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BatchSize,
|
||||
double alpha,
|
||||
const double* A,
|
||||
size_t lda,
|
||||
const double* B,
|
||||
size_t ldb,
|
||||
double beta,
|
||||
double* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* threadpool) {
|
||||
std::vector<MLAS_DGEMM_DATA_PARAMS> data(BatchSize);
|
||||
for (size_t i = 0; i < BatchSize; i++) {
|
||||
data[i].A = A + M * K * i;
|
||||
data[i].lda = lda;
|
||||
data[i].B = B + K * N * i;
|
||||
data[i].ldb = ldb;
|
||||
data[i].C = C + M * N * i;
|
||||
data[i].ldc = ldc;
|
||||
data[i].alpha = alpha;
|
||||
data[i].beta = beta;
|
||||
}
|
||||
MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
class FgemmPackedContext<float, true> {
|
||||
public:
|
||||
void
|
||||
TestGemm(
|
||||
|
|
@ -54,19 +102,34 @@ class FgemmPackedContext<T, true> {
|
|||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
float alpha,
|
||||
const T* A,
|
||||
size_t BatchSize,
|
||||
const float alpha,
|
||||
const float* A,
|
||||
size_t lda,
|
||||
const T* B,
|
||||
const float* B,
|
||||
size_t ldb,
|
||||
float beta,
|
||||
T* C,
|
||||
const float beta,
|
||||
float* C,
|
||||
size_t ldc,
|
||||
MLAS_THREADPOOL* threadpool) {
|
||||
size_t PackedBSize = MlasGemmPackBSize(N, K);
|
||||
void* PackedB = BufferBPacked.GetBuffer(PackedBSize, true);
|
||||
MlasGemmPackB(TransB, N, K, B, ldb, PackedB);
|
||||
MlasGemm(TransA, M, N, K, T(alpha), A, lda, PackedB, T(beta), C, ldc, threadpool);
|
||||
void* PackedB = BufferBPacked.GetBuffer(PackedBSize * BatchSize, true);
|
||||
std::vector<MLAS_SGEMM_DATA_PARAMS> data(BatchSize);
|
||||
for (size_t i = 0; i < BatchSize; i++) {
|
||||
MlasGemmPackB(TransB, N, K, B + K * N * i, ldb, (uint8_t*)PackedB + PackedBSize * i);
|
||||
data[i].BIsPacked = true;
|
||||
data[i].A = A + M * K * i;
|
||||
data[i].lda = lda;
|
||||
data[i].B = (float*)((uint8_t*)PackedB + PackedBSize * i);
|
||||
data[i].ldb = ldb;
|
||||
data[i].C = C + M * N * i;
|
||||
data[i].ldc = ldc;
|
||||
data[i].alpha = alpha;
|
||||
data[i].beta = beta;
|
||||
}
|
||||
MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool);
|
||||
|
||||
MlasGemm(TransA, M, N, K, alpha, A, lda, PackedB, beta, C, ldc, threadpool);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -89,14 +152,14 @@ class MlasFgemmTest : public MlasTestBase {
|
|||
|
||||
MlasFgemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) { }
|
||||
|
||||
void Test(size_t M, size_t N, size_t K, float alpha, float beta) {
|
||||
Test(false, false, M, N, K, alpha, beta);
|
||||
Test(false, true, M, N, K, alpha, beta);
|
||||
Test(true, false, M, N, K, alpha, beta);
|
||||
Test(true, true, M, N, K, alpha, beta);
|
||||
void Test(size_t M, size_t N, size_t K, size_t BatchSize, T alpha, T beta) {
|
||||
Test(false, false, M, N, K, BatchSize, alpha, beta);
|
||||
Test(false, true, M, N, K, BatchSize, alpha, beta);
|
||||
Test(true, false, M, N, K, BatchSize, alpha, beta);
|
||||
Test(true, true, M, N, K, BatchSize, alpha, beta);
|
||||
}
|
||||
|
||||
void Test(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, float alpha, float beta) {
|
||||
void Test(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, size_t BatchSize, T alpha, T beta) {
|
||||
//
|
||||
// Skip the test if the B buffer cannot be packed.
|
||||
//
|
||||
|
|
@ -105,14 +168,14 @@ class MlasFgemmTest : public MlasTestBase {
|
|||
return;
|
||||
}
|
||||
|
||||
const T* A = BufferA.GetBuffer(K * M);
|
||||
const T* B = BufferB.GetBuffer(N * K);
|
||||
T* C = BufferC.GetBuffer(N * M);
|
||||
T* CReference = BufferCReference.GetBuffer(N * M);
|
||||
const T* A = BufferA.GetBuffer(K * M * BatchSize);
|
||||
const T* B = BufferB.GetBuffer(N * K * BatchSize);
|
||||
T* C = BufferC.GetBuffer(N * M * BatchSize);
|
||||
T* CReference = BufferCReference.GetBuffer(N * M * BatchSize);
|
||||
|
||||
Test(trans_a ? CblasTrans : CblasNoTrans,
|
||||
trans_b ? CblasTrans : CblasNoTrans,
|
||||
M, N, K, alpha, A, trans_a ? M : K, B, trans_b ? K : N,
|
||||
M, N, K, BatchSize, alpha, A, trans_a ? M : K, B, trans_b ? K : N,
|
||||
beta, C, CReference, N);
|
||||
}
|
||||
|
||||
|
|
@ -121,33 +184,36 @@ class MlasFgemmTest : public MlasTestBase {
|
|||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
float alpha,
|
||||
size_t BatchSize,
|
||||
T alpha,
|
||||
const T* A,
|
||||
size_t lda,
|
||||
const T* B,
|
||||
size_t ldb,
|
||||
float beta,
|
||||
T beta,
|
||||
T* C,
|
||||
T* CReference,
|
||||
size_t ldc) {
|
||||
std::fill_n(C, M * N, -0.5f);
|
||||
std::fill_n(CReference, M * N, -0.5f);
|
||||
std::fill_n(C, M * N * BatchSize, -0.5f);
|
||||
std::fill_n(CReference, M * N * BatchSize, -0.5f);
|
||||
|
||||
PackedContext.TestGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, threadpool_);
|
||||
ReferenceGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, CReference, ldc);
|
||||
PackedContext.TestGemm(TransA, TransB, M, N, K, BatchSize, alpha, A, lda, B, ldb, beta, C, ldc, threadpool_);
|
||||
ReferenceGemm(TransA, TransB, M, N, K, BatchSize, alpha, A, lda, B, ldb, beta, CReference, ldc);
|
||||
|
||||
for (size_t m = 0, f = 0; m < M; m++) {
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
// Sensitive to comparing positive/negative zero.
|
||||
ASSERT_EQ(C[f], CReference[f])
|
||||
<< " Diff @[" << m << ", " << n << "] f=" << f << ", "
|
||||
<< (Packed ? "Packed" : "NoPack") << "."
|
||||
<< (Threaded ? "SingleThread" : "Threaded") << "/"
|
||||
<< (TransA == CblasTrans ? "TransA" : "A") << "/"
|
||||
<< (TransB == CblasTrans ? "TransB" : "B") << "/"
|
||||
<< "M" << M << "xN" << N << "xK" << K << "/"
|
||||
<< "Alpha" << alpha << "/"
|
||||
<< "Beta" << beta;
|
||||
for (size_t batch = 0, f = 0; batch < BatchSize; batch++) {
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
for (size_t n = 0; n < N; n++, f++) {
|
||||
// Sensitive to comparing positive/negative zero.
|
||||
ASSERT_EQ(C[f], CReference[f])
|
||||
<< " Diff @[" << batch << ", " << m << ", " << n << "] f=" << f << ", "
|
||||
<< (Packed ? "Packed" : "NoPack") << "."
|
||||
<< (Threaded ? "SingleThread" : "Threaded") << "/"
|
||||
<< (TransA == CblasTrans ? "TransA" : "A") << "/"
|
||||
<< (TransB == CblasTrans ? "TransB" : "B") << "/"
|
||||
<< "M" << M << "xN" << N << "xK" << K << "/"
|
||||
<< "Alpha" << alpha << "/"
|
||||
<< "Beta" << beta;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -157,111 +223,120 @@ class MlasFgemmTest : public MlasTestBase {
|
|||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
float alpha,
|
||||
size_t BatchSize,
|
||||
T alpha,
|
||||
const T* A,
|
||||
size_t lda,
|
||||
const T* B,
|
||||
size_t ldb,
|
||||
float beta,
|
||||
T beta,
|
||||
T* C,
|
||||
size_t ldc) {
|
||||
if (TransA == CblasNoTrans) {
|
||||
if (TransB == CblasNoTrans) {
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
const T* a = A + (m * lda);
|
||||
const T* b = B + n;
|
||||
T* c = C + (m * ldc) + n;
|
||||
T sum = 0.0f;
|
||||
for (size_t batch = 0; batch < BatchSize; batch++) {
|
||||
if (TransA == CblasNoTrans) {
|
||||
if (TransB == CblasNoTrans) {
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
const T* a = A + (m * lda);
|
||||
const T* b = B + n;
|
||||
T* c = C + (m * ldc) + n;
|
||||
T sum = 0.0f;
|
||||
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
sum += (*b * *a);
|
||||
b += ldb;
|
||||
a += 1;
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
sum += (*b * *a);
|
||||
b += ldb;
|
||||
a += 1;
|
||||
}
|
||||
|
||||
*c = (*c * beta) + (sum * alpha);
|
||||
}
|
||||
}
|
||||
|
||||
*c = (*c * beta) + (sum * alpha);
|
||||
} else {
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
const T* a = A + (m * lda);
|
||||
const T* b = B + (n * ldb);
|
||||
T* c = C + (m * ldc) + n;
|
||||
T sum = 0.0f;
|
||||
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
sum += (*b * *a);
|
||||
b += 1;
|
||||
a += 1;
|
||||
}
|
||||
|
||||
*c = (*c * beta) + (sum * alpha);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
const T* a = A + (m * lda);
|
||||
const T* b = B + (n * ldb);
|
||||
T* c = C + (m * ldc) + n;
|
||||
T sum = 0.0f;
|
||||
if (TransB == CblasNoTrans) {
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
const T* a = A + m;
|
||||
const T* b = B + n;
|
||||
T* c = C + (m * ldc) + n;
|
||||
T sum = 0.0f;
|
||||
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
sum += (*b * *a);
|
||||
b += 1;
|
||||
a += 1;
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
sum += (*b * *a);
|
||||
b += ldb;
|
||||
a += lda;
|
||||
}
|
||||
|
||||
*c = (*c * beta) + (sum * alpha);
|
||||
}
|
||||
}
|
||||
|
||||
*c = (*c * beta) + (sum * alpha);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
if (TransB == CblasNoTrans) {
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
const T* a = A + m;
|
||||
const T* b = B + n;
|
||||
T* c = C + (m * ldc) + n;
|
||||
T sum = 0.0f;
|
||||
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
sum += (*b * *a);
|
||||
b += ldb;
|
||||
a += lda;
|
||||
}
|
||||
|
||||
*c = (*c * beta) + (sum * alpha);
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
const T* a = A + m;
|
||||
const T* b = B + (n * ldb);
|
||||
T* c = C + (m * ldc) + n;
|
||||
T sum = 0.0f;
|
||||
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
sum += (*b * *a);
|
||||
b += 1;
|
||||
a += lda;
|
||||
}
|
||||
|
||||
*c = (*c * beta) + (sum * alpha);
|
||||
} else {
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
const T* a = A + m;
|
||||
const T* b = B + (n * ldb);
|
||||
T* c = C + (m * ldc) + n;
|
||||
T sum = 0.0f;
|
||||
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
sum += (*b * *a);
|
||||
b += 1;
|
||||
a += lda;
|
||||
}
|
||||
|
||||
*c = (*c * beta) + (sum * alpha);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
A += M * K;
|
||||
B += K * N;
|
||||
C += M * N;
|
||||
}
|
||||
}
|
||||
|
||||
void ExecuteLong() override {
|
||||
static const float multipliers[] = {0.0f, -0.0f, 0.25f, -0.5f, 1.0f, -1.0f};
|
||||
static const T multipliers[] = {0.0f, -0.0f, 0.25f, -0.5f, 1.0f, -1.0f};
|
||||
|
||||
for (size_t N = 1; N < 128; N++) {
|
||||
for (size_t K = 1; K < 128; K++) {
|
||||
for (size_t a = 0; a < _countof(multipliers); a++) {
|
||||
for (size_t b = 0; b < _countof(multipliers); b++) {
|
||||
Test(1, N, K, multipliers[a], multipliers[b]);
|
||||
Test(N, 1, K, multipliers[a], multipliers[b]);
|
||||
Test(1, N, K, 1, multipliers[a], multipliers[b]);
|
||||
Test(N, 1, K, 1, multipliers[a], multipliers[b]);
|
||||
if (!Packed) {
|
||||
Test(1, N, K, 3, multipliers[a], multipliers[b]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t a = 0; a < _countof(multipliers); a++) {
|
||||
float alpha = multipliers[a];
|
||||
T alpha = multipliers[a];
|
||||
|
||||
for (size_t b = 0; b < _countof(multipliers); b++) {
|
||||
float beta = multipliers[b];
|
||||
T beta = multipliers[b];
|
||||
|
||||
for (size_t M = 16; M < 160; M += 32) {
|
||||
for (size_t N = 16; N < 160; N += 32) {
|
||||
|
|
@ -269,21 +344,25 @@ class MlasFgemmTest : public MlasTestBase {
|
|||
for (size_t k = 0; k < _countof(ks); k++) {
|
||||
size_t K = ks[k];
|
||||
|
||||
Test(M, N, K, alpha, beta);
|
||||
Test(M + 1, N, K, alpha, beta);
|
||||
Test(M, N + 1, K, alpha, beta);
|
||||
Test(M + 1, N + 1, K, alpha, beta);
|
||||
Test(M + 3, N + 2, K, alpha, beta);
|
||||
Test(M + 4, N, K, alpha, beta);
|
||||
Test(M, N + 4, K, alpha, beta);
|
||||
Test(M + 4, N + 4, K, alpha, beta);
|
||||
Test(M + 3, N + 7, K, alpha, beta);
|
||||
Test(M + 8, N, K, alpha, beta);
|
||||
Test(M, N + 8, K, alpha, beta);
|
||||
Test(M + 12, N + 12, K, alpha, beta);
|
||||
Test(M + 13, N, K, alpha, beta);
|
||||
Test(M, N + 15, K, alpha, beta);
|
||||
Test(M + 15, N + 15, K, alpha, beta);
|
||||
Test(M, N, K, 1, alpha, beta);
|
||||
Test(M + 1, N, K, 1, alpha, beta);
|
||||
Test(M, N + 1, K, 1, alpha, beta);
|
||||
Test(M + 1, N + 1, K, 1, alpha, beta);
|
||||
Test(M + 3, N + 2, K, 1, alpha, beta);
|
||||
Test(M + 4, N, K, 1, alpha, beta);
|
||||
Test(M, N + 4, K, 1, alpha, beta);
|
||||
Test(M + 4, N + 4, K, 1, alpha, beta);
|
||||
Test(M + 3, N + 7, K, 1, alpha, beta);
|
||||
Test(M + 8, N, K, 1, alpha, beta);
|
||||
Test(M, N + 8, K, 1, alpha, beta);
|
||||
Test(M + 12, N + 12, K, 1, alpha, beta);
|
||||
Test(M + 13, N, K, 1, alpha, beta);
|
||||
Test(M, N + 15, K, 1, alpha, beta);
|
||||
Test(M + 15, N + 15, K, 1, alpha, beta);
|
||||
if (!Packed) {
|
||||
Test(M + 3, N + 1, K, 7, multipliers[a], multipliers[b]);
|
||||
Test(M + 13, N + 2, K, 9, multipliers[a], multipliers[b]);
|
||||
}
|
||||
}
|
||||
}
|
||||
printf("a %zd/%zd b %zd/%zd M %zd\n", a, _countof(multipliers), b, _countof(multipliers), M);
|
||||
|
|
@ -294,7 +373,7 @@ class MlasFgemmTest : public MlasTestBase {
|
|||
for (size_t M = 0; M < 160; M++) {
|
||||
for (size_t N = 0; N < 160; N++) {
|
||||
for (size_t K = 0; K < 160; K++) {
|
||||
Test(M, N, K, 1.0f, 0.0f);
|
||||
Test(M, N, K, 1, 1.0f, 0.0f);
|
||||
}
|
||||
}
|
||||
printf("M %zd\n", M);
|
||||
|
|
@ -303,10 +382,10 @@ class MlasFgemmTest : public MlasTestBase {
|
|||
for (size_t M = 160; M < 320; M += 24) {
|
||||
for (size_t N = 112; N < 320; N += 24) {
|
||||
for (size_t K = 0; K < 16; K++) {
|
||||
Test(M, N, K, 1.0f, 0.0f);
|
||||
Test(M, N, K, 1, 1.0f, 0.0f);
|
||||
}
|
||||
for (size_t K = 16; K < 160; K += 32) {
|
||||
Test(M, N, K, 1.0f, 0.0f);
|
||||
Test(M, N, K, 1, 1.0f, 0.0f);
|
||||
}
|
||||
}
|
||||
printf("M %zd\n", M);
|
||||
|
|
|
|||
|
|
@ -14,20 +14,20 @@
|
|||
template <typename T, bool Packed, bool Threaded>
|
||||
class FgemmShortExecuteTest : public MlasTestFixture<MlasFgemmTest<T, Packed, Threaded>> {
|
||||
public:
|
||||
explicit FgemmShortExecuteTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, float alpha, float beta)
|
||||
: trans_a_(trans_a), trans_b_(trans_b), M_(M), N_(N), K_(K), alpha_(alpha), beta_(beta) {
|
||||
explicit FgemmShortExecuteTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, size_t BatchSize, float alpha, float beta)
|
||||
: trans_a_(trans_a), trans_b_(trans_b), M_(M), N_(N), K_(K), Batch_(BatchSize), alpha_(alpha), beta_(beta) {
|
||||
}
|
||||
|
||||
void TestBody() override {
|
||||
MlasTestFixture<MlasFgemmTest<T, Packed, Threaded>>::mlas_tester->Test(
|
||||
trans_a_, trans_b_, M_, N_, K_, alpha_, beta_);
|
||||
trans_a_, trans_b_, M_, N_, K_, Batch_, alpha_, beta_);
|
||||
}
|
||||
|
||||
static size_t RegisterSingleTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, float alpha, float beta) {
|
||||
static size_t RegisterSingleTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, size_t BatchSize, float alpha, float beta) {
|
||||
std::stringstream ss;
|
||||
ss << (trans_a ? "TransA" : "A") << "/"
|
||||
<< (trans_b ? "TransB" : "B") << "/"
|
||||
<< "M" << M << "xN" << N << "xK" << K << "/"
|
||||
<< "BatchSize" << BatchSize << "/M" << M << "xN" << N << "xK" << K << "/"
|
||||
<< "Alpha" << alpha << "/"
|
||||
<< "Beta" << beta;
|
||||
auto test_name = ss.str();
|
||||
|
|
@ -42,37 +42,39 @@ class FgemmShortExecuteTest : public MlasTestFixture<MlasFgemmTest<T, Packed, Th
|
|||
// Important to use the fixture type as the return type here.
|
||||
[=]() -> MlasTestFixture<MlasFgemmTest<T, Packed, Threaded>>* {
|
||||
return new FgemmShortExecuteTest<T, Packed, Threaded>(
|
||||
trans_a, trans_b, M, N, K, alpha, beta);
|
||||
trans_a, trans_b, M, N, K, BatchSize, alpha, beta);
|
||||
});
|
||||
return 1;
|
||||
}
|
||||
|
||||
static size_t RegisterTestTransposeABProduct(size_t M, size_t N, size_t K, float alpha, float beta) {
|
||||
return RegisterSingleTest(false, false, M, N, K, alpha, beta) +
|
||||
RegisterSingleTest(false, true, M, N, K, alpha, beta) +
|
||||
RegisterSingleTest(true, false, M, N, K, alpha, beta) +
|
||||
RegisterSingleTest(true, true, M, N, K, alpha, beta);
|
||||
static size_t RegisterTestTransposeABProduct(size_t M, size_t N, size_t K, size_t BatchSize, float alpha, float beta) {
|
||||
return RegisterSingleTest(false, false, M, N, K, BatchSize, alpha, beta) +
|
||||
RegisterSingleTest(false, true, M, N, K, BatchSize, alpha, beta) +
|
||||
RegisterSingleTest(true, false, M, N, K, BatchSize, alpha, beta) +
|
||||
RegisterSingleTest(true, true, M, N, K, BatchSize, alpha, beta);
|
||||
}
|
||||
|
||||
static size_t RegisterShortExecuteTests() {
|
||||
size_t test_registered = 0;
|
||||
for (size_t b = 0; b < 16; b++) {
|
||||
test_registered += RegisterTestTransposeABProduct(b, b, b, 1.0f, 0.0f);
|
||||
test_registered += RegisterTestTransposeABProduct(b, b, b, 1, 1.0f, 0.0f);
|
||||
test_registered += RegisterTestTransposeABProduct(b, b, b, 3, 1.0f, 0.0f);
|
||||
}
|
||||
for (size_t b = 16; b <= 256; b <<= 1) {
|
||||
test_registered += RegisterTestTransposeABProduct(b, b, b, 1.0f, 0.0f);
|
||||
test_registered += RegisterTestTransposeABProduct(b, b, b, 1, 1.0f, 0.0f);
|
||||
}
|
||||
for (size_t b = 256; b < 320; b += 32) {
|
||||
test_registered += RegisterTestTransposeABProduct(b, b, b, 1.0f, 0.0f);
|
||||
test_registered += RegisterTestTransposeABProduct(b, b, b, 1, 1.0f, 0.0f);
|
||||
}
|
||||
|
||||
test_registered += RegisterTestTransposeABProduct(128, 3072, 768, 1.0f, 0.0f);
|
||||
test_registered += RegisterTestTransposeABProduct(128, 768, 3072, 1.0f, 0.0f);
|
||||
test_registered += RegisterTestTransposeABProduct(128, 3072, 768, 1, 1.0f, 0.0f);
|
||||
test_registered += RegisterTestTransposeABProduct(128, 768, 3072, 1, 1.0f, 0.0f);
|
||||
test_registered += RegisterTestTransposeABProduct(25, 81, 79, 7, 1.0f, 0.0f);
|
||||
return test_registered;
|
||||
}
|
||||
|
||||
private:
|
||||
bool trans_a_, trans_b_;
|
||||
size_t M_, N_, K_;
|
||||
float alpha_, beta_;
|
||||
const size_t M_, N_, K_, Batch_;
|
||||
const T alpha_, beta_;
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue