From f4f2cc1a00de102f289b1d3934948b3d6baba7a2 Mon Sep 17 00:00:00 2001 From: Chen Fu Date: Fri, 23 Apr 2021 17:34:22 -0700 Subject: [PATCH] Add batch interface to floating point GEMM (#7323) Currently in high dimension matmul, we call multiple GEMM sequentially. In this change we execute these GEMMs in parallel, removing barriers between two adjacent GEMM operations. Performance tested with Bert and T5 model. Bert model shows no noticeable perf differences, as the heavy lifting is done by the attention operator, which is not changed in this PR. In T5 model, we see no regression on low parallel threads (x4), and performance improvement is more pronounced in high number of threads (8-16). T5 shows 10% speedup with 16 threads. With profiling, we can see the most expensive MatMul operators in T5 achieves around 20% speedup with 16 threads. Co-authored-by: Chen Fu --- onnxruntime/contrib_ops/cpu/bert/attention.cc | 1 + onnxruntime/core/mlas/inc/mlas.h | 250 ++++++++++++- onnxruntime/core/mlas/lib/dgemm.cpp | 213 +++-------- onnxruntime/core/mlas/lib/sgemm.cpp | 313 +++------------- onnxruntime/core/mlas/lib/threading.cpp | 2 +- onnxruntime/core/providers/cpu/math/matmul.cc | 49 +-- .../cpu/math/quantize_linear_matmul.cc | 1 + onnxruntime/test/mlas/bench/bench_qgemm.cpp | 20 +- onnxruntime/test/mlas/bench/bench_sgemm.cpp | 12 +- onnxruntime/test/mlas/unittest/test_fgemm.cpp | 1 - onnxruntime/test/mlas/unittest/test_fgemm.h | 347 +++++++++++------- .../test/mlas/unittest/test_fgemm_fixture.h | 38 +- 12 files changed, 624 insertions(+), 623 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 852247c43b..68e5b6815e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -309,6 +309,7 @@ Status Attention::Compute(OpKernelContext* context) const { T* qkv_dest = QKV[qkv_index]; int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size); + // TODO!! memcpy here makes it not worthwhile to use Gemm batch. Possible to post process? // broadcast 3NH -> (3.B.N.S.H) const T* broadcast_data_src = bias_data + weights_offset; T* broadcast_data_dest = QKV[qkv_index] + qkv_offset; diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index cad2bb8546..c417bdb17d 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -147,10 +147,102 @@ MlasActivation( // // Matrix/matrix multiply routines. +// C := alpha * op(A) * op(B) + beta * C +// op(X) = X or op(X) = transpose(X) or op(X) = conjg(transpose(X)) // +/** + * @brief Supply matrices data information to single precision gemm functions + */ +struct MLAS_SGEMM_DATA_PARAMS { + const float* A = nullptr; /**< Supplies the address of matrix A */ + size_t lda = 0; /**< Supplies the first dimension of matrix A. */ + const float* B = nullptr; /**< Supplies the address of matrix B */ + size_t ldb = 0; /**< Supplies the first dimension of matrix B. */ + float* C = nullptr; /**< Supplies the address of matrix C */ + size_t ldc = 0; /**< Supplies the first dimension of matrix C. */ + float alpha = 1.0f; /**< Supplies the scalar alpha multiplier (see SGEMM definition) */ + float beta = 0.0f; /**< Supplies the scalar beta multiplier (see SGEMM definition) */ + bool BIsPacked = false; /**< Whether B is pre-packed */ +}; + +/** + * @brief Batched single precision matrix/matrix multiply operation (SGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number + of rows of matrix B. + * @param Data A array of matrices data parameters + * @param BatchSize Supplies number of multiplications in this batch + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ void MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); + +/** + * @brief Single precision matrix/matrix multiply operation (SGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number + of rows of matrix B. + * @param Data Supplies the matrices data parameters + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ +inline +void +MlasGemm( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS& Data, + MLAS_THREADPOOL* ThreadPool + ) +{ + MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool); +} + +/** + * @brief Single precision matrix/matrix multiply operation (SGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number + of rows of matrix B. + * @param alpha Supplies the scalar alpha multiplier (see SGEMM definition) + * @param A Supplies the address of matrix A + * @param lda Supplies the first dimension of matrix A. + * @param B Supplies the address of matrix B + * @param ldb Supplies the first dimension of matrix B. + * @param beta Supplies the scalar beta multiplier (see SGEMM definition) + * @param C Supplies the address of matrix C + * @param ldc Supplies the first dimension of matrix C. + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ +inline +void MlasGemm( CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, @@ -166,10 +258,41 @@ MlasGemm( float* C, size_t ldc, MLAS_THREADPOOL* ThreadPool - ); + ) +{ + MLAS_SGEMM_DATA_PARAMS Data; + Data.alpha = alpha; + Data.A = A; + Data.lda = lda; + Data.B = B; + Data.ldb = ldb; + Data.beta = beta; + Data.C = C; + Data.ldc = ldc; + MlasGemm(TransA, TransB, M, N, K, Data, ThreadPool); +} + +/** + * @brief the single precision matrix/matrix multiply operation (SGEMM) with pre-packed B + * + * @param TransA - Supplies the transpose operation for matrix A. + * @param M - Supplies the number of rows of matrix A and matrix C. + * @param N - Supplies the number of columns of matrix B and matrix C. + * @param K - Supplies the number of columns of matrix A and the number + of rows of matrix B. + * @param alpha - Supplies the scalar alpha multiplier (see SGEMM definition). + * @param A - Supplies the address of matrix A. + * @param lda - Supplies the first dimension of matrix A. + * @param PackedB - Supplies the address of packed matrix B. + * @param beta - Supplies the scalar beta multiplier (see SGEMM definition). + * @param C - Supplies the address of matrix C. + * @param ldc - Supplies the first dimension of matrix C. + * @param ThreadPool - Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ +inline void -MLASCALL MlasGemm( CBLAS_TRANSPOSE TransA, size_t M, @@ -183,10 +306,117 @@ MlasGemm( float* C, size_t ldc, MLAS_THREADPOOL* ThreadPool - ); + ) +{ + MLAS_SGEMM_DATA_PARAMS DataParams; + DataParams.A = A; + DataParams.lda = lda; + DataParams.B = static_cast(PackedB); + DataParams.ldb = 0; + DataParams.C = C; + DataParams.ldc = ldc; + DataParams.alpha = alpha; + DataParams.beta = beta; + DataParams.BIsPacked = true; + MlasGemmBatch(TransA, + CblasTrans, // deos not matter when B is packed + M, N, K, &DataParams, 1, ThreadPool); +} + + +/** + * @brief Supply matrices data information to double precision gemm functions + */ +struct MLAS_DGEMM_DATA_PARAMS { + const double* A = nullptr; /**< Supplies the address of matrix A */ + size_t lda = 0; /**< Supplies the first dimension of matrix A. */ + const double* B = nullptr; /**< Supplies the address of matrix B */ + size_t ldb = 0; /**< Supplies the first dimension of matrix B. */ + double* C = nullptr; /**< Supplies the address of matrix C */ + size_t ldc = 0; /**< Supplies the first dimension of matrix C. */ + double alpha = 1.0; /**< Supplies the scalar alpha multiplier (see SGEMM definition) */ + double beta = 0.0; /**< Supplies the scalar beta multiplier (see SGEMM definition) */ +}; + +/** + * @brief Batched double precision matrix/matrix multiply operation (DGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number + of rows of matrix B. + * @param Data A array of matrices data parameters + * @param BatchSize Supplies number of multiplications in this batch + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ void MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_DGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); + + +/** + * @brief Double precision matrix/matrix multiply operation (DGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number + of rows of matrix B. + * @param Data Supplies the matrices data parameters + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ +inline +void +MlasGemm( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_DGEMM_DATA_PARAMS& Data, + MLAS_THREADPOOL* ThreadPool + ) +{ + MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool); +} + +/** + * @brief Double precision matrix/matrix multiply operation (DGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number + of rows of matrix B. + * @param alpha Supplies the scalar alpha multiplier (see SGEMM definition) + * @param A Supplies the address of matrix A + * @param lda Supplies the first dimension of matrix A. + * @param B Supplies the address of matrix B + * @param ldb Supplies the first dimension of matrix B. + * @param beta Supplies the scalar beta multiplier (see SGEMM definition) + * @param C Supplies the address of matrix C + * @param ldc Supplies the first dimension of matrix C. + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ +inline +void MlasGemm( CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, @@ -202,7 +432,19 @@ MlasGemm( double* C, size_t ldc, MLAS_THREADPOOL* ThreadPool - ); + ) +{ + MLAS_DGEMM_DATA_PARAMS Data; + Data.alpha = alpha; + Data.A = A; + Data.lda = lda; + Data.B = B; + Data.ldb = ldb; + Data.beta = beta; + Data.C = C; + Data.ldc = ldc; + MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool); +} enum class MLAS_QUANTIZATION_GRANULARITY { PerMatrix, diff --git a/onnxruntime/core/mlas/lib/dgemm.cpp b/onnxruntime/core/mlas/lib/dgemm.cpp index 666856a6b6..d809e07d2a 100644 --- a/onnxruntime/core/mlas/lib/dgemm.cpp +++ b/onnxruntime/core/mlas/lib/dgemm.cpp @@ -26,29 +26,6 @@ Abstract: #define MLAS_DGEMM_TRANSA_ROWS 12 -// -// Define the parameters to execute segments of a DGEMM operation on worker -// threads. -// - -struct MLAS_DGEMM_WORK_BLOCK { - ptrdiff_t ThreadCountM; - ptrdiff_t ThreadCountN; - CBLAS_TRANSPOSE TransA; - CBLAS_TRANSPOSE TransB; - size_t M; - size_t N; - size_t K; - const double* A; - size_t lda; - const double* B; - size_t ldb; - double* C; - size_t ldc; - double alpha; - double beta; -}; - #ifdef MLAS_TARGET_AMD64 void @@ -750,8 +727,15 @@ Return Value: void MlasDgemmThreaded( - void* Context, - ptrdiff_t ThreadId + const ptrdiff_t ThreadCountM, + const ptrdiff_t ThreadCountN, + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const size_t M, + const size_t N, + const size_t K, + const MLAS_DGEMM_DATA_PARAMS* Data, + const ptrdiff_t ThreadId ) /*++ @@ -772,10 +756,6 @@ Return Value: --*/ { - const auto* WorkBlock = (MLAS_DGEMM_WORK_BLOCK*)Context; - - const ptrdiff_t ThreadCountM = WorkBlock->ThreadCountM; - const ptrdiff_t ThreadCountN = WorkBlock->ThreadCountN; const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; @@ -784,7 +764,6 @@ Return Value: // Partition the operation along the M dimension. // - size_t M = WorkBlock->M; size_t RangeStartM; size_t RangeCountM; @@ -794,7 +773,6 @@ Return Value: // Partition the operation along the N dimension. // - size_t N = WorkBlock->N; size_t RangeStartN; size_t RangeCountN; @@ -813,50 +791,32 @@ Return Value: // Dispatch the partitioned operation. // - CBLAS_TRANSPOSE TransA = WorkBlock->TransA; - CBLAS_TRANSPOSE TransB = WorkBlock->TransB; + const size_t lda = Data->lda; + const size_t ldb = Data->ldb; + const size_t ldc = Data->ldc; - const size_t lda = WorkBlock->lda; - const size_t ldb = WorkBlock->ldb; - const size_t ldc = WorkBlock->ldc; + const double* A = Data->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1); + const double* B = Data->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb); + double* C = Data->C + RangeStartM * ldc + RangeStartN; - const double* A = WorkBlock->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1); - const double* B = WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb); - double* C = WorkBlock->C + RangeStartM * ldc + RangeStartN; - - MlasDgemmOperation(TransA, TransB, RangeCountM, RangeCountN, WorkBlock->K, - WorkBlock->alpha, A, lda, B, ldb, WorkBlock->beta, C, ldc); + MlasDgemmOperation(TransA, TransB, RangeCountM, RangeCountN, K, + Data->alpha, A, lda, B, ldb, Data->beta, C, ldc); } + void -MlasDgemmSchedule( - MLAS_DGEMM_WORK_BLOCK* WorkBlock, +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_DGEMM_DATA_PARAMS* Data, + size_t BatchSize, MLAS_THREADPOOL* ThreadPool ) -/*++ - -Routine Description: - - This routine schedules the double precision matrix/matrix multiply - operation (DGEMM) across one or more threads. - -Arguments: - - WorkBlock - Supplies the structure containing the GEMM parameters. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ { - const size_t M = WorkBlock->M; - const size_t N = WorkBlock->N; - const size_t K = WorkBlock->K; - // // Compute the number of target threads given the complexity of the DGEMM // operation. Small requests should run using the single threaded path. @@ -885,121 +845,40 @@ Return Value: // works okay for operations involving skinny matrices. // + ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchSize - 1) / BatchSize; + ptrdiff_t ThreadCountM; + ptrdiff_t ThreadCountN; + if (N > M) { const size_t BlockedN = (N + MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_DGEMM_STRIDEN_THREAD_ALIGN; - if (size_t(TargetThreadCount) > BlockedN) { - TargetThreadCount = ptrdiff_t(BlockedN); + if (size_t(ThreadsPerGemm) > BlockedN) { + ThreadsPerGemm = ptrdiff_t(BlockedN); } - WorkBlock->ThreadCountM = 1; - WorkBlock->ThreadCountN = TargetThreadCount; + ThreadCountM = 1; + ThreadCountN = ThreadsPerGemm; } else { - if (size_t(TargetThreadCount) > M) { - TargetThreadCount = ptrdiff_t(M); + if (size_t(ThreadsPerGemm) > M) { + ThreadsPerGemm = ptrdiff_t(M); } - WorkBlock->ThreadCountM = TargetThreadCount; - WorkBlock->ThreadCountN = 1; + ThreadCountM = ThreadsPerGemm; + ThreadCountN = 1; } - MlasExecuteThreaded(MlasDgemmThreaded, WorkBlock, TargetThreadCount, ThreadPool); -} + const ptrdiff_t TotalThreads = ThreadsPerGemm * static_cast(BatchSize); + MlasTrySimpleParallel(ThreadPool, TotalThreads, [=](ptrdiff_t tid) { + const ptrdiff_t GemmIdx = tid / ThreadsPerGemm; + const ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; + MlasDgemmThreaded(ThreadCountM, ThreadCountN, TransA, TransB, + M, N, K, &(Data[GemmIdx]), ThreadIdx); + }); -void -MLASCALL -MlasGemm( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - double alpha, - const double* A, - size_t lda, - const double* B, - size_t ldb, - double beta, - double* C, - size_t ldc, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine implements the double precision matrix/matrix multiply - operation (DGEMM). - -Arguments: - - TransA - Supplies the transpose operation for matrix A. - - TransB - Supplies the transpose operation for matrix B. - - M - Supplies the number of rows of matrix A and matrix C. - - N - Supplies the number of columns of matrix B and matrix C. - - K - Supplies the number of columns of matrix A and the number of rows of - matrix B. - - alpha - Supplies the scalar alpha multiplier (see DGEMM definition). - - A - Supplies the address of matrix A. - - lda - Supplies the first dimension of matrix A. - - B - Supplies the address of matrix B. - - ldb - Supplies the first dimension of matrix B. - - beta - Supplies the scalar beta multiplier (see DGEMM definition). - - C - Supplies the address of matrix C. - - ldc - Supplies the first dimension of matrix C. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ -{ - MLAS_DGEMM_WORK_BLOCK WorkBlock; - - // - // Capture the GEMM parameters to the work block. - // - - memset(&WorkBlock, 0, sizeof(MLAS_DGEMM_WORK_BLOCK)); - - WorkBlock.TransA = TransA; - WorkBlock.TransB = TransB; - WorkBlock.M = M; - WorkBlock.N = N; - WorkBlock.K = K; - WorkBlock.A = A; - WorkBlock.lda = lda; - WorkBlock.B = B; - WorkBlock.ldb = ldb; - WorkBlock.C = C; - WorkBlock.ldc = ldc; - WorkBlock.alpha = alpha; - WorkBlock.beta = beta; - - // - // Schedule the operation across a set of worker threads. - // - - MlasDgemmSchedule(&WorkBlock, ThreadPool); } #endif diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 34ef3f9fa1..8934a8b520 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -31,25 +31,6 @@ Abstract: // threads. // -struct MLAS_SGEMM_WORK_BLOCK { - ptrdiff_t ThreadCountM; - ptrdiff_t ThreadCountN; - CBLAS_TRANSPOSE TransA; - CBLAS_TRANSPOSE TransB; - size_t M; - size_t N; - size_t K; - const float* A; - size_t lda; - const void* B; - size_t ldb; - float* C; - size_t ldc; - float alpha; - float beta; - bool BIsPacked; -}; - void MlasSgemmMultiplyBeta( float* C, @@ -1476,7 +1457,15 @@ Return Value: void MlasSgemmThreaded( - void* Context, + const ptrdiff_t ThreadCountM, + const ptrdiff_t ThreadCountN, + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const size_t M, + const size_t N, + const size_t K, + + const MLAS_SGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId ) /*++ @@ -1488,7 +1477,17 @@ Routine Description: Arguments: - Context - Supplies the pointer to the context for the threaded operation. + ThreadCountM - Supplies the total thread partition on the M dimension. + + ThreadCountN - Supplies the total thread partition on the N dimension. + + TransA - Supplies the transpose operation on A matrix + + TransB - Supplies the transpose operation on B matrix + + M, N, K - Supplies the shape of the multiplication + + DataParams - Supplies the data position and layout of the matrices ThreadId - Supplies the current index of the threaded operation. @@ -1498,10 +1497,6 @@ Return Value: --*/ { - const auto* WorkBlock = (MLAS_SGEMM_WORK_BLOCK*)Context; - - const ptrdiff_t ThreadCountM = WorkBlock->ThreadCountM; - const ptrdiff_t ThreadCountN = WorkBlock->ThreadCountN; const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; @@ -1510,7 +1505,6 @@ Return Value: // Partition the operation along the M dimension. // - size_t M = WorkBlock->M; size_t RangeStartM; size_t RangeCountM; @@ -1520,7 +1514,6 @@ Return Value: // Partition the operation along the N dimension. // - size_t N = WorkBlock->N; size_t RangeStartN; size_t RangeCountN; @@ -1539,61 +1532,42 @@ Return Value: // Dispatch the partitioned operation. // - CBLAS_TRANSPOSE TransA = WorkBlock->TransA; + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; - const size_t lda = WorkBlock->lda; - const size_t ldc = WorkBlock->ldc; + const float* A = DataParams->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1); + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - const float* A = WorkBlock->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1); - float* C = WorkBlock->C + RangeStartM * ldc + RangeStartN; - - if (WorkBlock->BIsPacked) { + if (DataParams->BIsPacked) { MlasSgemmPackedOperation(TransA, RangeCountM, RangeStartN, RangeCountN, - WorkBlock->K, WorkBlock->alpha, A, lda, WorkBlock->B, - BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, WorkBlock->beta, C, ldc); + K, DataParams->alpha, A, lda, DataParams->B, + BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, DataParams->beta, C, ldc); } else { - CBLAS_TRANSPOSE TransB = WorkBlock->TransB; + const size_t ldb = DataParams->ldb; - const size_t ldb = WorkBlock->ldb; + const float* B = (const float*)DataParams->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb); - const float* B = (const float*)WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb); - - MlasSgemmOperation(TransA, TransB, RangeCountM, RangeCountN, WorkBlock->K, - WorkBlock->alpha, A, lda, B, ldb, WorkBlock->beta, C, ldc); + MlasSgemmOperation(TransA, TransB, RangeCountM, RangeCountN, K, + DataParams->alpha, A, lda, B, ldb, DataParams->beta, C, ldc); } } void -MlasSgemmSchedule( - MLAS_SGEMM_WORK_BLOCK* WorkBlock, +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, MLAS_THREADPOOL* ThreadPool ) -/*++ - -Routine Description: - - This routine schedules the single precision matrix/matrix multiply - operation (SGEMM) across one or more threads. - -Arguments: - - WorkBlock - Supplies the structure containing the GEMM parameters. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ { - const size_t M = WorkBlock->M; - const size_t N = WorkBlock->N; - const size_t K = WorkBlock->K; // // Compute the number of target threads given the complexity of the SGEMM @@ -1623,206 +1597,41 @@ Return Value: // works okay for operations involving skinny matrices. // + ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchSize - 1) / BatchSize; + ptrdiff_t ThreadCountM; + ptrdiff_t ThreadCountN; + if (N > M) { const size_t BlockedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - if (size_t(TargetThreadCount) > BlockedN) { - TargetThreadCount = ptrdiff_t(BlockedN); + if (size_t(ThreadsPerGemm) > BlockedN) { + ThreadsPerGemm = ptrdiff_t(BlockedN); } - WorkBlock->ThreadCountM = 1; - WorkBlock->ThreadCountN = TargetThreadCount; + ThreadCountM = 1; + ThreadCountN = ThreadsPerGemm; } else { - if (size_t(TargetThreadCount) > M) { - TargetThreadCount = ptrdiff_t(M); + if (size_t(ThreadsPerGemm) > M) { + ThreadsPerGemm = ptrdiff_t(M); } - WorkBlock->ThreadCountM = TargetThreadCount; - WorkBlock->ThreadCountN = 1; + ThreadCountM = ThreadsPerGemm; + ThreadCountN = 1; } - MlasExecuteThreaded(MlasSgemmThreaded, WorkBlock, TargetThreadCount, ThreadPool); -} - -void -MLASCALL -MlasGemm( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - float alpha, - const float* A, - size_t lda, - const float* B, - size_t ldb, - float beta, - float* C, - size_t ldc, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine implements the single precision matrix/matrix multiply - operation (SGEMM). - -Arguments: - - TransA - Supplies the transpose operation for matrix A. - - TransB - Supplies the transpose operation for matrix B. - - M - Supplies the number of rows of matrix A and matrix C. - - N - Supplies the number of columns of matrix B and matrix C. - - K - Supplies the number of columns of matrix A and the number of rows of - matrix B. - - alpha - Supplies the scalar alpha multiplier (see SGEMM definition). - - A - Supplies the address of matrix A. - - lda - Supplies the first dimension of matrix A. - - B - Supplies the address of matrix B. - - ldb - Supplies the first dimension of matrix B. - - beta - Supplies the scalar beta multiplier (see SGEMM definition). - - C - Supplies the address of matrix C. - - ldc - Supplies the first dimension of matrix C. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ -{ - MLAS_SGEMM_WORK_BLOCK WorkBlock; - - // - // Capture the GEMM parameters to the work block. - // - - memset(&WorkBlock, 0, sizeof(MLAS_SGEMM_WORK_BLOCK)); - - WorkBlock.TransA = TransA; - WorkBlock.TransB = TransB; - WorkBlock.M = M; - WorkBlock.N = N; - WorkBlock.K = K; - WorkBlock.A = A; - WorkBlock.lda = lda; - WorkBlock.B = B; - WorkBlock.ldb = ldb; - WorkBlock.C = C; - WorkBlock.ldc = ldc; - WorkBlock.alpha = alpha; - WorkBlock.beta = beta; - - // - // Schedule the operation across a set of worker threads. - // - - MlasSgemmSchedule(&WorkBlock, ThreadPool); -} - -void -MLASCALL -MlasGemm( - CBLAS_TRANSPOSE TransA, - size_t M, - size_t N, - size_t K, - float alpha, - const float* A, - size_t lda, - const void* PackedB, - float beta, - float* C, - size_t ldc, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine implements the single precision matrix/matrix multiply - operation (SGEMM). - -Arguments: - - TransA - Supplies the transpose operation for matrix A. - - M - Supplies the number of rows of matrix A and matrix C. - - N - Supplies the number of columns of matrix B and matrix C. - - K - Supplies the number of columns of matrix A and the number of rows of - matrix B. - - alpha - Supplies the scalar alpha multiplier (see SGEMM definition). - - A - Supplies the address of matrix A. - - lda - Supplies the first dimension of matrix A. - - PackedB - Supplies the address of packed matrix B. - - beta - Supplies the scalar beta multiplier (see SGEMM definition). - - C - Supplies the address of matrix C. - - ldc - Supplies the first dimension of matrix C. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ -{ - MLAS_SGEMM_WORK_BLOCK WorkBlock; - - // - // Capture the GEMM parameters to the work block. - // - - memset(&WorkBlock, 0, sizeof(MLAS_SGEMM_WORK_BLOCK)); - - WorkBlock.TransA = TransA; - WorkBlock.M = M; - WorkBlock.N = N; - WorkBlock.K = K; - WorkBlock.A = A; - WorkBlock.lda = lda; - WorkBlock.B = PackedB; - WorkBlock.C = C; - WorkBlock.ldc = ldc; - WorkBlock.alpha = alpha; - WorkBlock.beta = beta; - WorkBlock.BIsPacked = true; - - // - // Schedule the operation across a set of worker threads. - // - - MlasSgemmSchedule(&WorkBlock, ThreadPool); + MlasTrySimpleParallel(ThreadPool, + ThreadsPerGemm * static_cast(BatchSize), + [=](ptrdiff_t tid) + { + ptrdiff_t GemmIdx = tid / ThreadsPerGemm; + ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; + MlasSgemmThreaded(ThreadCountM, ThreadCountN, + TransA, TransB, M, N, K, &(Data[GemmIdx]), ThreadIdx); + }); } size_t diff --git a/onnxruntime/core/mlas/lib/threading.cpp b/onnxruntime/core/mlas/lib/threading.cpp index 85e6b2ef7e..8769abdf08 100644 --- a/onnxruntime/core/mlas/lib/threading.cpp +++ b/onnxruntime/core/mlas/lib/threading.cpp @@ -88,7 +88,7 @@ MlasTrySimpleParallel( #ifdef _OPENMP #pragma omp parallel for #endif - + for (ptrdiff_t tid = 0; tid < Iterations; tid++) { Work(tid); } diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index adb4419d68..6165d922fa 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -159,38 +159,27 @@ Status MatMul::Compute(OpKernelContext* ctx) const { const auto* b_data = b ? b->Data() : nullptr; auto* y_data = y->MutableData(); - // TODO: replace it with GemmBatch for performance, it's OK for now as GemmBatch unrolls as well - size_t max_len = helper.OutputOffsets().size(); + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = static_cast(trans_a ? M : K); + const size_t ldb = static_cast(trans_b ? K : N); + + std::vector data(max_len); for (size_t i = 0; i < max_len; i++) { - if (packed_b_) { - MlasGemm( - trans_a ? CblasTrans : CblasNoTrans, - static_cast(helper.M()), - static_cast(helper.N()), - static_cast(helper.K()), - alpha_attr_, - a_data + helper.LeftOffsets()[i], - static_cast(trans_a ? helper.M() : helper.K()), - packed_b_.get(), - 0.0f, - y_data + helper.OutputOffsets()[i], - static_cast(helper.N()), - thread_pool); - continue; - } - math::Gemm( - trans_a ? CblasTrans : CblasNoTrans, - trans_b ? CblasTrans : CblasNoTrans, - helper.M(), - helper.N(), - helper.K(), - alpha_attr_, - a_data + helper.LeftOffsets()[i], - b_data + helper.RightOffsets()[i], - 0.0f, - y_data + helper.OutputOffsets()[i], - thread_pool); + data[i].BIsPacked = bool(packed_b_); + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = data[i].BIsPacked ? (float*)packed_b_.get() : b_data + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = alpha_attr_; + data[i].beta = 0.0f; } + MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, + M, N, K, data.data(), max_len, thread_pool); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc b/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc index 82b808d9e0..50a41cfa4c 100644 --- a/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc +++ b/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc @@ -108,6 +108,7 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) const { MlasGemm(gemm_shape, gemm_params, ctx->GetOperatorThreadPool()); + //TODO!! consider making this a post processor, so that we can parallize this loop MlasRequantizeOutput(gemm_output, y->template MutableData() + helper.OutputOffsets()[i], nullptr, diff --git a/onnxruntime/test/mlas/bench/bench_qgemm.cpp b/onnxruntime/test/mlas/bench/bench_qgemm.cpp index 76f4caf655..86723a931b 100644 --- a/onnxruntime/test/mlas/bench/bench_qgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_qgemm.cpp @@ -17,18 +17,18 @@ void QGEMM(benchmark::State& state, bool pack_b) { const uint8_t a_zero_point = 29; const uint8_t b_zero_point = 179; - const int64_t M = state.range(0); - const int64_t N = state.range(1); - const int64_t K = state.range(2); + if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); + if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); + if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); + if (state.range(3) <= 0) throw std::invalid_argument("Batch must greater than 0!"); + if (state.range(4) <= 0) throw std::invalid_argument("Threads must greater than 0!"); - const int64_t batch = state.range(3); - const int64_t threads = state.range(4); + const size_t M = static_cast(state.range(0)); + const size_t N = static_cast(state.range(1)); + const size_t K = static_cast(state.range(2)); - if (M <= 0) throw std::invalid_argument("M must greater than 0!"); - if (N <= 0) throw std::invalid_argument("N must greater than 0!"); - if (K <= 0) throw std::invalid_argument("K must greater than 0!"); - if (batch <= 0) throw std::invalid_argument("Batch must greater than 0!"); - if (threads <= 0) throw std::invalid_argument("Threads must greater than 0!"); + const size_t batch = static_cast(state.range(3)); + const size_t threads = static_cast(state.range(4)); OrtThreadPoolParams tpo; tpo.thread_pool_size = int(threads); diff --git a/onnxruntime/test/mlas/bench/bench_sgemm.cpp b/onnxruntime/test/mlas/bench/bench_sgemm.cpp index 4336311a90..93ae0f2eb2 100644 --- a/onnxruntime/test/mlas/bench/bench_sgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sgemm.cpp @@ -10,13 +10,13 @@ static const std::vector sgemm_bench_arg_names = {"M", "N", "K"}; void SGEMM(benchmark::State& state, bool pack_b, bool trans_a, bool trans_b, float alpha = 1.0f, float beta = 0.0f) { - const int64_t M = state.range(0); - const int64_t N = state.range(1); - const int64_t K = state.range(2); + if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); + if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); + if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); + const size_t M = static_cast(state.range(0)); + const size_t N = static_cast(state.range(1)); + const size_t K = static_cast(state.range(2)); - if (M <= 0) throw std::invalid_argument("M must greater than 0!"); - if (N <= 0) throw std::invalid_argument("N must greater than 0!"); - if (K <= 0) throw std::invalid_argument("K must greater than 0!"); auto A = RandomVectorUniform(static_cast(M * K), -1.0f, 1.0f); auto B = RandomVectorUniform(static_cast(N * K), -1.0f, 1.0f); diff --git a/onnxruntime/test/mlas/unittest/test_fgemm.cpp b/onnxruntime/test/mlas/unittest/test_fgemm.cpp index 710083c6e0..89af5234ee 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_fgemm.cpp @@ -7,7 +7,6 @@ #include #include - template <> MlasFgemmTest* MlasTestFixture>::mlas_tester(nullptr); template <> MlasFgemmTest* MlasTestFixture>::mlas_tester(nullptr); template <> MlasFgemmTest* MlasTestFixture>::mlas_tester(nullptr); diff --git a/onnxruntime/test/mlas/unittest/test_fgemm.h b/onnxruntime/test/mlas/unittest/test_fgemm.h index 5a161ebafd..35d4622003 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm.h +++ b/onnxruntime/test/mlas/unittest/test_fgemm.h @@ -21,8 +21,8 @@ const char* GetGemmTestSuitePrefix() { template class FgemmPackedContext; -template -class FgemmPackedContext { +template <> +class FgemmPackedContext { public: void TestGemm( @@ -31,21 +31,69 @@ class FgemmPackedContext { size_t M, size_t N, size_t K, - float alpha, - const T* A, + size_t BatchSize, + const float alpha, + const float* A, size_t lda, - const T* B, + const float* B, size_t ldb, - float beta, - T* C, + const float beta, + float* C, size_t ldc, MLAS_THREADPOOL* threadpool) { - MlasGemm(TransA, TransB, M, N, K, T(alpha), A, lda, B, ldb, T(beta), C, ldc, threadpool); + std::vector data(BatchSize); + for (size_t i = 0; i < BatchSize; i++) { + data[i].A = A + M * K * i; + data[i].lda = lda; + data[i].B = B + K * N * i; + data[i].ldb = ldb; + data[i].C = C + M * N * i; + data[i].ldc = ldc; + data[i].alpha = alpha; + data[i].beta = beta; + } + MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool); } }; -template -class FgemmPackedContext { +#ifdef MLAS_TARGET_AMD64 +template <> +class FgemmPackedContext { + public: + void TestGemm( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + size_t BatchSize, + double alpha, + const double* A, + size_t lda, + const double* B, + size_t ldb, + double beta, + double* C, + size_t ldc, + MLAS_THREADPOOL* threadpool) { + std::vector data(BatchSize); + for (size_t i = 0; i < BatchSize; i++) { + data[i].A = A + M * K * i; + data[i].lda = lda; + data[i].B = B + K * N * i; + data[i].ldb = ldb; + data[i].C = C + M * N * i; + data[i].ldc = ldc; + data[i].alpha = alpha; + data[i].beta = beta; + } + MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool); + } +}; +#endif + +template <> +class FgemmPackedContext { public: void TestGemm( @@ -54,19 +102,34 @@ class FgemmPackedContext { size_t M, size_t N, size_t K, - float alpha, - const T* A, + size_t BatchSize, + const float alpha, + const float* A, size_t lda, - const T* B, + const float* B, size_t ldb, - float beta, - T* C, + const float beta, + float* C, size_t ldc, MLAS_THREADPOOL* threadpool) { size_t PackedBSize = MlasGemmPackBSize(N, K); - void* PackedB = BufferBPacked.GetBuffer(PackedBSize, true); - MlasGemmPackB(TransB, N, K, B, ldb, PackedB); - MlasGemm(TransA, M, N, K, T(alpha), A, lda, PackedB, T(beta), C, ldc, threadpool); + void* PackedB = BufferBPacked.GetBuffer(PackedBSize * BatchSize, true); + std::vector data(BatchSize); + for (size_t i = 0; i < BatchSize; i++) { + MlasGemmPackB(TransB, N, K, B + K * N * i, ldb, (uint8_t*)PackedB + PackedBSize * i); + data[i].BIsPacked = true; + data[i].A = A + M * K * i; + data[i].lda = lda; + data[i].B = (float*)((uint8_t*)PackedB + PackedBSize * i); + data[i].ldb = ldb; + data[i].C = C + M * N * i; + data[i].ldc = ldc; + data[i].alpha = alpha; + data[i].beta = beta; + } + MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool); + + MlasGemm(TransA, M, N, K, alpha, A, lda, PackedB, beta, C, ldc, threadpool); } private: @@ -89,14 +152,14 @@ class MlasFgemmTest : public MlasTestBase { MlasFgemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) { } - void Test(size_t M, size_t N, size_t K, float alpha, float beta) { - Test(false, false, M, N, K, alpha, beta); - Test(false, true, M, N, K, alpha, beta); - Test(true, false, M, N, K, alpha, beta); - Test(true, true, M, N, K, alpha, beta); + void Test(size_t M, size_t N, size_t K, size_t BatchSize, T alpha, T beta) { + Test(false, false, M, N, K, BatchSize, alpha, beta); + Test(false, true, M, N, K, BatchSize, alpha, beta); + Test(true, false, M, N, K, BatchSize, alpha, beta); + Test(true, true, M, N, K, BatchSize, alpha, beta); } - void Test(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, float alpha, float beta) { + void Test(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, size_t BatchSize, T alpha, T beta) { // // Skip the test if the B buffer cannot be packed. // @@ -105,14 +168,14 @@ class MlasFgemmTest : public MlasTestBase { return; } - const T* A = BufferA.GetBuffer(K * M); - const T* B = BufferB.GetBuffer(N * K); - T* C = BufferC.GetBuffer(N * M); - T* CReference = BufferCReference.GetBuffer(N * M); + const T* A = BufferA.GetBuffer(K * M * BatchSize); + const T* B = BufferB.GetBuffer(N * K * BatchSize); + T* C = BufferC.GetBuffer(N * M * BatchSize); + T* CReference = BufferCReference.GetBuffer(N * M * BatchSize); Test(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, - M, N, K, alpha, A, trans_a ? M : K, B, trans_b ? K : N, + M, N, K, BatchSize, alpha, A, trans_a ? M : K, B, trans_b ? K : N, beta, C, CReference, N); } @@ -121,33 +184,36 @@ class MlasFgemmTest : public MlasTestBase { size_t M, size_t N, size_t K, - float alpha, + size_t BatchSize, + T alpha, const T* A, size_t lda, const T* B, size_t ldb, - float beta, + T beta, T* C, T* CReference, size_t ldc) { - std::fill_n(C, M * N, -0.5f); - std::fill_n(CReference, M * N, -0.5f); + std::fill_n(C, M * N * BatchSize, -0.5f); + std::fill_n(CReference, M * N * BatchSize, -0.5f); - PackedContext.TestGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, threadpool_); - ReferenceGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, CReference, ldc); + PackedContext.TestGemm(TransA, TransB, M, N, K, BatchSize, alpha, A, lda, B, ldb, beta, C, ldc, threadpool_); + ReferenceGemm(TransA, TransB, M, N, K, BatchSize, alpha, A, lda, B, ldb, beta, CReference, ldc); - for (size_t m = 0, f = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - // Sensitive to comparing positive/negative zero. - ASSERT_EQ(C[f], CReference[f]) - << " Diff @[" << m << ", " << n << "] f=" << f << ", " - << (Packed ? "Packed" : "NoPack") << "." - << (Threaded ? "SingleThread" : "Threaded") << "/" - << (TransA == CblasTrans ? "TransA" : "A") << "/" - << (TransB == CblasTrans ? "TransB" : "B") << "/" - << "M" << M << "xN" << N << "xK" << K << "/" - << "Alpha" << alpha << "/" - << "Beta" << beta; + for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + // Sensitive to comparing positive/negative zero. + ASSERT_EQ(C[f], CReference[f]) + << " Diff @[" << batch << ", " << m << ", " << n << "] f=" << f << ", " + << (Packed ? "Packed" : "NoPack") << "." + << (Threaded ? "SingleThread" : "Threaded") << "/" + << (TransA == CblasTrans ? "TransA" : "A") << "/" + << (TransB == CblasTrans ? "TransB" : "B") << "/" + << "M" << M << "xN" << N << "xK" << K << "/" + << "Alpha" << alpha << "/" + << "Beta" << beta; + } } } } @@ -157,111 +223,120 @@ class MlasFgemmTest : public MlasTestBase { size_t M, size_t N, size_t K, - float alpha, + size_t BatchSize, + T alpha, const T* A, size_t lda, const T* B, size_t ldb, - float beta, + T beta, T* C, size_t ldc) { - if (TransA == CblasNoTrans) { - if (TransB == CblasNoTrans) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const T* a = A + (m * lda); - const T* b = B + n; - T* c = C + (m * ldc) + n; - T sum = 0.0f; + for (size_t batch = 0; batch < BatchSize; batch++) { + if (TransA == CblasNoTrans) { + if (TransB == CblasNoTrans) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const T* a = A + (m * lda); + const T* b = B + n; + T* c = C + (m * ldc) + n; + T sum = 0.0f; - for (size_t k = 0; k < K; k++) { - sum += (*b * *a); - b += ldb; - a += 1; + for (size_t k = 0; k < K; k++) { + sum += (*b * *a); + b += ldb; + a += 1; + } + + *c = (*c * beta) + (sum * alpha); } + } - *c = (*c * beta) + (sum * alpha); + } else { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const T* a = A + (m * lda); + const T* b = B + (n * ldb); + T* c = C + (m * ldc) + n; + T sum = 0.0f; + + for (size_t k = 0; k < K; k++) { + sum += (*b * *a); + b += 1; + a += 1; + } + + *c = (*c * beta) + (sum * alpha); + } } } } else { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const T* a = A + (m * lda); - const T* b = B + (n * ldb); - T* c = C + (m * ldc) + n; - T sum = 0.0f; + if (TransB == CblasNoTrans) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const T* a = A + m; + const T* b = B + n; + T* c = C + (m * ldc) + n; + T sum = 0.0f; - for (size_t k = 0; k < K; k++) { - sum += (*b * *a); - b += 1; - a += 1; + for (size_t k = 0; k < K; k++) { + sum += (*b * *a); + b += ldb; + a += lda; + } + + *c = (*c * beta) + (sum * alpha); } + } - *c = (*c * beta) + (sum * alpha); - } - } - } - - } else { - if (TransB == CblasNoTrans) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const T* a = A + m; - const T* b = B + n; - T* c = C + (m * ldc) + n; - T sum = 0.0f; - - for (size_t k = 0; k < K; k++) { - sum += (*b * *a); - b += ldb; - a += lda; - } - - *c = (*c * beta) + (sum * alpha); - } - } - - } else { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const T* a = A + m; - const T* b = B + (n * ldb); - T* c = C + (m * ldc) + n; - T sum = 0.0f; - - for (size_t k = 0; k < K; k++) { - sum += (*b * *a); - b += 1; - a += lda; - } - - *c = (*c * beta) + (sum * alpha); + } else { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const T* a = A + m; + const T* b = B + (n * ldb); + T* c = C + (m * ldc) + n; + T sum = 0.0f; + + for (size_t k = 0; k < K; k++) { + sum += (*b * *a); + b += 1; + a += lda; + } + + *c = (*c * beta) + (sum * alpha); + } } } } + A += M * K; + B += K * N; + C += M * N; } } void ExecuteLong() override { - static const float multipliers[] = {0.0f, -0.0f, 0.25f, -0.5f, 1.0f, -1.0f}; + static const T multipliers[] = {0.0f, -0.0f, 0.25f, -0.5f, 1.0f, -1.0f}; for (size_t N = 1; N < 128; N++) { for (size_t K = 1; K < 128; K++) { for (size_t a = 0; a < _countof(multipliers); a++) { for (size_t b = 0; b < _countof(multipliers); b++) { - Test(1, N, K, multipliers[a], multipliers[b]); - Test(N, 1, K, multipliers[a], multipliers[b]); + Test(1, N, K, 1, multipliers[a], multipliers[b]); + Test(N, 1, K, 1, multipliers[a], multipliers[b]); + if (!Packed) { + Test(1, N, K, 3, multipliers[a], multipliers[b]); + } } } } } for (size_t a = 0; a < _countof(multipliers); a++) { - float alpha = multipliers[a]; + T alpha = multipliers[a]; for (size_t b = 0; b < _countof(multipliers); b++) { - float beta = multipliers[b]; + T beta = multipliers[b]; for (size_t M = 16; M < 160; M += 32) { for (size_t N = 16; N < 160; N += 32) { @@ -269,21 +344,25 @@ class MlasFgemmTest : public MlasTestBase { for (size_t k = 0; k < _countof(ks); k++) { size_t K = ks[k]; - Test(M, N, K, alpha, beta); - Test(M + 1, N, K, alpha, beta); - Test(M, N + 1, K, alpha, beta); - Test(M + 1, N + 1, K, alpha, beta); - Test(M + 3, N + 2, K, alpha, beta); - Test(M + 4, N, K, alpha, beta); - Test(M, N + 4, K, alpha, beta); - Test(M + 4, N + 4, K, alpha, beta); - Test(M + 3, N + 7, K, alpha, beta); - Test(M + 8, N, K, alpha, beta); - Test(M, N + 8, K, alpha, beta); - Test(M + 12, N + 12, K, alpha, beta); - Test(M + 13, N, K, alpha, beta); - Test(M, N + 15, K, alpha, beta); - Test(M + 15, N + 15, K, alpha, beta); + Test(M, N, K, 1, alpha, beta); + Test(M + 1, N, K, 1, alpha, beta); + Test(M, N + 1, K, 1, alpha, beta); + Test(M + 1, N + 1, K, 1, alpha, beta); + Test(M + 3, N + 2, K, 1, alpha, beta); + Test(M + 4, N, K, 1, alpha, beta); + Test(M, N + 4, K, 1, alpha, beta); + Test(M + 4, N + 4, K, 1, alpha, beta); + Test(M + 3, N + 7, K, 1, alpha, beta); + Test(M + 8, N, K, 1, alpha, beta); + Test(M, N + 8, K, 1, alpha, beta); + Test(M + 12, N + 12, K, 1, alpha, beta); + Test(M + 13, N, K, 1, alpha, beta); + Test(M, N + 15, K, 1, alpha, beta); + Test(M + 15, N + 15, K, 1, alpha, beta); + if (!Packed) { + Test(M + 3, N + 1, K, 7, multipliers[a], multipliers[b]); + Test(M + 13, N + 2, K, 9, multipliers[a], multipliers[b]); + } } } printf("a %zd/%zd b %zd/%zd M %zd\n", a, _countof(multipliers), b, _countof(multipliers), M); @@ -294,7 +373,7 @@ class MlasFgemmTest : public MlasTestBase { for (size_t M = 0; M < 160; M++) { for (size_t N = 0; N < 160; N++) { for (size_t K = 0; K < 160; K++) { - Test(M, N, K, 1.0f, 0.0f); + Test(M, N, K, 1, 1.0f, 0.0f); } } printf("M %zd\n", M); @@ -303,10 +382,10 @@ class MlasFgemmTest : public MlasTestBase { for (size_t M = 160; M < 320; M += 24) { for (size_t N = 112; N < 320; N += 24) { for (size_t K = 0; K < 16; K++) { - Test(M, N, K, 1.0f, 0.0f); + Test(M, N, K, 1, 1.0f, 0.0f); } for (size_t K = 16; K < 160; K += 32) { - Test(M, N, K, 1.0f, 0.0f); + Test(M, N, K, 1, 1.0f, 0.0f); } } printf("M %zd\n", M); diff --git a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h index 2780619809..8a0a936e79 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h +++ b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h @@ -14,20 +14,20 @@ template class FgemmShortExecuteTest : public MlasTestFixture> { public: - explicit FgemmShortExecuteTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, float alpha, float beta) - : trans_a_(trans_a), trans_b_(trans_b), M_(M), N_(N), K_(K), alpha_(alpha), beta_(beta) { + explicit FgemmShortExecuteTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, size_t BatchSize, float alpha, float beta) + : trans_a_(trans_a), trans_b_(trans_b), M_(M), N_(N), K_(K), Batch_(BatchSize), alpha_(alpha), beta_(beta) { } void TestBody() override { MlasTestFixture>::mlas_tester->Test( - trans_a_, trans_b_, M_, N_, K_, alpha_, beta_); + trans_a_, trans_b_, M_, N_, K_, Batch_, alpha_, beta_); } - static size_t RegisterSingleTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, float alpha, float beta) { + static size_t RegisterSingleTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, size_t BatchSize, float alpha, float beta) { std::stringstream ss; ss << (trans_a ? "TransA" : "A") << "/" << (trans_b ? "TransB" : "B") << "/" - << "M" << M << "xN" << N << "xK" << K << "/" + << "BatchSize" << BatchSize << "/M" << M << "xN" << N << "xK" << K << "/" << "Alpha" << alpha << "/" << "Beta" << beta; auto test_name = ss.str(); @@ -42,37 +42,39 @@ class FgemmShortExecuteTest : public MlasTestFixture MlasTestFixture>* { return new FgemmShortExecuteTest( - trans_a, trans_b, M, N, K, alpha, beta); + trans_a, trans_b, M, N, K, BatchSize, alpha, beta); }); return 1; } - static size_t RegisterTestTransposeABProduct(size_t M, size_t N, size_t K, float alpha, float beta) { - return RegisterSingleTest(false, false, M, N, K, alpha, beta) + - RegisterSingleTest(false, true, M, N, K, alpha, beta) + - RegisterSingleTest(true, false, M, N, K, alpha, beta) + - RegisterSingleTest(true, true, M, N, K, alpha, beta); + static size_t RegisterTestTransposeABProduct(size_t M, size_t N, size_t K, size_t BatchSize, float alpha, float beta) { + return RegisterSingleTest(false, false, M, N, K, BatchSize, alpha, beta) + + RegisterSingleTest(false, true, M, N, K, BatchSize, alpha, beta) + + RegisterSingleTest(true, false, M, N, K, BatchSize, alpha, beta) + + RegisterSingleTest(true, true, M, N, K, BatchSize, alpha, beta); } static size_t RegisterShortExecuteTests() { size_t test_registered = 0; for (size_t b = 0; b < 16; b++) { - test_registered += RegisterTestTransposeABProduct(b, b, b, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(b, b, b, 1, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(b, b, b, 3, 1.0f, 0.0f); } for (size_t b = 16; b <= 256; b <<= 1) { - test_registered += RegisterTestTransposeABProduct(b, b, b, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(b, b, b, 1, 1.0f, 0.0f); } for (size_t b = 256; b < 320; b += 32) { - test_registered += RegisterTestTransposeABProduct(b, b, b, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(b, b, b, 1, 1.0f, 0.0f); } - test_registered += RegisterTestTransposeABProduct(128, 3072, 768, 1.0f, 0.0f); - test_registered += RegisterTestTransposeABProduct(128, 768, 3072, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(128, 3072, 768, 1, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(128, 768, 3072, 1, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(25, 81, 79, 7, 1.0f, 0.0f); return test_registered; } private: bool trans_a_, trans_b_; - size_t M_, N_, K_; - float alpha_, beta_; + const size_t M_, N_, K_, Batch_; + const T alpha_, beta_; };