From ecb2e119e468facdd04a49cf04e2d9d70084ad95 Mon Sep 17 00:00:00 2001 From: Tracy Sharpe <42477615+tracysh@users.noreply.github.com> Date: Wed, 30 Dec 2020 23:25:10 -0800 Subject: [PATCH] MLAS: handle MlasGemm(M/N/K==0) cases (#6238) --- onnxruntime/core/mlas/lib/dgemm.cpp | 523 ++++++++++++++-------------- onnxruntime/core/mlas/lib/sgemm.cpp | 54 +-- onnxruntime/test/mlas/unittest.cpp | 29 +- 3 files changed, 312 insertions(+), 294 deletions(-) diff --git a/onnxruntime/core/mlas/lib/dgemm.cpp b/onnxruntime/core/mlas/lib/dgemm.cpp index 29805a97f1..48c29e0d37 100644 --- a/onnxruntime/core/mlas/lib/dgemm.cpp +++ b/onnxruntime/core/mlas/lib/dgemm.cpp @@ -32,21 +32,21 @@ Abstract: // struct MLAS_DGEMM_WORK_BLOCK { + int32_t ThreadCountM; + int32_t ThreadCountN; CBLAS_TRANSPOSE TransA; CBLAS_TRANSPOSE TransB; + size_t M; + size_t N; size_t K; + const double* A; size_t lda; + const double* B; size_t ldb; + double* C; size_t ldc; double alpha; double beta; - struct SEGMENT { - size_t M; - size_t N; - const double* A; - const double* B; - double* C; - } Segments[MLAS_MAXIMUM_THREAD_COUNT]; }; #ifdef MLAS_TARGET_AMD64 @@ -86,7 +86,7 @@ Return Value: { MLAS_FLOAT64X2 BetaBroadcast = MlasBroadcastFloat64x2(beta); - do { + while (CountM-- > 0) { double* c = C; size_t n = CountN; @@ -106,9 +106,7 @@ Return Value: } C += ldc; - CountM--; - - } while (CountM > 0); + } } void @@ -497,6 +495,82 @@ Return Value: } } +MLAS_FORCEINLINE +double* +MlasDgemmKernelLoop( + const double* A, + const double* B, + double* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + double alpha, + bool ZeroMode + ) +/*++ + +Routine Description: + + This routine steps through the rows of the input and output matrices calling + the kernel until all rows have been processed. + +Arguments: + + A - Supplies the address of matrix A. + + B - Supplies the address of matrix B. The matrix data has been packed using + MlasDgemmCopyPackB or MlasDgemmTransposePackB. + + C - Supplies the address of matrix C. + + CountK - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + CountM - Supplies the number of rows from matrix A and matrix C to iterate + over. + + CountN - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda - Supplies the first dimension of matrix A. + + ldc - Supplies the first dimension of matrix C. + + alpha - Supplies the scalar alpha multiplier (see DGEMM definition). + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the next address of matrix C. + +--*/ +{ + while (CountM > 0) { + + size_t RowsHandled; + +#if defined(MLAS_TARGET_AMD64_IX86) + RowsHandled = MlasPlatform.GemmDoubleKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); +#else + if (ZeroMode) { + RowsHandled = MlasDgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); + } else { + RowsHandled = MlasDgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); + } +#endif + + C += ldc * RowsHandled; + A += lda * RowsHandled; + CountM -= RowsHandled; + } + + return C; +} + void MlasDgemmOperation( CBLAS_TRANSPOSE TransA, @@ -558,6 +632,16 @@ Return Value: double PanelA[MLAS_DGEMM_TRANSA_ROWS * MLAS_DGEMM_STRIDEK]; MLAS_DECLSPEC_ALIGN(double PanelB[MLAS_DGEMM_STRIDEN * MLAS_DGEMM_STRIDEK], 8 * sizeof(double)); + // + // Handle the special case of K equals zero. Apply the beta multiplier to + // the output matrix and exit. + // + + if (K == 0) { + MlasDgemmMultiplyBeta(C, M, N, ldc, beta); + return; + } + // // Compute the strides to step through slices of the input matrices. // @@ -566,8 +650,8 @@ Return Value: // the A panel needs to be used for transposing. // - uint32_t StrideN = MLAS_DGEMM_STRIDEN; - uint32_t StrideK = MLAS_DGEMM_STRIDEK; + size_t StrideN = MLAS_DGEMM_STRIDEN; + size_t StrideK = MLAS_DGEMM_STRIDEK; if (N >= K) { @@ -589,15 +673,10 @@ Return Value: // size_t CountN; - size_t CountK; for (size_t n = 0; n < N; n += CountN) { - CountN = StrideN; - - if (CountN > (N - n)) { - CountN = N - n; - } + CountN = std::min(N - n, StrideN); // // Multiply the output matrix by beta as needed. @@ -611,15 +690,12 @@ Return Value: // Step through each slice of matrix B along the K dimension. // + size_t CountK; + bool ZeroMode = (beta == 0.0f); + for (size_t k = 0; k < K; k += CountK) { - bool ZeroMode = (k == 0 && beta == 0.0f); - - CountK = StrideK; - - if (CountK > (K - k)) { - CountK = K - k; - } + CountK = std::min(K - k, StrideK); // // Copy or transpose a panel of matrix B to a local packed buffer. @@ -637,93 +713,45 @@ Return Value: double* c = C + n; - size_t RowsRemaining = M; - size_t RowsHandled; - if (TransA == CblasNoTrans) { - const double* a = A + k; - - // - // Step through the rows of matrix A. - // - - do { - -#if defined(MLAS_TARGET_AMD64_IX86) - RowsHandled = MlasPlatform.GemmDoubleKernel(a, PanelB, c, CountK, RowsRemaining, CountN, lda, ldc, alpha, ZeroMode); -#else - if (ZeroMode) { - RowsHandled = MlasDgemmKernelZero(a, PanelB, c, CountK, RowsRemaining, CountN, lda, ldc, alpha); - } else { - RowsHandled = MlasDgemmKernelAdd(a, PanelB, c, CountK, RowsRemaining, CountN, lda, ldc, alpha); - } -#endif - - c += ldc * RowsHandled; - a += lda * RowsHandled; - - RowsRemaining -= RowsHandled; - - } while (RowsRemaining > 0); + MlasDgemmKernelLoop(A + k, PanelB, c, CountK, M, CountN, lda, ldc, alpha, ZeroMode); } else { const double* a = A + k * lda; + size_t RowsRemaining = M; - do { + while (RowsRemaining > 0) { // // Transpose elements from matrix A into a local buffer. // - size_t RowsTransposed = RowsRemaining; - - if (RowsTransposed > MLAS_DGEMM_TRANSA_ROWS) { - RowsTransposed = MLAS_DGEMM_TRANSA_ROWS; - } - - RowsRemaining -= RowsTransposed; + size_t RowsTransposed = std::min(RowsRemaining, size_t(MLAS_DGEMM_TRANSA_ROWS)); MlasDgemmTransposeA(PanelA, a, lda, RowsTransposed, CountK); + RowsRemaining -= RowsTransposed; a += RowsTransposed; // // Step through the rows of the local buffer. // - const double* pa = PanelA; - - do { - -#if defined(MLAS_TARGET_AMD64_IX86) - RowsHandled = MlasPlatform.GemmDoubleKernel(pa, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode); -#else - if (ZeroMode) { - RowsHandled = MlasDgemmKernelZero(pa, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha); - } else { - RowsHandled = MlasDgemmKernelAdd(pa, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha); - } -#endif - - c += ldc * RowsHandled; - pa += CountK * RowsHandled; - - RowsTransposed -= RowsHandled; - - } while (RowsTransposed > 0); - - } while (RowsRemaining > 0); + c = MlasDgemmKernelLoop(PanelA, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode); + } } + + ZeroMode = false; } } } void -MlasDgemmOperationThreaded( +MlasDgemmThreaded( void* Context, - int32_t Index + int32_t ThreadId ) /*++ @@ -736,7 +764,7 @@ Arguments: Context - Supplies the pointer to the context for the threaded operation. - Index - Supplies the current index of the threaded operation. + ThreadId - Supplies the current index of the threaded operation. Return Value: @@ -744,19 +772,147 @@ Return Value: --*/ { - MLAS_DGEMM_WORK_BLOCK* WorkBlock = (MLAS_DGEMM_WORK_BLOCK*)Context; + const auto* WorkBlock = (MLAS_DGEMM_WORK_BLOCK*)Context; - MLAS_DGEMM_WORK_BLOCK::SEGMENT* Segment = &WorkBlock->Segments[Index]; + const int32_t ThreadCountM = WorkBlock->ThreadCountM; + const int32_t ThreadCountN = WorkBlock->ThreadCountN; - MlasDgemmOperation(WorkBlock->TransA, WorkBlock->TransB, Segment->M, - Segment->N, WorkBlock->K, WorkBlock->alpha, Segment->A, WorkBlock->lda, - Segment->B, WorkBlock->ldb, WorkBlock->beta, Segment->C, - WorkBlock->ldc); + const int32_t ThreadIdM = ThreadId / ThreadCountN; + const int32_t ThreadIdN = ThreadId % ThreadCountN; + + // + // Partition the operation along the M dimension. + // + + size_t M = WorkBlock->M; + size_t RangeStartM; + size_t RangeCountM; + + MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); + + // + // Partition the operation along the N dimension. + // + + size_t N = WorkBlock->N; + size_t RangeStartN; + size_t RangeCountN; + + const size_t BlockedN = (N + MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1) / + MLAS_DGEMM_STRIDEN_THREAD_ALIGN; + + MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, + &RangeCountN); + + RangeStartN *= MLAS_DGEMM_STRIDEN_THREAD_ALIGN; + RangeCountN *= MLAS_DGEMM_STRIDEN_THREAD_ALIGN; + + RangeCountN = std::min(N - RangeStartN, RangeCountN); + + // + // Dispatch the partitioned operation. + // + + CBLAS_TRANSPOSE TransA = WorkBlock->TransA; + CBLAS_TRANSPOSE TransB = WorkBlock->TransB; + + const size_t lda = WorkBlock->lda; + const size_t ldb = WorkBlock->ldb; + const size_t ldc = WorkBlock->ldc; + + const double* A = WorkBlock->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1); + const double* B = WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb); + double* C = WorkBlock->C + RangeStartM * ldc + RangeStartN; + + MlasDgemmOperation(TransA, TransB, RangeCountM, RangeCountN, WorkBlock->K, + WorkBlock->alpha, A, lda, B, ldb, WorkBlock->beta, C, ldc); } -inline -bool -MlasDgemmTryMultithread( +void +MlasDgemmSchedule( + MLAS_DGEMM_WORK_BLOCK* WorkBlock, + MLAS_THREADPOOL* ThreadPool + ) +/*++ + +Routine Description: + + This routine schedules the double precision matrix/matrix multiply + operation (DGEMM) across one or more threads. + +Arguments: + + WorkBlock - Supplies the structure containing the GEMM parameters. + + ThreadPool - Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + +Return Value: + + None. + +--*/ +{ + const size_t M = WorkBlock->M; + const size_t N = WorkBlock->N; + const size_t K = WorkBlock->K; + + // + // Compute the number of target threads given the complexity of the DGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K); + + int32_t TargetThreadCount; + + if (Complexity < double(MLAS_DGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) { + TargetThreadCount = int32_t(Complexity / double(MLAS_DGEMM_THREAD_COMPLEXITY)) + 1; + } else { + TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT; + } + + int32_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + // + // Segment the operation across multiple threads. + // + // N.B. Currently, the operation is segmented as a 1D partition, which + // works okay for operations involving skinny matrices. + // + + if (N > M) { + + const size_t BlockedN = (N + MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1) / + MLAS_DGEMM_STRIDEN_THREAD_ALIGN; + + if (size_t(TargetThreadCount) > BlockedN) { + TargetThreadCount = int32_t(BlockedN); + } + + WorkBlock->ThreadCountM = 1; + WorkBlock->ThreadCountN = TargetThreadCount; + + } else { + + if (size_t(TargetThreadCount) > M) { + TargetThreadCount = int32_t(M); + } + + WorkBlock->ThreadCountM = TargetThreadCount; + WorkBlock->ThreadCountN = 1; + } + + MlasExecuteThreaded(MlasDgemmThreaded, WorkBlock, TargetThreadCount, ThreadPool); +} + +void +MLASCALL +MlasGemm( CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t M, @@ -776,8 +932,8 @@ MlasDgemmTryMultithread( Routine Description: - This routine attempts to launch a single precision matrix/matrix multiply - operation (DGEMM) across multiple threads. + This routine implements the double precision matrix/matrix multiply + operation (DGEMM). Arguments: @@ -813,190 +969,37 @@ Arguments: Return Value: - Returns true if the operation was completed across multiple threads, else - false if the operation should fall back to a single thread. + None. --*/ { MLAS_DGEMM_WORK_BLOCK WorkBlock; - int32_t TargetThreadCount; // - // Compute the number of target threads given the complexity of the DGEMM - // operation. Small requests should run using the single threaded path. + // Capture the GEMM parameters to the work block. // - double Complexity = double(M) * double(N) * double(K); - - if (Complexity < double(MLAS_DGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) { - TargetThreadCount = int32_t(Complexity / double(MLAS_DGEMM_THREAD_COMPLEXITY)) + 1; - } else { - TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT; - } - - int32_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); - - if (TargetThreadCount >= MaximumThreadCount) { - TargetThreadCount = MaximumThreadCount; - } - - if (TargetThreadCount == 1) { - return false; - } - - // - // Initialize the common fields of the work block. - // + memset(&WorkBlock, 0, sizeof(MLAS_DGEMM_WORK_BLOCK)); WorkBlock.TransA = TransA; WorkBlock.TransB = TransB; + WorkBlock.M = M; + WorkBlock.N = N; WorkBlock.K = K; + WorkBlock.A = A; WorkBlock.lda = lda; + WorkBlock.B = B; WorkBlock.ldb = ldb; + WorkBlock.C = C; WorkBlock.ldc = ldc; WorkBlock.alpha = alpha; WorkBlock.beta = beta; // - // Segment the operation across multiple threads. + // Schedule the operation across a set of worker threads. // - int32_t Index = 0; - - if (N > M) { - - size_t StrideN = N / TargetThreadCount; - - if ((StrideN * TargetThreadCount) != N) { - StrideN++; - } - - StrideN = - (StrideN + MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1); - - size_t pldb = (TransB == CblasNoTrans) ? 1 : ldb; - - for (size_t CountN, n = 0; n < N; n += CountN) { - - CountN = StrideN; - - if (CountN > (N - n)) { - CountN = N - n; - } - - WorkBlock.Segments[Index].M = M; - WorkBlock.Segments[Index].N = CountN; - WorkBlock.Segments[Index].A = A; - WorkBlock.Segments[Index].B = B + n * pldb; - WorkBlock.Segments[Index].C = C + n; - - Index++; - } - - } else { - - size_t StrideM = M / TargetThreadCount; - - if ((StrideM * TargetThreadCount) != M) { - StrideM++; - } - - size_t plda = (TransA == CblasNoTrans) ? lda : 1; - - for (size_t CountM, m = 0; m < M; m += CountM) { - - CountM = StrideM; - - if (CountM > (M - m)) { - CountM = M - m; - } - - WorkBlock.Segments[Index].M = CountM; - WorkBlock.Segments[Index].N = N; - WorkBlock.Segments[Index].A = A + m * plda; - WorkBlock.Segments[Index].B = B; - WorkBlock.Segments[Index].C = C + m * ldc; - - Index++; - } - } - - MlasExecuteThreaded(MlasDgemmOperationThreaded, &WorkBlock, Index, ThreadPool); - - return true; -} - -void -MLASCALL -MlasGemm( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - double alpha, - const double* A, - size_t lda, - const double* B, - size_t ldb, - double beta, - double* C, - size_t ldc, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine implements the single precision matrix/matrix multiply - operation (SGEMM). - -Arguments: - - TransA - Supplies the transpose operation for matrix A. - - TransB - Supplies the transpose operation for matrix B. - - M - Supplies the number of rows of matrix A and matrix C. - - N - Supplies the number of columns of matrix B and matrix C. - - K - Supplies the number of columns of matrix A and the number of rows of - matrix B. - - alpha - Supplies the scalar alpha multiplier (see SGEMM definition). - - A - Supplies the address of matrix A. - - lda - Supplies the first dimension of matrix A. - - B - Supplies the address of matrix B. - - ldb - Supplies the first dimension of matrix B. - - beta - Supplies the scalar beta multiplier (see SGEMM definition). - - C - Supplies the address of matrix C. - - ldc - Supplies the first dimension of matrix C. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ -{ - // - // Try to run the operation across multiple threads or fall back to a - // single thread based on the GEMM parameters and system configuration. - // - - if (!MlasDgemmTryMultithread(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, ThreadPool)) { - MlasDgemmOperation(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); - } + MlasDgemmSchedule(&WorkBlock, ThreadPool); } #endif diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 7489d89095..31052a27fc 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -41,13 +41,13 @@ struct MLAS_SGEMM_WORK_BLOCK { size_t K; const float* A; size_t lda; - const float* B; + const void* B; size_t ldb; - const void* PackedB; float* C; size_t ldc; float alpha; float beta; + bool BIsPacked; }; void @@ -85,7 +85,7 @@ Return Value: { MLAS_FLOAT32X4 BetaBroadcast = MlasBroadcastFloat32x4(beta); - do { + while (CountM-- > 0) { float* c = C; size_t n = CountN; @@ -107,9 +107,7 @@ Return Value: } C += ldc; - CountM--; - - } while (CountM > 0); + } } void @@ -807,7 +805,7 @@ Return Value: --*/ { - do { + while (CountM > 0) { size_t RowsHandled; @@ -826,8 +824,7 @@ Return Value: C += ldc * RowsHandled; A += lda * RowsHandled; CountM -= RowsHandled; - - } while (CountM > 0); + } return C; } @@ -893,6 +890,16 @@ Return Value: float PanelA[MLAS_SGEMM_TRANSA_ROWS * MLAS_SGEMM_STRIDEK]; MLAS_DECLSPEC_ALIGN(float PanelB[MLAS_SGEMM_STRIDEN * MLAS_SGEMM_STRIDEK], 16 * sizeof(float)); + // + // Handle the special case of K equals zero. Apply the beta multiplier to + // the output matrix and exit. + // + + if (K == 0) { + MlasSgemmMultiplyBeta(C, M, N, ldc, beta); + return; + } + // // Handle the special case of a small M. The data from matrix B is not // referenced multiple times, so using a local packed buffer is a wasted @@ -1035,7 +1042,7 @@ Return Value: const float* a = A + k * lda; size_t RowsRemaining = M; - do { + while (RowsRemaining > 0) { // // Transpose elements from matrix A into a local buffer. @@ -1053,8 +1060,7 @@ Return Value: // c = MlasSgemmKernelLoop(PanelA, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode); - - } while (RowsRemaining > 0); + } } ZeroMode = false; @@ -1171,7 +1177,7 @@ Return Value: const float* a = A + k * lda; size_t RowsRemaining = M; - do { + while (RowsRemaining > 0) { // // Transpose elements from matrix A into a local buffer. @@ -1189,8 +1195,7 @@ Return Value: // c = MlasSgemmKernelLoop(PanelA, pb, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode); - - } while (RowsRemaining > 0); + } } ZeroMode = false; @@ -1271,22 +1276,22 @@ Return Value: const float* A = WorkBlock->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1); float* C = WorkBlock->C + RangeStartM * ldc + RangeStartN; - if (WorkBlock->B != nullptr) { + if (WorkBlock->BIsPacked) { + + MlasSgemmPackedOperation(TransA, RangeCountM, RangeStartN, RangeCountN, + WorkBlock->K, WorkBlock->alpha, A, lda, WorkBlock->B, + BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, WorkBlock->beta, C, ldc); + + } else { CBLAS_TRANSPOSE TransB = WorkBlock->TransB; const size_t ldb = WorkBlock->ldb; - const float* B = WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb); + const float* B = (const float*)WorkBlock->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb); MlasSgemmOperation(TransA, TransB, RangeCountM, RangeCountN, WorkBlock->K, WorkBlock->alpha, A, lda, B, ldb, WorkBlock->beta, C, ldc); - - } else { - - MlasSgemmPackedOperation(TransA, RangeCountM, RangeStartN, RangeCountN, - WorkBlock->K, WorkBlock->alpha, A, lda, WorkBlock->PackedB, - BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, WorkBlock->beta, C, ldc); } } @@ -1535,11 +1540,12 @@ Return Value: WorkBlock.K = K; WorkBlock.A = A; WorkBlock.lda = lda; - WorkBlock.PackedB = PackedB; + WorkBlock.B = PackedB; WorkBlock.C = C; WorkBlock.ldc = ldc; WorkBlock.alpha = alpha; WorkBlock.beta = beta; + WorkBlock.BIsPacked = true; // // Schedule the operation across a set of worker threads. diff --git a/onnxruntime/test/mlas/unittest.cpp b/onnxruntime/test/mlas/unittest.cpp index 451655b214..9b42a29f1b 100644 --- a/onnxruntime/test/mlas/unittest.cpp +++ b/onnxruntime/test/mlas/unittest.cpp @@ -202,10 +202,10 @@ public: }; template -class MlasFgemmTestBase; +class FgemmPackedContext; template -class MlasFgemmTestBase : public MlasTestBase +class FgemmPackedContext { public: void @@ -230,7 +230,7 @@ public: }; template -class MlasFgemmTestBase : public MlasTestBase +class FgemmPackedContext { public: void @@ -261,7 +261,7 @@ private: }; template -class MlasFgemmTest : public MlasFgemmTestBase +class MlasFgemmTest : public MlasTestBase { private: void @@ -273,6 +273,14 @@ private: float beta ) { + // + // Skip the test if the B buffer cannot be packed. + // + + if (Packed && (N == 0 || K == 0)) { + return; + } + const T* A = BufferA.GetBuffer(K * M); const T* B = BufferB.GetBuffer(N * K); T* C = BufferC.GetBuffer(N * M); @@ -305,7 +313,7 @@ private: std::fill_n(C, M * N, -0.5f); std::fill_n(CReference, M * N, -0.5f); - this->TestGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + PackedContext.TestGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); ReferenceGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, CReference, ldc); for (size_t f = 0; f < M * N; f++) { @@ -430,6 +438,7 @@ private: MatrixGuardBuffer BufferB; MatrixGuardBuffer BufferC; MatrixGuardBuffer BufferCReference; + FgemmPackedContext PackedContext; public: void @@ -437,7 +446,7 @@ public: void ) override { - for (size_t b = 1; b < 16; b++) { + for (size_t b = 0; b < 16; b++) { Test(b, b, b, 1.0f, 0.0f); } for (size_t b = 16; b <= 256; b <<= 1) { @@ -504,9 +513,9 @@ public: } } - for (size_t M = 1; M < 160; M++) { - for (size_t N = 1; N < 160; N++) { - for (size_t K = 1; K < 160; K++) { + for (size_t M = 0; M < 160; M++) { + for (size_t N = 0; N < 160; N++) { + for (size_t K = 0; K < 160; K++) { Test(M, N, K, 1.0f, 0.0f); } } @@ -515,7 +524,7 @@ public: for (size_t M = 160; M < 320; M += 24) { for (size_t N = 112; N < 320; N += 24) { - for (size_t K = 1; K < 16; K++) { + for (size_t K = 0; K < 16; K++) { Test(M, N, K, 1.0f, 0.0f); } for (size_t K = 16; K < 160; K += 32) {