/*++ Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License. Module Name: sgemm.cpp Abstract: This module implements the single precision matrix/matrix multiply operation (SGEMM). --*/ #include "mlasi.h" // // Define the number of rows from matrix A to transpose to a local buffer. // // N.B. AVX processes a maximum of 4 rows, FMA3 processes a maximum of 6 // rows, and AVX512F processes a maximum of 12 rows. // #define MLAS_SGEMM_TRANSA_ROWS 12 // // Define the parameters to execute segments of a SGEMM operation on worker // threads. // struct MLAS_SGEMM_WORK_BLOCK { CBLAS_TRANSPOSE TransA; CBLAS_TRANSPOSE TransB; size_t K; size_t lda; size_t ldb; size_t ldc; float alpha; float beta; struct SEGMENT { size_t M; size_t N; const float* A; const float* B; float* C; } Segments[MLAS_MAXIMUM_THREAD_COUNT]; }; void MlasSgemmMultiplyBeta( float* C, size_t CountM, size_t CountN, size_t ldc, float beta ) /*++ Routine Description: This routine multiplies all elements of the output matrix by the beta scalar value. Arguments: C - Supplies the address of matrix C. CountM - Supplies the number of rows from matrix C. CountN - Supplies the number of columns from matrix C. ldc - Supplies the first dimension of matrix C. beta - Supplies the scalar beta multiplier (see SGEMM definition). Return Value: None. --*/ { MLAS_FLOAT32X4 BetaBroadcast = MlasBroadcastFloat32x4(beta); do { float* c = C; size_t n = CountN; while (n >= 4) { MlasStoreFloat32x4(c, MlasMultiplyFloat32x4(MlasLoadFloat32x4(c), BetaBroadcast)); c += 4; n -= 4; } while (n > 0) { #if defined(MLAS_SSE2_INTRINSICS) _mm_store_ss(c, _mm_mul_ss(_mm_load_ss(c), BetaBroadcast)); #else *c = *c * beta; #endif c += 1; n -= 1; } C += ldc; CountM--; } while (CountM > 0); } void MlasSgemmTransposeA( float* D, const float* A, size_t lda, size_t CountY, size_t CountX ) /*++ Routine Description: This routine transposes elements from the source matrix to the destination buffer. Arguments: D - Supplies the address of the destination buffer. A - Supplies the address of the source matrix. lda - Supplies the number of elements per row of the source matrix. CountY - Supplies the number of columns of the source matrix to transpose. CountX - Supplies the number of rows of the source matrix to transpose. Return Value: None. --*/ { size_t ldd = CountX; // // Transpose elements from matrix A into the destination buffer 4 columns // at a time. // while (CountX >= 4) { float* d = D; const float* a = A; size_t y = CountY; do { float t0 = a[0]; float t1 = a[lda]; float t2 = a[lda * 2]; float t3 = a[lda * 3]; d[0] = t0; d[1] = t1; d[2] = t2; d[3] = t3; d += ldd; a += 1; y--; } while (y > 0); D += 4; A += lda * 4; CountX -= 4; } // // Transpose elements from matrix A into the destination buffer for the // remaining columns. // if (CountX >= 2) { float* d = D; const float* a = A; size_t y = CountY; do { float t0 = a[0]; float t1 = a[lda]; d[0] = t0; d[1] = t1; d += ldd; a += 1; y--; } while (y > 0); D += 2; A += lda * 2; CountX -= 2; } if (CountX >= 1) { float* d = D; const float* a = A; size_t y = CountY; do { d[0] = a[0]; d += ldd; a += 1; y--; } while (y > 0); } } void MlasSgemmCopyPackB( float* D, const float* B, size_t ldb, size_t CountX, size_t CountY ) /*++ Routine Description: This routine copies elements from the source matrix to the destination packed buffer. Columns of 16 elements from the source matrix are unrolled to be physically contiguous for better locality inside the SGEMM kernels. Any remaining columns less than 16 elements wide are zero-padded. Arguments: D - Supplies the address of the destination packed buffer. B - Supplies the address of the source matrix. ldb - Supplies the number of elements per row of the source matrix. CountX - Supplies the number of columns of the source matrix to copy. CountY - Supplies the number of rows of the source matrix to copy. Return Value: None. --*/ { // // Copy data from matrix B into the destination buffer 16 columns at a // time. // while (CountX >= 16) { const float* b = B; size_t y = CountY; do { #if defined(MLAS_NEON_INTRINSICS) vst4q_f32(D, vld4q_f32(b)); #else MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]); MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[4]); MLAS_FLOAT32X4 t2 = MlasLoadFloat32x4(&b[8]); MLAS_FLOAT32X4 t3 = MlasLoadFloat32x4(&b[12]); MlasStoreAlignedFloat32x4(&D[0], t0); MlasStoreAlignedFloat32x4(&D[4], t1); MlasStoreAlignedFloat32x4(&D[8], t2); MlasStoreAlignedFloat32x4(&D[12], t3); #endif D += 16; b += ldb; y--; } while (y > 0); B += 16; CountX -= 16; } // // Special case the handling of the remaining columns less than 16 elements // wide. // if (CountX > 0) { MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4(); #if defined(MLAS_NEON_INTRINSICS) float32x4x4_t ZeroFloat32x4x4 = { ZeroFloat32x4, ZeroFloat32x4, ZeroFloat32x4, ZeroFloat32x4 }; #endif size_t y = CountY; do { float* d = D; const float* b = B; #if defined(MLAS_NEON_INTRINSICS) vst4q_f32(d, ZeroFloat32x4x4); #else MlasStoreAlignedFloat32x4(d, ZeroFloat32x4); MlasStoreAlignedFloat32x4(d + 4, ZeroFloat32x4); MlasStoreAlignedFloat32x4(d + 8, ZeroFloat32x4); MlasStoreAlignedFloat32x4(d + 12, ZeroFloat32x4); #endif if ((CountX & 8) != 0) { MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(b); MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(b + 4); MlasStoreAlignedFloat32x4(d, t0); MlasStoreAlignedFloat32x4(d + 4, t1); d += 8; b += 8; } if ((CountX & 4) != 0) { MlasStoreAlignedFloat32x4(d, MlasLoadFloat32x4(b)); d += 4; b += 4; } if ((CountX & 2) != 0) { float t0 = b[0]; float t1 = b[1]; d[0] = t0; d[1] = t1; d += 2; b += 2; } if ((CountX & 1) != 0) { d[0] = b[0]; } D += 16; B += ldb; y--; } while (y > 0); } } template inline void MlasSgemmTransposePackBNx4( float* D, const float* B, size_t ldb ) /*++ Routine Description: This routine transposes elements from the source matrix to the destination packed buffer. 4 columns of N rows from the source matrix are transposed to N columns of 4 rows in the destination packed buffer. Arguments: D - Supplies the address of the destination packed buffer. B - Supplies the address of the source matrix. ldb - Supplies the number of elements per row of the source matrix. Return Value: None. --*/ { for (unsigned n = 0; n < N / 4; n++) { MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&B[ldb * 0]); MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&B[ldb * 1]); MLAS_FLOAT32X4 t2 = MlasLoadFloat32x4(&B[ldb * 2]); MLAS_FLOAT32X4 t3 = MlasLoadFloat32x4(&B[ldb * 3]); #if defined(MLAS_NEON_INTRINSICS) float32x4x2_t z0 = vzipq_f32(t0, t2); float32x4x2_t z1 = vzipq_f32(t1, t3); float32x4x2_t o0 = vzipq_f32(z0.val[0], z1.val[0]); float32x4x2_t o1 = vzipq_f32(z0.val[1], z1.val[1]); t0 = o0.val[0]; t1 = o0.val[1]; t2 = o1.val[0]; t3 = o1.val[1]; #else MLAS_FLOAT32X4 z0 = MlasInterleaveLowFloat32x4(t0, t2); MLAS_FLOAT32X4 z1 = MlasInterleaveHighFloat32x4(t0, t2); MLAS_FLOAT32X4 z2 = MlasInterleaveLowFloat32x4(t1, t3); MLAS_FLOAT32X4 z3 = MlasInterleaveHighFloat32x4(t1, t3); t0 = MlasInterleaveLowFloat32x4(z0, z2); t1 = MlasInterleaveHighFloat32x4(z0, z2); t2 = MlasInterleaveLowFloat32x4(z1, z3); t3 = MlasInterleaveHighFloat32x4(z1, z3); #endif MlasStoreAlignedFloat32x4(&D[0], t0); MlasStoreAlignedFloat32x4(&D[16], t1); MlasStoreAlignedFloat32x4(&D[32], t2); MlasStoreAlignedFloat32x4(&D[48], t3); D += 4; B += ldb * 4; } } void MlasSgemmTransposePackB( float* D, const float* B, size_t ldb, size_t CountY, size_t CountX ) /*++ Routine Description: This routine transposes elements from the source matrix to the destination packed buffer. Columns of 16 elements from the source matrix are unrolled to be physically contiguous for better locality inside the SGEMM kernels. Any remaining columns less than 16 elements wide are zero-padded. Arguments: D - Supplies the address of the destination packed buffer. B - Supplies the address of the source matrix. ldb - Supplies the number of elements per row of the source matrix. CountY - Supplies the number of rows of the source matrix to transpose. CountX - Supplies the number of columns of the source matrix to transpose. Return Value: None. --*/ { // // Transpose elements from matrix B into the packed buffer 16 rows at a // time. // while (CountY >= 16) { const float* b = B; size_t x = CountX; #if defined(MLAS_TARGET_AMD64) PMLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE SgemmTransposePackB16x4Routine = MlasPlatform.TransposePackB16x4Routine; while (x >= 4) { SgemmTransposePackB16x4Routine(&D[0], &b[0], ldb); D += 16 * 4; b += 4; x -= 4; } #else while (x >= 4) { MlasSgemmTransposePackBNx4<16>(&D[0], &b[0], ldb); D += 16 * 4; b += 4; x -= 4; } #endif while (x > 0) { float t0 = b[0]; float t1 = b[ldb]; float t2 = b[ldb * 2]; float t3 = b[ldb * 3]; float t4 = b[ldb * 4]; float t5 = b[ldb * 5]; float t6 = b[ldb * 6]; float t7 = b[ldb * 7]; float t8 = b[ldb * 8]; float t9 = b[ldb * 9]; float t10 = b[ldb * 10]; float t11 = b[ldb * 11]; float t12 = b[ldb * 12]; float t13 = b[ldb * 13]; float t14 = b[ldb * 14]; float t15 = b[ldb * 15]; D[0] = t0; D[1] = t1; D[2] = t2; D[3] = t3; D[4] = t4; D[5] = t5; D[6] = t6; D[7] = t7; D[8] = t8; D[9] = t9; D[10] = t10; D[11] = t11; D[12] = t12; D[13] = t13; D[14] = t14; D[15] = t15; D += 16; b += 1; x--; } B += ldb * 16; CountY -= 16; } // // Special case the handling of the less than 16 remaining rows. // if (CountY > 0) { MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4(); size_t x = CountX; // // Transpose 4 columns at a time. // while (x >= 4) { float* d = D; const float* b = B; if ((CountY & 8) != 0) { MlasSgemmTransposePackBNx4<8>(&d[0], &b[0], ldb); d += 8; b += ldb * 8; } else { MlasStoreAlignedFloat32x4(&d[8], ZeroFloat32x4); MlasStoreAlignedFloat32x4(&d[12], ZeroFloat32x4); MlasStoreAlignedFloat32x4(&d[24], ZeroFloat32x4); MlasStoreAlignedFloat32x4(&d[28], ZeroFloat32x4); MlasStoreAlignedFloat32x4(&d[40], ZeroFloat32x4); MlasStoreAlignedFloat32x4(&d[44], ZeroFloat32x4); MlasStoreAlignedFloat32x4(&d[56], ZeroFloat32x4); MlasStoreAlignedFloat32x4(&d[60], ZeroFloat32x4); } if ((CountY & 4) != 0) { MlasSgemmTransposePackBNx4<4>(&d[0], &b[0], ldb); d += 4; b += ldb * 4; } else { MlasStoreAlignedFloat32x4(&d[4], ZeroFloat32x4); MlasStoreAlignedFloat32x4(&d[20], ZeroFloat32x4); MlasStoreAlignedFloat32x4(&d[36], ZeroFloat32x4); MlasStoreAlignedFloat32x4(&d[52], ZeroFloat32x4); } MlasStoreAlignedFloat32x4(&d[0], ZeroFloat32x4); MlasStoreAlignedFloat32x4(&d[16], ZeroFloat32x4); MlasStoreAlignedFloat32x4(&d[32], ZeroFloat32x4); MlasStoreAlignedFloat32x4(&d[48], ZeroFloat32x4); if ((CountY & 2) != 0) { MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]); MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[ldb]); #if defined(MLAS_SSE2_INTRINSICS) __m128 v0 = _mm_unpacklo_ps(t0, t1); __m128 v1 = _mm_unpackhi_ps(t0, t1); _mm_storel_pi((__m64*)&d[0], v0); _mm_storeh_pi((__m64*)&d[16], v0); _mm_storel_pi((__m64*)&d[32], v1); _mm_storeh_pi((__m64*)&d[48], v1); #else MlasStoreLaneFloat32x4<0>(&d[0], t0); MlasStoreLaneFloat32x4<0>(&d[1], t1); MlasStoreLaneFloat32x4<1>(&d[16], t0); MlasStoreLaneFloat32x4<1>(&d[17], t1); MlasStoreLaneFloat32x4<2>(&d[32], t0); MlasStoreLaneFloat32x4<2>(&d[33], t1); MlasStoreLaneFloat32x4<3>(&d[48], t0); MlasStoreLaneFloat32x4<3>(&d[49], t1); #endif d += 2; b += ldb * 2; } if ((CountY & 1) != 0) { #if defined(MLAS_NEON_INTRINSICS) MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]); MlasStoreLaneFloat32x4<0>(&d[0], t0); MlasStoreLaneFloat32x4<1>(&d[16], t0); MlasStoreLaneFloat32x4<2>(&d[32], t0); MlasStoreLaneFloat32x4<3>(&d[48], t0); #else d[0] = b[0]; d[16] = b[1]; d[32] = b[2]; d[48] = b[3]; #endif } D += 16 * 4; B += 4; x -= 4; } // // Transpose the remaining columns. // while (x > 0) { float* d = D; const float* b = B; if ((CountY & 8) != 0) { float t0 = b[0]; float t1 = b[ldb]; float t2 = b[ldb * 2]; float t3 = b[ldb * 3]; float t4 = b[ldb * 4]; float t5 = b[ldb * 5]; float t6 = b[ldb * 6]; float t7 = b[ldb * 7]; d[0] = t0; d[1] = t1; d[2] = t2; d[3] = t3; d[4] = t4; d[5] = t5; d[6] = t6; d[7] = t7; d += 8; b += ldb * 8; } else { MlasStoreAlignedFloat32x4(&d[8], ZeroFloat32x4); MlasStoreAlignedFloat32x4(&d[12], ZeroFloat32x4); } if ((CountY & 4) != 0) { float t0 = b[0]; float t1 = b[ldb]; float t2 = b[ldb * 2]; float t3 = b[ldb * 3]; d[0] = t0; d[1] = t1; d[2] = t2; d[3] = t3; d += 4; b += ldb * 4; } else { MlasStoreAlignedFloat32x4(&d[4], ZeroFloat32x4); } MlasStoreAlignedFloat32x4(d, ZeroFloat32x4); if ((CountY & 2) != 0) { float t0 = b[0]; float t1 = b[ldb]; d[0] = t0; d[1] = t1; d += 2; b += ldb * 2; } if ((CountY & 1) != 0) { d[0] = b[0]; } D += 16; B += 1; x--; } } } MLAS_FORCEINLINE float* MlasSgemmKernelLoop( const float* A, const float* B, float* C, size_t CountK, size_t CountM, size_t CountN, size_t lda, size_t ldc, float alpha, bool ZeroMode ) /*++ Routine Description: This routine steps through the rows of the input and output matrices calling the kernel until all rows have been processed. Arguments: A - Supplies the address of matrix A. B - Supplies the address of matrix B. The matrix data has been packed using MlasSgemmCopyPackB or MlasSgemmTransposePackB. C - Supplies the address of matrix C. CountK - Supplies the number of columns from matrix A and the number of rows from matrix B to iterate over. CountM - Supplies the number of rows from matrix A and matrix C to iterate over. CountN - Supplies the number of columns from matrix B and matrix C to iterate over. lda - Supplies the first dimension of matrix A. ldc - Supplies the first dimension of matrix C. alpha - Supplies the scalar alpha multiplier (see SGEMM definition). ZeroMode - Supplies true if the output matrix must be zero initialized, else false if the output matrix is accumulated into. Return Value: Returns the next address of matrix C. --*/ { do { size_t RowsHandled; #if defined(MLAS_TARGET_AMD64_IX86) RowsHandled = MlasPlatform.GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); #elif defined(MLAS_TARGET_POWER) RowsHandled = MlasSgemmKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); #else if (ZeroMode) { RowsHandled = MlasSgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); } else { RowsHandled = MlasSgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); } #endif C += ldc * RowsHandled; A += lda * RowsHandled; CountM -= RowsHandled; } while (CountM > 0); return C; } void MlasSgemmOperation( CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t M, size_t N, size_t K, float alpha, const float* A, size_t lda, const float* B, size_t ldb, float beta, float* C, size_t ldc ) /*++ Routine Description: This routine implements the single precision matrix/matrix multiply operation (SGEMM). Arguments: TransA - Supplies the transpose operation for matrix A. TransB - Supplies the transpose operation for matrix B. M - Supplies the number of rows of matrix A and matrix C. N - Supplies the number of columns of matrix B and matrix C. K - Supplies the number of columns of matrix A and the number of rows of matrix B. alpha - Supplies the scalar alpha multiplier (see SGEMM definition). A - Supplies the address of matrix A. lda - Supplies the first dimension of matrix A. B - Supplies the address of matrix B. ldb - Supplies the first dimension of matrix B. beta - Supplies the scalar beta multiplier (see SGEMM definition). C - Supplies the address of matrix C. ldc - Supplies the first dimension of matrix C. Return Value: None. --*/ { float PanelA[MLAS_SGEMM_TRANSA_ROWS * MLAS_SGEMM_STRIDEK]; MLAS_DECLSPEC_ALIGN(float PanelB[MLAS_SGEMM_STRIDEN * MLAS_SGEMM_STRIDEK], 16 * sizeof(float)); // // Handle the special case of a small M. The data from matrix B is not // referenced multiple times, so using a local packed buffer is a wasted // memory copy. // if (M == 1 && TransA == CblasNoTrans && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) { #if defined(MLAS_TARGET_AMD64) PMLAS_SGEMM_KERNEL_M1_ROUTINE SgemmKernelM1Routine; if (TransB == CblasNoTrans) { SgemmKernelM1Routine = MlasPlatform.KernelM1Routine; } else { SgemmKernelM1Routine = MlasPlatform.KernelM1TransposeBRoutine; } if (SgemmKernelM1Routine != nullptr) { SgemmKernelM1Routine(A, B, C, K, N, ldb, beta); return; } #endif } // // Handle the case when both B and C are column-vectors that are contiguous in memory. // Because transposition of such vectors doesn't change their layout, and // Transpose(A*B) = Transpose(B) * Transpose(A), we can apply the same 'small-M' // optimization as above, with A and B flipped. // if (N == 1 && ldb == 1 && ldc == 1 && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) { #if defined(MLAS_TARGET_AMD64) PMLAS_SGEMM_KERNEL_M1_ROUTINE SgemmKernelM1Routine; if (TransA == CblasNoTrans) { SgemmKernelM1Routine = MlasPlatform.KernelM1TransposeBRoutine; } else { SgemmKernelM1Routine = MlasPlatform.KernelM1Routine; } if (SgemmKernelM1Routine != nullptr) { SgemmKernelM1Routine(B, A, C, K, M, lda, beta); return; } #endif } // // Compute the strides to step through slices of the input matrices. // // Expand the N stride if K is small or expand the K stride if N is small // for better utilization of the B panel. Avoid changing the K stride if // the A panel needs to be used for transposing. // uint32_t StrideN = MLAS_SGEMM_STRIDEN; uint32_t StrideK = MLAS_SGEMM_STRIDEK; if (N >= K) { while (StrideK / 2 >= K) { StrideN *= 2; StrideK /= 2; } } else if (TransA == CblasNoTrans) { while (StrideN > 16 && StrideN / 2 >= N) { StrideK *= 2; StrideN /= 2; } } // // Step through each slice of matrix B along the N dimension. // size_t CountN; size_t CountK; for (size_t n = 0; n < N; n += CountN) { CountN = StrideN; if (CountN > (N - n)) { CountN = N - n; } // // Multiply the output matrix by beta as needed. // if (beta != 0.0f && beta != 1.0f) { MlasSgemmMultiplyBeta(C + n, M, CountN, ldc, beta); } // // Step through each slice of matrix B along the K dimension. // bool ZeroMode = (beta == 0.0f); for (size_t k = 0; k < K; k += CountK) { CountK = StrideK; if (CountK > (K - k)) { CountK = K - k; } // // Copy or transpose a panel of matrix B to a local packed buffer. // if (TransB == CblasNoTrans) { MlasSgemmCopyPackB(PanelB, B + n + k * ldb, ldb, CountN, CountK); } else { MlasSgemmTransposePackB(PanelB, B + k + n * ldb, ldb, CountN, CountK); } // // Step through each slice of matrix A along the M dimension. // float* c = C + n; if (TransA == CblasNoTrans) { MlasSgemmKernelLoop(A + k, PanelB, c, CountK, M, CountN, lda, ldc, alpha, ZeroMode); } else { const float* a = A + k * lda; size_t RowsRemaining = M; do { // // Transpose elements from matrix A into a local buffer. // size_t RowsTransposed = RowsRemaining; if (RowsTransposed > MLAS_SGEMM_TRANSA_ROWS) { RowsTransposed = MLAS_SGEMM_TRANSA_ROWS; } MlasSgemmTransposeA(PanelA, a, lda, RowsTransposed, CountK); RowsRemaining -= RowsTransposed; a += RowsTransposed; // // Step through the rows of the local buffer. // c = MlasSgemmKernelLoop(PanelA, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode); } while (RowsRemaining > 0); } ZeroMode = false; } } } void MlasSgemmOperationThreaded( void* Context, int32_t Index ) /*++ Routine Description: This routine is invoked from a worker thread to execute a segment of a SGEMM operation. Arguments: Context - Supplies the pointer to the context for the threaded operation. Index - Supplies the current index of the threaded operation. Return Value: None. --*/ { MLAS_SGEMM_WORK_BLOCK* WorkBlock = (MLAS_SGEMM_WORK_BLOCK*)Context; MLAS_SGEMM_WORK_BLOCK::SEGMENT* Segment = &WorkBlock->Segments[Index]; MlasSgemmOperation(WorkBlock->TransA, WorkBlock->TransB, Segment->M, Segment->N, WorkBlock->K, WorkBlock->alpha, Segment->A, WorkBlock->lda, Segment->B, WorkBlock->ldb, WorkBlock->beta, Segment->C, WorkBlock->ldc); } inline bool MlasSgemmTryMultithread( CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t M, size_t N, size_t K, float alpha, const float* A, size_t lda, const float* B, size_t ldb, float beta, float* C, size_t ldc, MLAS_THREADPOOL* ThreadPool ) /*++ Routine Description: This routine attempts to launch a single precision matrix/matrix multiply operation (SGEMM) across multiple threads. Arguments: TransA - Supplies the transpose operation for matrix A. TransB - Supplies the transpose operation for matrix B. M - Supplies the number of rows of matrix A and matrix C. N - Supplies the number of columns of matrix B and matrix C. K - Supplies the number of columns of matrix A and the number of rows of matrix B. alpha - Supplies the scalar alpha multiplier (see SGEMM definition). A - Supplies the address of matrix A. lda - Supplies the first dimension of matrix A. B - Supplies the address of matrix B. ldb - Supplies the first dimension of matrix B. beta - Supplies the scalar beta multiplier (see SGEMM definition). C - Supplies the address of matrix C. ldc - Supplies the first dimension of matrix C. ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. Return Value: Returns true if the operation was completed across multiple threads, else false if the operation should fall back to a single thread. --*/ { MLAS_SGEMM_WORK_BLOCK WorkBlock; int32_t TargetThreadCount; // // Compute the number of target threads given the complexity of the SGEMM // operation. Small requests should run using the single threaded path. // double Complexity = double(M) * double(N) * double(K); if (Complexity < double(MLAS_SGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) { TargetThreadCount = int32_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; } else { TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT; } int32_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); if (TargetThreadCount >= MaximumThreadCount) { TargetThreadCount = MaximumThreadCount; } if (TargetThreadCount == 1) { return false; } // // Initialize the common fields of the work block. // WorkBlock.TransA = TransA; WorkBlock.TransB = TransB; WorkBlock.K = K; WorkBlock.lda = lda; WorkBlock.ldb = ldb; WorkBlock.ldc = ldc; WorkBlock.alpha = alpha; WorkBlock.beta = beta; // // Segment the operation across multiple threads. // int32_t Index = 0; if (N > M) { size_t StrideN = N / TargetThreadCount; if ((StrideN * TargetThreadCount) != N) { StrideN++; } StrideN = (StrideN + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); size_t pldb = (TransB == CblasNoTrans) ? 1 : ldb; for (size_t CountN, n = 0; n < N; n += CountN) { CountN = StrideN; if (CountN > (N - n)) { CountN = N - n; } WorkBlock.Segments[Index].M = M; WorkBlock.Segments[Index].N = CountN; WorkBlock.Segments[Index].A = A; WorkBlock.Segments[Index].B = B + n * pldb; WorkBlock.Segments[Index].C = C + n; Index++; } } else { size_t StrideM = M / TargetThreadCount; if ((StrideM * TargetThreadCount) != M) { StrideM++; } size_t plda = (TransA == CblasNoTrans) ? lda : 1; for (size_t CountM, m = 0; m < M; m += CountM) { CountM = StrideM; if (CountM > (M - m)) { CountM = M - m; } WorkBlock.Segments[Index].M = CountM; WorkBlock.Segments[Index].N = N; WorkBlock.Segments[Index].A = A + m * plda; WorkBlock.Segments[Index].B = B; WorkBlock.Segments[Index].C = C + m * ldc; Index++; } } MlasExecuteThreaded(MlasSgemmOperationThreaded, &WorkBlock, Index, ThreadPool); return true; } void MLASCALL MlasGemm( CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t M, size_t N, size_t K, float alpha, const float* A, size_t lda, const float* B, size_t ldb, float beta, float* C, size_t ldc, MLAS_THREADPOOL* ThreadPool ) /*++ Routine Description: This routine implements the single precision matrix/matrix multiply operation (SGEMM). Arguments: TransA - Supplies the transpose operation for matrix A. TransB - Supplies the transpose operation for matrix B. M - Supplies the number of rows of matrix A and matrix C. N - Supplies the number of columns of matrix B and matrix C. K - Supplies the number of columns of matrix A and the number of rows of matrix B. alpha - Supplies the scalar alpha multiplier (see SGEMM definition). A - Supplies the address of matrix A. lda - Supplies the first dimension of matrix A. B - Supplies the address of matrix B. ldb - Supplies the first dimension of matrix B. beta - Supplies the scalar beta multiplier (see SGEMM definition). C - Supplies the address of matrix C. ldc - Supplies the first dimension of matrix C. ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. Return Value: None. --*/ { // // Try to run the operation across multiple threads or fall back to a // single thread based on the GEMM parameters and system configuration. // if (!MlasSgemmTryMultithread(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, ThreadPool)) { MlasSgemmOperation(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } }