diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 852247c43b..68e5b6815e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -309,6 +309,7 @@ Status Attention::Compute(OpKernelContext* context) const { T* qkv_dest = QKV[qkv_index]; int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size); + // TODO!! memcpy here makes it not worthwhile to use Gemm batch. Possible to post process? // broadcast 3NH -> (3.B.N.S.H) const T* broadcast_data_src = bias_data + weights_offset; T* broadcast_data_dest = QKV[qkv_index] + qkv_offset; diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index cad2bb8546..c417bdb17d 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -147,10 +147,102 @@ MlasActivation( // // Matrix/matrix multiply routines. +// C := alpha * op(A) * op(B) + beta * C +// op(X) = X or op(X) = transpose(X) or op(X) = conjg(transpose(X)) // +/** + * @brief Supply matrices data information to single precision gemm functions + */ +struct MLAS_SGEMM_DATA_PARAMS { + const float* A = nullptr; /**< Supplies the address of matrix A */ + size_t lda = 0; /**< Supplies the first dimension of matrix A. */ + const float* B = nullptr; /**< Supplies the address of matrix B */ + size_t ldb = 0; /**< Supplies the first dimension of matrix B. */ + float* C = nullptr; /**< Supplies the address of matrix C */ + size_t ldc = 0; /**< Supplies the first dimension of matrix C. */ + float alpha = 1.0f; /**< Supplies the scalar alpha multiplier (see SGEMM definition) */ + float beta = 0.0f; /**< Supplies the scalar beta multiplier (see SGEMM definition) */ + bool BIsPacked = false; /**< Whether B is pre-packed */ +}; + +/** + * @brief Batched single precision matrix/matrix multiply operation (SGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number + of rows of matrix B. + * @param Data A array of matrices data parameters + * @param BatchSize Supplies number of multiplications in this batch + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ void MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); + +/** + * @brief Single precision matrix/matrix multiply operation (SGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number + of rows of matrix B. + * @param Data Supplies the matrices data parameters + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ +inline +void +MlasGemm( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS& Data, + MLAS_THREADPOOL* ThreadPool + ) +{ + MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool); +} + +/** + * @brief Single precision matrix/matrix multiply operation (SGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number + of rows of matrix B. + * @param alpha Supplies the scalar alpha multiplier (see SGEMM definition) + * @param A Supplies the address of matrix A + * @param lda Supplies the first dimension of matrix A. + * @param B Supplies the address of matrix B + * @param ldb Supplies the first dimension of matrix B. + * @param beta Supplies the scalar beta multiplier (see SGEMM definition) + * @param C Supplies the address of matrix C + * @param ldc Supplies the first dimension of matrix C. + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ +inline +void MlasGemm( CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, @@ -166,10 +258,41 @@ MlasGemm( float* C, size_t ldc, MLAS_THREADPOOL* ThreadPool - ); + ) +{ + MLAS_SGEMM_DATA_PARAMS Data; + Data.alpha = alpha; + Data.A = A; + Data.lda = lda; + Data.B = B; + Data.ldb = ldb; + Data.beta = beta; + Data.C = C; + Data.ldc = ldc; + MlasGemm(TransA, TransB, M, N, K, Data, ThreadPool); +} + +/** + * @brief the single precision matrix/matrix multiply operation (SGEMM) with pre-packed B + * + * @param TransA - Supplies the transpose operation for matrix A. + * @param M - Supplies the number of rows of matrix A and matrix C. + * @param N - Supplies the number of columns of matrix B and matrix C. + * @param K - Supplies the number of columns of matrix A and the number + of rows of matrix B. + * @param alpha - Supplies the scalar alpha multiplier (see SGEMM definition). + * @param A - Supplies the address of matrix A. + * @param lda - Supplies the first dimension of matrix A. + * @param PackedB - Supplies the address of packed matrix B. + * @param beta - Supplies the scalar beta multiplier (see SGEMM definition). + * @param C - Supplies the address of matrix C. + * @param ldc - Supplies the first dimension of matrix C. + * @param ThreadPool - Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ +inline void -MLASCALL MlasGemm( CBLAS_TRANSPOSE TransA, size_t M, @@ -183,10 +306,117 @@ MlasGemm( float* C, size_t ldc, MLAS_THREADPOOL* ThreadPool - ); + ) +{ + MLAS_SGEMM_DATA_PARAMS DataParams; + DataParams.A = A; + DataParams.lda = lda; + DataParams.B = static_cast(PackedB); + DataParams.ldb = 0; + DataParams.C = C; + DataParams.ldc = ldc; + DataParams.alpha = alpha; + DataParams.beta = beta; + DataParams.BIsPacked = true; + MlasGemmBatch(TransA, + CblasTrans, // deos not matter when B is packed + M, N, K, &DataParams, 1, ThreadPool); +} + + +/** + * @brief Supply matrices data information to double precision gemm functions + */ +struct MLAS_DGEMM_DATA_PARAMS { + const double* A = nullptr; /**< Supplies the address of matrix A */ + size_t lda = 0; /**< Supplies the first dimension of matrix A. */ + const double* B = nullptr; /**< Supplies the address of matrix B */ + size_t ldb = 0; /**< Supplies the first dimension of matrix B. */ + double* C = nullptr; /**< Supplies the address of matrix C */ + size_t ldc = 0; /**< Supplies the first dimension of matrix C. */ + double alpha = 1.0; /**< Supplies the scalar alpha multiplier (see SGEMM definition) */ + double beta = 0.0; /**< Supplies the scalar beta multiplier (see SGEMM definition) */ +}; + +/** + * @brief Batched double precision matrix/matrix multiply operation (DGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number + of rows of matrix B. + * @param Data A array of matrices data parameters + * @param BatchSize Supplies number of multiplications in this batch + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ void MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_DGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); + + +/** + * @brief Double precision matrix/matrix multiply operation (DGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number + of rows of matrix B. + * @param Data Supplies the matrices data parameters + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ +inline +void +MlasGemm( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_DGEMM_DATA_PARAMS& Data, + MLAS_THREADPOOL* ThreadPool + ) +{ + MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool); +} + +/** + * @brief Double precision matrix/matrix multiply operation (DGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number + of rows of matrix B. + * @param alpha Supplies the scalar alpha multiplier (see SGEMM definition) + * @param A Supplies the address of matrix A + * @param lda Supplies the first dimension of matrix A. + * @param B Supplies the address of matrix B + * @param ldb Supplies the first dimension of matrix B. + * @param beta Supplies the scalar beta multiplier (see SGEMM definition) + * @param C Supplies the address of matrix C + * @param ldc Supplies the first dimension of matrix C. + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ +inline +void MlasGemm( CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, @@ -202,7 +432,19 @@ MlasGemm( double* C, size_t ldc, MLAS_THREADPOOL* ThreadPool - ); + ) +{ + MLAS_DGEMM_DATA_PARAMS Data; + Data.alpha = alpha; + Data.A = A; + Data.lda = lda; + Data.B = B; + Data.ldb = ldb; + Data.beta = beta; + Data.C = C; + Data.ldc = ldc; + MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool); +} enum class MLAS_QUANTIZATION_GRANULARITY { PerMatrix, diff --git a/onnxruntime/core/mlas/lib/dgemm.cpp b/onnxruntime/core/mlas/lib/dgemm.cpp index 666856a6b6..d809e07d2a 100644 --- a/onnxruntime/core/mlas/lib/dgemm.cpp +++ b/onnxruntime/core/mlas/lib/dgemm.cpp @@ -26,29 +26,6 @@ Abstract: #define MLAS_DGEMM_TRANSA_ROWS 12 -// -// Define the parameters to execute segments of a DGEMM operation on worker -// threads. -// - -struct MLAS_DGEMM_WORK_BLOCK { - ptrdiff_t ThreadCountM; - ptrdiff_t ThreadCountN; - CBLAS_TRANSPOSE TransA; - CBLAS_TRANSPOSE TransB; - size_t M; - size_t N; - size_t K; - const double* A; - size_t lda; - const double* B; - size_t ldb; - double* C; - size_t ldc; - double alpha; - double beta; -}; - #ifdef MLAS_TARGET_AMD64 void @@ -750,8 +727,15 @@ Return Value: void MlasDgemmThreaded( - void* Context, - ptrdiff_t ThreadId + const ptrdiff_t ThreadCountM, + const ptrdiff_t ThreadCountN, + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const size_t M, + const size_t N, + const size_t K, + const MLAS_DGEMM_DATA_PARAMS* Data, + const ptrdiff_t ThreadId ) /*++ @@ -772,10 +756,6 @@ Return Value: --*/ { - const auto* WorkBlock = (MLAS_DGEMM_WORK_BLOCK*)Context; - - const ptrdiff_t ThreadCountM = WorkBlock->ThreadCountM; - const ptrdiff_t ThreadCountN = WorkBlock->ThreadCountN; const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; @@ -784,7 +764,6 @@ Return Value: // Partition the operation along the M dimension. // - size_t M = WorkBlock->M; size_t RangeStartM; size_t RangeCountM; @@ -794,7 +773,6 @@ Return Value: // Partition the operation along the N dimension. // - size_t N = WorkBlock->N; size_t RangeStartN; size_t RangeCountN; @@ -813,50 +791,32 @@ Return Value: // Dispatch the partitioned operation. // - CBLAS_TRANSPOSE TransA = WorkBlock->TransA; - CBLAS_TRANSPOSE TransB = WorkBlock->TransB; + const size_t lda = Data->lda; + const size_t ldb = Data->ldb; + const size_t ldc = Data->ldc; - const size_t lda = WorkBlock->lda; - const size_t ldb = WorkBlock->ldb; - const size_t ldc = WorkBlock->ldc; + const double* A = Data->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1); + const double* B = Data->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb); + double* C = Data->C + RangeStartM * ldc + RangeStartN; - const double* A = WorkBlock->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1); - const double* B = WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb); - double* C = WorkBlock->C + RangeStartM * ldc + RangeStartN; - - MlasDgemmOperation(TransA, TransB, RangeCountM, RangeCountN, WorkBlock->K, - WorkBlock->alpha, A, lda, B, ldb, WorkBlock->beta, C, ldc); + MlasDgemmOperation(TransA, TransB, RangeCountM, RangeCountN, K, + Data->alpha, A, lda, B, ldb, Data->beta, C, ldc); } + void -MlasDgemmSchedule( - MLAS_DGEMM_WORK_BLOCK* WorkBlock, +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_DGEMM_DATA_PARAMS* Data, + size_t BatchSize, MLAS_THREADPOOL* ThreadPool ) -/*++ - -Routine Description: - - This routine schedules the double precision matrix/matrix multiply - operation (DGEMM) across one or more threads. - -Arguments: - - WorkBlock - Supplies the structure containing the GEMM parameters. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ { - const size_t M = WorkBlock->M; - const size_t N = WorkBlock->N; - const size_t K = WorkBlock->K; - // // Compute the number of target threads given the complexity of the DGEMM // operation. Small requests should run using the single threaded path. @@ -885,121 +845,40 @@ Return Value: // works okay for operations involving skinny matrices. // + ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchSize - 1) / BatchSize; + ptrdiff_t ThreadCountM; + ptrdiff_t ThreadCountN; + if (N > M) { const size_t BlockedN = (N + MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_DGEMM_STRIDEN_THREAD_ALIGN; - if (size_t(TargetThreadCount) > BlockedN) { - TargetThreadCount = ptrdiff_t(BlockedN); + if (size_t(ThreadsPerGemm) > BlockedN) { + ThreadsPerGemm = ptrdiff_t(BlockedN); } - WorkBlock->ThreadCountM = 1; - WorkBlock->ThreadCountN = TargetThreadCount; + ThreadCountM = 1; + ThreadCountN = ThreadsPerGemm; } else { - if (size_t(TargetThreadCount) > M) { - TargetThreadCount = ptrdiff_t(M); + if (size_t(ThreadsPerGemm) > M) { + ThreadsPerGemm = ptrdiff_t(M); } - WorkBlock->ThreadCountM = TargetThreadCount; - WorkBlock->ThreadCountN = 1; + ThreadCountM = ThreadsPerGemm; + ThreadCountN = 1; } - MlasExecuteThreaded(MlasDgemmThreaded, WorkBlock, TargetThreadCount, ThreadPool); -} + const ptrdiff_t TotalThreads = ThreadsPerGemm * static_cast(BatchSize); + MlasTrySimpleParallel(ThreadPool, TotalThreads, [=](ptrdiff_t tid) { + const ptrdiff_t GemmIdx = tid / ThreadsPerGemm; + const ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; + MlasDgemmThreaded(ThreadCountM, ThreadCountN, TransA, TransB, + M, N, K, &(Data[GemmIdx]), ThreadIdx); + }); -void -MLASCALL -MlasGemm( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - double alpha, - const double* A, - size_t lda, - const double* B, - size_t ldb, - double beta, - double* C, - size_t ldc, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine implements the double precision matrix/matrix multiply - operation (DGEMM). - -Arguments: - - TransA - Supplies the transpose operation for matrix A. - - TransB - Supplies the transpose operation for matrix B. - - M - Supplies the number of rows of matrix A and matrix C. - - N - Supplies the number of columns of matrix B and matrix C. - - K - Supplies the number of columns of matrix A and the number of rows of - matrix B. - - alpha - Supplies the scalar alpha multiplier (see DGEMM definition). - - A - Supplies the address of matrix A. - - lda - Supplies the first dimension of matrix A. - - B - Supplies the address of matrix B. - - ldb - Supplies the first dimension of matrix B. - - beta - Supplies the scalar beta multiplier (see DGEMM definition). - - C - Supplies the address of matrix C. - - ldc - Supplies the first dimension of matrix C. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ -{ - MLAS_DGEMM_WORK_BLOCK WorkBlock; - - // - // Capture the GEMM parameters to the work block. - // - - memset(&WorkBlock, 0, sizeof(MLAS_DGEMM_WORK_BLOCK)); - - WorkBlock.TransA = TransA; - WorkBlock.TransB = TransB; - WorkBlock.M = M; - WorkBlock.N = N; - WorkBlock.K = K; - WorkBlock.A = A; - WorkBlock.lda = lda; - WorkBlock.B = B; - WorkBlock.ldb = ldb; - WorkBlock.C = C; - WorkBlock.ldc = ldc; - WorkBlock.alpha = alpha; - WorkBlock.beta = beta; - - // - // Schedule the operation across a set of worker threads. - // - - MlasDgemmSchedule(&WorkBlock, ThreadPool); } #endif diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 34ef3f9fa1..8934a8b520 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -31,25 +31,6 @@ Abstract: // threads. // -struct MLAS_SGEMM_WORK_BLOCK { - ptrdiff_t ThreadCountM; - ptrdiff_t ThreadCountN; - CBLAS_TRANSPOSE TransA; - CBLAS_TRANSPOSE TransB; - size_t M; - size_t N; - size_t K; - const float* A; - size_t lda; - const void* B; - size_t ldb; - float* C; - size_t ldc; - float alpha; - float beta; - bool BIsPacked; -}; - void MlasSgemmMultiplyBeta( float* C, @@ -1476,7 +1457,15 @@ Return Value: void MlasSgemmThreaded( - void* Context, + const ptrdiff_t ThreadCountM, + const ptrdiff_t ThreadCountN, + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const size_t M, + const size_t N, + const size_t K, + + const MLAS_SGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId ) /*++ @@ -1488,7 +1477,17 @@ Routine Description: Arguments: - Context - Supplies the pointer to the context for the threaded operation. + ThreadCountM - Supplies the total thread partition on the M dimension. + + ThreadCountN - Supplies the total thread partition on the N dimension. + + TransA - Supplies the transpose operation on A matrix + + TransB - Supplies the transpose operation on B matrix + + M, N, K - Supplies the shape of the multiplication + + DataParams - Supplies the data position and layout of the matrices ThreadId - Supplies the current index of the threaded operation. @@ -1498,10 +1497,6 @@ Return Value: --*/ { - const auto* WorkBlock = (MLAS_SGEMM_WORK_BLOCK*)Context; - - const ptrdiff_t ThreadCountM = WorkBlock->ThreadCountM; - const ptrdiff_t ThreadCountN = WorkBlock->ThreadCountN; const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; @@ -1510,7 +1505,6 @@ Return Value: // Partition the operation along the M dimension. // - size_t M = WorkBlock->M; size_t RangeStartM; size_t RangeCountM; @@ -1520,7 +1514,6 @@ Return Value: // Partition the operation along the N dimension. // - size_t N = WorkBlock->N; size_t RangeStartN; size_t RangeCountN; @@ -1539,61 +1532,42 @@ Return Value: // Dispatch the partitioned operation. // - CBLAS_TRANSPOSE TransA = WorkBlock->TransA; + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; - const size_t lda = WorkBlock->lda; - const size_t ldc = WorkBlock->ldc; + const float* A = DataParams->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1); + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - const float* A = WorkBlock->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1); - float* C = WorkBlock->C + RangeStartM * ldc + RangeStartN; - - if (WorkBlock->BIsPacked) { + if (DataParams->BIsPacked) { MlasSgemmPackedOperation(TransA, RangeCountM, RangeStartN, RangeCountN, - WorkBlock->K, WorkBlock->alpha, A, lda, WorkBlock->B, - BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, WorkBlock->beta, C, ldc); + K, DataParams->alpha, A, lda, DataParams->B, + BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, DataParams->beta, C, ldc); } else { - CBLAS_TRANSPOSE TransB = WorkBlock->TransB; + const size_t ldb = DataParams->ldb; - const size_t ldb = WorkBlock->ldb; + const float* B = (const float*)DataParams->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb); - const float* B = (const float*)WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb); - - MlasSgemmOperation(TransA, TransB, RangeCountM, RangeCountN, WorkBlock->K, - WorkBlock->alpha, A, lda, B, ldb, WorkBlock->beta, C, ldc); + MlasSgemmOperation(TransA, TransB, RangeCountM, RangeCountN, K, + DataParams->alpha, A, lda, B, ldb, DataParams->beta, C, ldc); } } void -MlasSgemmSchedule( - MLAS_SGEMM_WORK_BLOCK* WorkBlock, +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, MLAS_THREADPOOL* ThreadPool ) -/*++ - -Routine Description: - - This routine schedules the single precision matrix/matrix multiply - operation (SGEMM) across one or more threads. - -Arguments: - - WorkBlock - Supplies the structure containing the GEMM parameters. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ { - const size_t M = WorkBlock->M; - const size_t N = WorkBlock->N; - const size_t K = WorkBlock->K; // // Compute the number of target threads given the complexity of the SGEMM @@ -1623,206 +1597,41 @@ Return Value: // works okay for operations involving skinny matrices. // + ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchSize - 1) / BatchSize; + ptrdiff_t ThreadCountM; + ptrdiff_t ThreadCountN; + if (N > M) { const size_t BlockedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - if (size_t(TargetThreadCount) > BlockedN) { - TargetThreadCount = ptrdiff_t(BlockedN); + if (size_t(ThreadsPerGemm) > BlockedN) { + ThreadsPerGemm = ptrdiff_t(BlockedN); } - WorkBlock->ThreadCountM = 1; - WorkBlock->ThreadCountN = TargetThreadCount; + ThreadCountM = 1; + ThreadCountN = ThreadsPerGemm; } else { - if (size_t(TargetThreadCount) > M) { - TargetThreadCount = ptrdiff_t(M); + if (size_t(ThreadsPerGemm) > M) { + ThreadsPerGemm = ptrdiff_t(M); } - WorkBlock->ThreadCountM = TargetThreadCount; - WorkBlock->ThreadCountN = 1; + ThreadCountM = ThreadsPerGemm; + ThreadCountN = 1; } - MlasExecuteThreaded(MlasSgemmThreaded, WorkBlock, TargetThreadCount, ThreadPool); -} - -void -MLASCALL -MlasGemm( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - float alpha, - const float* A, - size_t lda, - const float* B, - size_t ldb, - float beta, - float* C, - size_t ldc, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine implements the single precision matrix/matrix multiply - operation (SGEMM). - -Arguments: - - TransA - Supplies the transpose operation for matrix A. - - TransB - Supplies the transpose operation for matrix B. - - M - Supplies the number of rows of matrix A and matrix C. - - N - Supplies the number of columns of matrix B and matrix C. - - K - Supplies the number of columns of matrix A and the number of rows of - matrix B. - - alpha - Supplies the scalar alpha multiplier (see SGEMM definition). - - A - Supplies the address of matrix A. - - lda - Supplies the first dimension of matrix A. - - B - Supplies the address of matrix B. - - ldb - Supplies the first dimension of matrix B. - - beta - Supplies the scalar beta multiplier (see SGEMM definition). - - C - Supplies the address of matrix C. - - ldc - Supplies the first dimension of matrix C. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ -{ - MLAS_SGEMM_WORK_BLOCK WorkBlock; - - // - // Capture the GEMM parameters to the work block. - // - - memset(&WorkBlock, 0, sizeof(MLAS_SGEMM_WORK_BLOCK)); - - WorkBlock.TransA = TransA; - WorkBlock.TransB = TransB; - WorkBlock.M = M; - WorkBlock.N = N; - WorkBlock.K = K; - WorkBlock.A = A; - WorkBlock.lda = lda; - WorkBlock.B = B; - WorkBlock.ldb = ldb; - WorkBlock.C = C; - WorkBlock.ldc = ldc; - WorkBlock.alpha = alpha; - WorkBlock.beta = beta; - - // - // Schedule the operation across a set of worker threads. - // - - MlasSgemmSchedule(&WorkBlock, ThreadPool); -} - -void -MLASCALL -MlasGemm( - CBLAS_TRANSPOSE TransA, - size_t M, - size_t N, - size_t K, - float alpha, - const float* A, - size_t lda, - const void* PackedB, - float beta, - float* C, - size_t ldc, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine implements the single precision matrix/matrix multiply - operation (SGEMM). - -Arguments: - - TransA - Supplies the transpose operation for matrix A. - - M - Supplies the number of rows of matrix A and matrix C. - - N - Supplies the number of columns of matrix B and matrix C. - - K - Supplies the number of columns of matrix A and the number of rows of - matrix B. - - alpha - Supplies the scalar alpha multiplier (see SGEMM definition). - - A - Supplies the address of matrix A. - - lda - Supplies the first dimension of matrix A. - - PackedB - Supplies the address of packed matrix B. - - beta - Supplies the scalar beta multiplier (see SGEMM definition). - - C - Supplies the address of matrix C. - - ldc - Supplies the first dimension of matrix C. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ -{ - MLAS_SGEMM_WORK_BLOCK WorkBlock; - - // - // Capture the GEMM parameters to the work block. - // - - memset(&WorkBlock, 0, sizeof(MLAS_SGEMM_WORK_BLOCK)); - - WorkBlock.TransA = TransA; - WorkBlock.M = M; - WorkBlock.N = N; - WorkBlock.K = K; - WorkBlock.A = A; - WorkBlock.lda = lda; - WorkBlock.B = PackedB; - WorkBlock.C = C; - WorkBlock.ldc = ldc; - WorkBlock.alpha = alpha; - WorkBlock.beta = beta; - WorkBlock.BIsPacked = true; - - // - // Schedule the operation across a set of worker threads. - // - - MlasSgemmSchedule(&WorkBlock, ThreadPool); + MlasTrySimpleParallel(ThreadPool, + ThreadsPerGemm * static_cast(BatchSize), + [=](ptrdiff_t tid) + { + ptrdiff_t GemmIdx = tid / ThreadsPerGemm; + ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; + MlasSgemmThreaded(ThreadCountM, ThreadCountN, + TransA, TransB, M, N, K, &(Data[GemmIdx]), ThreadIdx); + }); } size_t diff --git a/onnxruntime/core/mlas/lib/threading.cpp b/onnxruntime/core/mlas/lib/threading.cpp index 85e6b2ef7e..8769abdf08 100644 --- a/onnxruntime/core/mlas/lib/threading.cpp +++ b/onnxruntime/core/mlas/lib/threading.cpp @@ -88,7 +88,7 @@ MlasTrySimpleParallel( #ifdef _OPENMP #pragma omp parallel for #endif - + for (ptrdiff_t tid = 0; tid < Iterations; tid++) { Work(tid); } diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index adb4419d68..6165d922fa 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -159,38 +159,27 @@ Status MatMul::Compute(OpKernelContext* ctx) const { const auto* b_data = b ? b->Data() : nullptr; auto* y_data = y->MutableData(); - // TODO: replace it with GemmBatch for performance, it's OK for now as GemmBatch unrolls as well - size_t max_len = helper.OutputOffsets().size(); + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = static_cast(trans_a ? M : K); + const size_t ldb = static_cast(trans_b ? K : N); + + std::vector data(max_len); for (size_t i = 0; i < max_len; i++) { - if (packed_b_) { - MlasGemm( - trans_a ? CblasTrans : CblasNoTrans, - static_cast(helper.M()), - static_cast(helper.N()), - static_cast(helper.K()), - alpha_attr_, - a_data + helper.LeftOffsets()[i], - static_cast(trans_a ? helper.M() : helper.K()), - packed_b_.get(), - 0.0f, - y_data + helper.OutputOffsets()[i], - static_cast(helper.N()), - thread_pool); - continue; - } - math::Gemm( - trans_a ? CblasTrans : CblasNoTrans, - trans_b ? CblasTrans : CblasNoTrans, - helper.M(), - helper.N(), - helper.K(), - alpha_attr_, - a_data + helper.LeftOffsets()[i], - b_data + helper.RightOffsets()[i], - 0.0f, - y_data + helper.OutputOffsets()[i], - thread_pool); + data[i].BIsPacked = bool(packed_b_); + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = data[i].BIsPacked ? (float*)packed_b_.get() : b_data + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = alpha_attr_; + data[i].beta = 0.0f; } + MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, + M, N, K, data.data(), max_len, thread_pool); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc b/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc index 82b808d9e0..50a41cfa4c 100644 --- a/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc +++ b/onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc @@ -108,6 +108,7 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) const { MlasGemm(gemm_shape, gemm_params, ctx->GetOperatorThreadPool()); + //TODO!! consider making this a post processor, so that we can parallize this loop MlasRequantizeOutput(gemm_output, y->template MutableData() + helper.OutputOffsets()[i], nullptr, diff --git a/onnxruntime/test/mlas/bench/bench_qgemm.cpp b/onnxruntime/test/mlas/bench/bench_qgemm.cpp index 76f4caf655..86723a931b 100644 --- a/onnxruntime/test/mlas/bench/bench_qgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_qgemm.cpp @@ -17,18 +17,18 @@ void QGEMM(benchmark::State& state, bool pack_b) { const uint8_t a_zero_point = 29; const uint8_t b_zero_point = 179; - const int64_t M = state.range(0); - const int64_t N = state.range(1); - const int64_t K = state.range(2); + if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); + if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); + if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); + if (state.range(3) <= 0) throw std::invalid_argument("Batch must greater than 0!"); + if (state.range(4) <= 0) throw std::invalid_argument("Threads must greater than 0!"); - const int64_t batch = state.range(3); - const int64_t threads = state.range(4); + const size_t M = static_cast(state.range(0)); + const size_t N = static_cast(state.range(1)); + const size_t K = static_cast(state.range(2)); - if (M <= 0) throw std::invalid_argument("M must greater than 0!"); - if (N <= 0) throw std::invalid_argument("N must greater than 0!"); - if (K <= 0) throw std::invalid_argument("K must greater than 0!"); - if (batch <= 0) throw std::invalid_argument("Batch must greater than 0!"); - if (threads <= 0) throw std::invalid_argument("Threads must greater than 0!"); + const size_t batch = static_cast(state.range(3)); + const size_t threads = static_cast(state.range(4)); OrtThreadPoolParams tpo; tpo.thread_pool_size = int(threads); diff --git a/onnxruntime/test/mlas/bench/bench_sgemm.cpp b/onnxruntime/test/mlas/bench/bench_sgemm.cpp index 4336311a90..93ae0f2eb2 100644 --- a/onnxruntime/test/mlas/bench/bench_sgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sgemm.cpp @@ -10,13 +10,13 @@ static const std::vector sgemm_bench_arg_names = {"M", "N", "K"}; void SGEMM(benchmark::State& state, bool pack_b, bool trans_a, bool trans_b, float alpha = 1.0f, float beta = 0.0f) { - const int64_t M = state.range(0); - const int64_t N = state.range(1); - const int64_t K = state.range(2); + if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); + if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); + if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); + const size_t M = static_cast(state.range(0)); + const size_t N = static_cast(state.range(1)); + const size_t K = static_cast(state.range(2)); - if (M <= 0) throw std::invalid_argument("M must greater than 0!"); - if (N <= 0) throw std::invalid_argument("N must greater than 0!"); - if (K <= 0) throw std::invalid_argument("K must greater than 0!"); auto A = RandomVectorUniform(static_cast(M * K), -1.0f, 1.0f); auto B = RandomVectorUniform(static_cast(N * K), -1.0f, 1.0f); diff --git a/onnxruntime/test/mlas/unittest/test_fgemm.cpp b/onnxruntime/test/mlas/unittest/test_fgemm.cpp index 710083c6e0..89af5234ee 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_fgemm.cpp @@ -7,7 +7,6 @@ #include #include - template <> MlasFgemmTest* MlasTestFixture>::mlas_tester(nullptr); template <> MlasFgemmTest* MlasTestFixture>::mlas_tester(nullptr); template <> MlasFgemmTest* MlasTestFixture>::mlas_tester(nullptr); diff --git a/onnxruntime/test/mlas/unittest/test_fgemm.h b/onnxruntime/test/mlas/unittest/test_fgemm.h index 5a161ebafd..35d4622003 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm.h +++ b/onnxruntime/test/mlas/unittest/test_fgemm.h @@ -21,8 +21,8 @@ const char* GetGemmTestSuitePrefix() { template class FgemmPackedContext; -template -class FgemmPackedContext { +template <> +class FgemmPackedContext { public: void TestGemm( @@ -31,21 +31,69 @@ class FgemmPackedContext { size_t M, size_t N, size_t K, - float alpha, - const T* A, + size_t BatchSize, + const float alpha, + const float* A, size_t lda, - const T* B, + const float* B, size_t ldb, - float beta, - T* C, + const float beta, + float* C, size_t ldc, MLAS_THREADPOOL* threadpool) { - MlasGemm(TransA, TransB, M, N, K, T(alpha), A, lda, B, ldb, T(beta), C, ldc, threadpool); + std::vector data(BatchSize); + for (size_t i = 0; i < BatchSize; i++) { + data[i].A = A + M * K * i; + data[i].lda = lda; + data[i].B = B + K * N * i; + data[i].ldb = ldb; + data[i].C = C + M * N * i; + data[i].ldc = ldc; + data[i].alpha = alpha; + data[i].beta = beta; + } + MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool); } }; -template -class FgemmPackedContext { +#ifdef MLAS_TARGET_AMD64 +template <> +class FgemmPackedContext { + public: + void TestGemm( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + size_t BatchSize, + double alpha, + const double* A, + size_t lda, + const double* B, + size_t ldb, + double beta, + double* C, + size_t ldc, + MLAS_THREADPOOL* threadpool) { + std::vector data(BatchSize); + for (size_t i = 0; i < BatchSize; i++) { + data[i].A = A + M * K * i; + data[i].lda = lda; + data[i].B = B + K * N * i; + data[i].ldb = ldb; + data[i].C = C + M * N * i; + data[i].ldc = ldc; + data[i].alpha = alpha; + data[i].beta = beta; + } + MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool); + } +}; +#endif + +template <> +class FgemmPackedContext { public: void TestGemm( @@ -54,19 +102,34 @@ class FgemmPackedContext { size_t M, size_t N, size_t K, - float alpha, - const T* A, + size_t BatchSize, + const float alpha, + const float* A, size_t lda, - const T* B, + const float* B, size_t ldb, - float beta, - T* C, + const float beta, + float* C, size_t ldc, MLAS_THREADPOOL* threadpool) { size_t PackedBSize = MlasGemmPackBSize(N, K); - void* PackedB = BufferBPacked.GetBuffer(PackedBSize, true); - MlasGemmPackB(TransB, N, K, B, ldb, PackedB); - MlasGemm(TransA, M, N, K, T(alpha), A, lda, PackedB, T(beta), C, ldc, threadpool); + void* PackedB = BufferBPacked.GetBuffer(PackedBSize * BatchSize, true); + std::vector data(BatchSize); + for (size_t i = 0; i < BatchSize; i++) { + MlasGemmPackB(TransB, N, K, B + K * N * i, ldb, (uint8_t*)PackedB + PackedBSize * i); + data[i].BIsPacked = true; + data[i].A = A + M * K * i; + data[i].lda = lda; + data[i].B = (float*)((uint8_t*)PackedB + PackedBSize * i); + data[i].ldb = ldb; + data[i].C = C + M * N * i; + data[i].ldc = ldc; + data[i].alpha = alpha; + data[i].beta = beta; + } + MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool); + + MlasGemm(TransA, M, N, K, alpha, A, lda, PackedB, beta, C, ldc, threadpool); } private: @@ -89,14 +152,14 @@ class MlasFgemmTest : public MlasTestBase { MlasFgemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) { } - void Test(size_t M, size_t N, size_t K, float alpha, float beta) { - Test(false, false, M, N, K, alpha, beta); - Test(false, true, M, N, K, alpha, beta); - Test(true, false, M, N, K, alpha, beta); - Test(true, true, M, N, K, alpha, beta); + void Test(size_t M, size_t N, size_t K, size_t BatchSize, T alpha, T beta) { + Test(false, false, M, N, K, BatchSize, alpha, beta); + Test(false, true, M, N, K, BatchSize, alpha, beta); + Test(true, false, M, N, K, BatchSize, alpha, beta); + Test(true, true, M, N, K, BatchSize, alpha, beta); } - void Test(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, float alpha, float beta) { + void Test(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, size_t BatchSize, T alpha, T beta) { // // Skip the test if the B buffer cannot be packed. // @@ -105,14 +168,14 @@ class MlasFgemmTest : public MlasTestBase { return; } - const T* A = BufferA.GetBuffer(K * M); - const T* B = BufferB.GetBuffer(N * K); - T* C = BufferC.GetBuffer(N * M); - T* CReference = BufferCReference.GetBuffer(N * M); + const T* A = BufferA.GetBuffer(K * M * BatchSize); + const T* B = BufferB.GetBuffer(N * K * BatchSize); + T* C = BufferC.GetBuffer(N * M * BatchSize); + T* CReference = BufferCReference.GetBuffer(N * M * BatchSize); Test(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, - M, N, K, alpha, A, trans_a ? M : K, B, trans_b ? K : N, + M, N, K, BatchSize, alpha, A, trans_a ? M : K, B, trans_b ? K : N, beta, C, CReference, N); } @@ -121,33 +184,36 @@ class MlasFgemmTest : public MlasTestBase { size_t M, size_t N, size_t K, - float alpha, + size_t BatchSize, + T alpha, const T* A, size_t lda, const T* B, size_t ldb, - float beta, + T beta, T* C, T* CReference, size_t ldc) { - std::fill_n(C, M * N, -0.5f); - std::fill_n(CReference, M * N, -0.5f); + std::fill_n(C, M * N * BatchSize, -0.5f); + std::fill_n(CReference, M * N * BatchSize, -0.5f); - PackedContext.TestGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, threadpool_); - ReferenceGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, CReference, ldc); + PackedContext.TestGemm(TransA, TransB, M, N, K, BatchSize, alpha, A, lda, B, ldb, beta, C, ldc, threadpool_); + ReferenceGemm(TransA, TransB, M, N, K, BatchSize, alpha, A, lda, B, ldb, beta, CReference, ldc); - for (size_t m = 0, f = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - // Sensitive to comparing positive/negative zero. - ASSERT_EQ(C[f], CReference[f]) - << " Diff @[" << m << ", " << n << "] f=" << f << ", " - << (Packed ? "Packed" : "NoPack") << "." - << (Threaded ? "SingleThread" : "Threaded") << "/" - << (TransA == CblasTrans ? "TransA" : "A") << "/" - << (TransB == CblasTrans ? "TransB" : "B") << "/" - << "M" << M << "xN" << N << "xK" << K << "/" - << "Alpha" << alpha << "/" - << "Beta" << beta; + for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + // Sensitive to comparing positive/negative zero. + ASSERT_EQ(C[f], CReference[f]) + << " Diff @[" << batch << ", " << m << ", " << n << "] f=" << f << ", " + << (Packed ? "Packed" : "NoPack") << "." + << (Threaded ? "SingleThread" : "Threaded") << "/" + << (TransA == CblasTrans ? "TransA" : "A") << "/" + << (TransB == CblasTrans ? "TransB" : "B") << "/" + << "M" << M << "xN" << N << "xK" << K << "/" + << "Alpha" << alpha << "/" + << "Beta" << beta; + } } } } @@ -157,111 +223,120 @@ class MlasFgemmTest : public MlasTestBase { size_t M, size_t N, size_t K, - float alpha, + size_t BatchSize, + T alpha, const T* A, size_t lda, const T* B, size_t ldb, - float beta, + T beta, T* C, size_t ldc) { - if (TransA == CblasNoTrans) { - if (TransB == CblasNoTrans) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const T* a = A + (m * lda); - const T* b = B + n; - T* c = C + (m * ldc) + n; - T sum = 0.0f; + for (size_t batch = 0; batch < BatchSize; batch++) { + if (TransA == CblasNoTrans) { + if (TransB == CblasNoTrans) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const T* a = A + (m * lda); + const T* b = B + n; + T* c = C + (m * ldc) + n; + T sum = 0.0f; - for (size_t k = 0; k < K; k++) { - sum += (*b * *a); - b += ldb; - a += 1; + for (size_t k = 0; k < K; k++) { + sum += (*b * *a); + b += ldb; + a += 1; + } + + *c = (*c * beta) + (sum * alpha); } + } - *c = (*c * beta) + (sum * alpha); + } else { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const T* a = A + (m * lda); + const T* b = B + (n * ldb); + T* c = C + (m * ldc) + n; + T sum = 0.0f; + + for (size_t k = 0; k < K; k++) { + sum += (*b * *a); + b += 1; + a += 1; + } + + *c = (*c * beta) + (sum * alpha); + } } } } else { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const T* a = A + (m * lda); - const T* b = B + (n * ldb); - T* c = C + (m * ldc) + n; - T sum = 0.0f; + if (TransB == CblasNoTrans) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const T* a = A + m; + const T* b = B + n; + T* c = C + (m * ldc) + n; + T sum = 0.0f; - for (size_t k = 0; k < K; k++) { - sum += (*b * *a); - b += 1; - a += 1; + for (size_t k = 0; k < K; k++) { + sum += (*b * *a); + b += ldb; + a += lda; + } + + *c = (*c * beta) + (sum * alpha); } + } - *c = (*c * beta) + (sum * alpha); - } - } - } - - } else { - if (TransB == CblasNoTrans) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const T* a = A + m; - const T* b = B + n; - T* c = C + (m * ldc) + n; - T sum = 0.0f; - - for (size_t k = 0; k < K; k++) { - sum += (*b * *a); - b += ldb; - a += lda; - } - - *c = (*c * beta) + (sum * alpha); - } - } - - } else { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const T* a = A + m; - const T* b = B + (n * ldb); - T* c = C + (m * ldc) + n; - T sum = 0.0f; - - for (size_t k = 0; k < K; k++) { - sum += (*b * *a); - b += 1; - a += lda; - } - - *c = (*c * beta) + (sum * alpha); + } else { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const T* a = A + m; + const T* b = B + (n * ldb); + T* c = C + (m * ldc) + n; + T sum = 0.0f; + + for (size_t k = 0; k < K; k++) { + sum += (*b * *a); + b += 1; + a += lda; + } + + *c = (*c * beta) + (sum * alpha); + } } } } + A += M * K; + B += K * N; + C += M * N; } } void ExecuteLong() override { - static const float multipliers[] = {0.0f, -0.0f, 0.25f, -0.5f, 1.0f, -1.0f}; + static const T multipliers[] = {0.0f, -0.0f, 0.25f, -0.5f, 1.0f, -1.0f}; for (size_t N = 1; N < 128; N++) { for (size_t K = 1; K < 128; K++) { for (size_t a = 0; a < _countof(multipliers); a++) { for (size_t b = 0; b < _countof(multipliers); b++) { - Test(1, N, K, multipliers[a], multipliers[b]); - Test(N, 1, K, multipliers[a], multipliers[b]); + Test(1, N, K, 1, multipliers[a], multipliers[b]); + Test(N, 1, K, 1, multipliers[a], multipliers[b]); + if (!Packed) { + Test(1, N, K, 3, multipliers[a], multipliers[b]); + } } } } } for (size_t a = 0; a < _countof(multipliers); a++) { - float alpha = multipliers[a]; + T alpha = multipliers[a]; for (size_t b = 0; b < _countof(multipliers); b++) { - float beta = multipliers[b]; + T beta = multipliers[b]; for (size_t M = 16; M < 160; M += 32) { for (size_t N = 16; N < 160; N += 32) { @@ -269,21 +344,25 @@ class MlasFgemmTest : public MlasTestBase { for (size_t k = 0; k < _countof(ks); k++) { size_t K = ks[k]; - Test(M, N, K, alpha, beta); - Test(M + 1, N, K, alpha, beta); - Test(M, N + 1, K, alpha, beta); - Test(M + 1, N + 1, K, alpha, beta); - Test(M + 3, N + 2, K, alpha, beta); - Test(M + 4, N, K, alpha, beta); - Test(M, N + 4, K, alpha, beta); - Test(M + 4, N + 4, K, alpha, beta); - Test(M + 3, N + 7, K, alpha, beta); - Test(M + 8, N, K, alpha, beta); - Test(M, N + 8, K, alpha, beta); - Test(M + 12, N + 12, K, alpha, beta); - Test(M + 13, N, K, alpha, beta); - Test(M, N + 15, K, alpha, beta); - Test(M + 15, N + 15, K, alpha, beta); + Test(M, N, K, 1, alpha, beta); + Test(M + 1, N, K, 1, alpha, beta); + Test(M, N + 1, K, 1, alpha, beta); + Test(M + 1, N + 1, K, 1, alpha, beta); + Test(M + 3, N + 2, K, 1, alpha, beta); + Test(M + 4, N, K, 1, alpha, beta); + Test(M, N + 4, K, 1, alpha, beta); + Test(M + 4, N + 4, K, 1, alpha, beta); + Test(M + 3, N + 7, K, 1, alpha, beta); + Test(M + 8, N, K, 1, alpha, beta); + Test(M, N + 8, K, 1, alpha, beta); + Test(M + 12, N + 12, K, 1, alpha, beta); + Test(M + 13, N, K, 1, alpha, beta); + Test(M, N + 15, K, 1, alpha, beta); + Test(M + 15, N + 15, K, 1, alpha, beta); + if (!Packed) { + Test(M + 3, N + 1, K, 7, multipliers[a], multipliers[b]); + Test(M + 13, N + 2, K, 9, multipliers[a], multipliers[b]); + } } } printf("a %zd/%zd b %zd/%zd M %zd\n", a, _countof(multipliers), b, _countof(multipliers), M); @@ -294,7 +373,7 @@ class MlasFgemmTest : public MlasTestBase { for (size_t M = 0; M < 160; M++) { for (size_t N = 0; N < 160; N++) { for (size_t K = 0; K < 160; K++) { - Test(M, N, K, 1.0f, 0.0f); + Test(M, N, K, 1, 1.0f, 0.0f); } } printf("M %zd\n", M); @@ -303,10 +382,10 @@ class MlasFgemmTest : public MlasTestBase { for (size_t M = 160; M < 320; M += 24) { for (size_t N = 112; N < 320; N += 24) { for (size_t K = 0; K < 16; K++) { - Test(M, N, K, 1.0f, 0.0f); + Test(M, N, K, 1, 1.0f, 0.0f); } for (size_t K = 16; K < 160; K += 32) { - Test(M, N, K, 1.0f, 0.0f); + Test(M, N, K, 1, 1.0f, 0.0f); } } printf("M %zd\n", M); diff --git a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h index 2780619809..8a0a936e79 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h +++ b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h @@ -14,20 +14,20 @@ template class FgemmShortExecuteTest : public MlasTestFixture> { public: - explicit FgemmShortExecuteTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, float alpha, float beta) - : trans_a_(trans_a), trans_b_(trans_b), M_(M), N_(N), K_(K), alpha_(alpha), beta_(beta) { + explicit FgemmShortExecuteTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, size_t BatchSize, float alpha, float beta) + : trans_a_(trans_a), trans_b_(trans_b), M_(M), N_(N), K_(K), Batch_(BatchSize), alpha_(alpha), beta_(beta) { } void TestBody() override { MlasTestFixture>::mlas_tester->Test( - trans_a_, trans_b_, M_, N_, K_, alpha_, beta_); + trans_a_, trans_b_, M_, N_, K_, Batch_, alpha_, beta_); } - static size_t RegisterSingleTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, float alpha, float beta) { + static size_t RegisterSingleTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, size_t BatchSize, float alpha, float beta) { std::stringstream ss; ss << (trans_a ? "TransA" : "A") << "/" << (trans_b ? "TransB" : "B") << "/" - << "M" << M << "xN" << N << "xK" << K << "/" + << "BatchSize" << BatchSize << "/M" << M << "xN" << N << "xK" << K << "/" << "Alpha" << alpha << "/" << "Beta" << beta; auto test_name = ss.str(); @@ -42,37 +42,39 @@ class FgemmShortExecuteTest : public MlasTestFixture MlasTestFixture>* { return new FgemmShortExecuteTest( - trans_a, trans_b, M, N, K, alpha, beta); + trans_a, trans_b, M, N, K, BatchSize, alpha, beta); }); return 1; } - static size_t RegisterTestTransposeABProduct(size_t M, size_t N, size_t K, float alpha, float beta) { - return RegisterSingleTest(false, false, M, N, K, alpha, beta) + - RegisterSingleTest(false, true, M, N, K, alpha, beta) + - RegisterSingleTest(true, false, M, N, K, alpha, beta) + - RegisterSingleTest(true, true, M, N, K, alpha, beta); + static size_t RegisterTestTransposeABProduct(size_t M, size_t N, size_t K, size_t BatchSize, float alpha, float beta) { + return RegisterSingleTest(false, false, M, N, K, BatchSize, alpha, beta) + + RegisterSingleTest(false, true, M, N, K, BatchSize, alpha, beta) + + RegisterSingleTest(true, false, M, N, K, BatchSize, alpha, beta) + + RegisterSingleTest(true, true, M, N, K, BatchSize, alpha, beta); } static size_t RegisterShortExecuteTests() { size_t test_registered = 0; for (size_t b = 0; b < 16; b++) { - test_registered += RegisterTestTransposeABProduct(b, b, b, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(b, b, b, 1, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(b, b, b, 3, 1.0f, 0.0f); } for (size_t b = 16; b <= 256; b <<= 1) { - test_registered += RegisterTestTransposeABProduct(b, b, b, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(b, b, b, 1, 1.0f, 0.0f); } for (size_t b = 256; b < 320; b += 32) { - test_registered += RegisterTestTransposeABProduct(b, b, b, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(b, b, b, 1, 1.0f, 0.0f); } - test_registered += RegisterTestTransposeABProduct(128, 3072, 768, 1.0f, 0.0f); - test_registered += RegisterTestTransposeABProduct(128, 768, 3072, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(128, 3072, 768, 1, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(128, 768, 3072, 1, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(25, 81, 79, 7, 1.0f, 0.0f); return test_registered; } private: bool trans_a_, trans_b_; - size_t M_, N_, K_; - float alpha_, beta_; + const size_t M_, N_, K_, Batch_; + const T alpha_, beta_; };