diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 66b53c6041..14da273af9 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -293,15 +293,15 @@ Return Value: #if defined(MLAS_NEON_INTRINSICS) vst4q_f32(D, vld4q_f32(b)); #else - MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(b); - MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(b + 4); - MLAS_FLOAT32X4 t2 = MlasLoadFloat32x4(b + 8); - MLAS_FLOAT32X4 t3 = MlasLoadFloat32x4(b + 12); + MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]); + MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[4]); + MLAS_FLOAT32X4 t2 = MlasLoadFloat32x4(&b[8]); + MLAS_FLOAT32X4 t3 = MlasLoadFloat32x4(&b[12]); - MlasStoreAlignedFloat32x4(D, t0); - MlasStoreAlignedFloat32x4(D + 4, t1); - MlasStoreAlignedFloat32x4(D + 8, t2); - MlasStoreAlignedFloat32x4(D + 12, t3); + MlasStoreAlignedFloat32x4(&D[0], t0); + MlasStoreAlignedFloat32x4(&D[4], t1); + MlasStoreAlignedFloat32x4(&D[8], t2); + MlasStoreAlignedFloat32x4(&D[12], t3); #endif D += 16; @@ -335,7 +335,7 @@ Return Value: const float* b = B; #if defined(MLAS_NEON_INTRINSICS) - vst4q_f32(D, ZeroFloat32x4x4); + vst4q_f32(d, ZeroFloat32x4x4); #else MlasStoreAlignedFloat32x4(d, ZeroFloat32x4); MlasStoreAlignedFloat32x4(d + 4, ZeroFloat32x4); @@ -376,7 +376,7 @@ Return Value: } if ((CountX & 1) != 0) { - *d = *b; + d[0] = b[0]; } D += 16; @@ -387,11 +387,10 @@ Return Value: } } -#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) - +template inline void -MlasSgemmTransposePackB16x4( +MlasSgemmTransposePackBNx4( float* D, const float* B, size_t ldb @@ -403,7 +402,7 @@ Routine Description: This routine transposes elements from the source matrix to the destination packed buffer. - 4 columns of 16 rows from the source matrix are transposed to 16 columns of 4 + 4 columns of N rows from the source matrix are transposed to N columns of 4 rows in the destination packed buffer. Arguments: @@ -420,7 +419,7 @@ Return Value: --*/ { - for (unsigned n = 0; n < 4; n++) { + for (unsigned n = 0; n < N / 4; n++) { MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&B[ldb * 0]); MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&B[ldb * 1]); @@ -436,8 +435,20 @@ Return Value: t1 = o0.val[1]; t2 = o1.val[0]; t3 = o1.val[1]; +#elif defined(MLAS_SSE2_INTRINSICS) + // N.B. The MSVC version of _MM_TRANSPOSE4_PS uses shufps which is + // slightly larger than the below sequence, so manually expand the + // matrix transpose. + __m128 z0 = _mm_unpacklo_ps(t0, t1); + __m128 z1 = _mm_unpackhi_ps(t0, t1); + __m128 z2 = _mm_unpacklo_ps(t2, t3); + __m128 z3 = _mm_unpackhi_ps(t2, t3); + t0 = _mm_movelh_ps(z0, z2); + t1 = _mm_movehl_ps(z2, z0); + t2 = _mm_movelh_ps(z1, z3); + t3 = _mm_movehl_ps(z3, z1); #else - _MM_TRANSPOSE4_PS(t0, t1, t2, t3); +#error Unsupported architecture. #endif MlasStoreAlignedFloat32x4(&D[0], t0); @@ -450,8 +461,6 @@ Return Value: } } -#endif - void MlasSgemmTransposePackB( float* D, @@ -508,19 +517,19 @@ Return Value: SgemmTransposePackB16x4Routine(&D[0], &b[0], ldb); - b += 4; D += 16 * 4; + b += 4; x -= 4; } -#elif defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) +#else while (x >= 4) { - MlasSgemmTransposePackB16x4(&D[0], &b[0], ldb); + MlasSgemmTransposePackBNx4<16>(&D[0], &b[0], ldb); - b += 4; D += 16 * 4; + b += 4; x -= 4; } @@ -571,26 +580,128 @@ Return Value: CountY -= 16; } + // + // Special case the handling of the less than 16 remaining rows. + // + if (CountY > 0) { - // - // Zero pad the remaining entries of the packed buffer. - // + MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4(); - memset(D, 0, CountX * 16 * sizeof(float)); + size_t x = CountX; // - // Transpose elements from matrix B into the packed buffer for the - // remaining rows. + // Transpose 4 columns at a time. // - if (CountY >= 8) { + while (x >= 4) { float* d = D; const float* b = B; - size_t x = CountX; - do { + if ((CountY & 8) != 0) { + + MlasSgemmTransposePackBNx4<8>(&d[0], &b[0], ldb); + + d += 8; + b += ldb * 8; + + } else { + + MlasStoreAlignedFloat32x4(&d[8], ZeroFloat32x4); + MlasStoreAlignedFloat32x4(&d[12], ZeroFloat32x4); + MlasStoreAlignedFloat32x4(&d[24], ZeroFloat32x4); + MlasStoreAlignedFloat32x4(&d[28], ZeroFloat32x4); + MlasStoreAlignedFloat32x4(&d[40], ZeroFloat32x4); + MlasStoreAlignedFloat32x4(&d[44], ZeroFloat32x4); + MlasStoreAlignedFloat32x4(&d[56], ZeroFloat32x4); + MlasStoreAlignedFloat32x4(&d[60], ZeroFloat32x4); + } + + if ((CountY & 4) != 0) { + + MlasSgemmTransposePackBNx4<4>(&d[0], &b[0], ldb); + + d += 4; + b += ldb * 4; + + } else { + + MlasStoreAlignedFloat32x4(&d[4], ZeroFloat32x4); + MlasStoreAlignedFloat32x4(&d[20], ZeroFloat32x4); + MlasStoreAlignedFloat32x4(&d[36], ZeroFloat32x4); + MlasStoreAlignedFloat32x4(&d[52], ZeroFloat32x4); + } + + MlasStoreAlignedFloat32x4(&d[0], ZeroFloat32x4); + MlasStoreAlignedFloat32x4(&d[16], ZeroFloat32x4); + MlasStoreAlignedFloat32x4(&d[32], ZeroFloat32x4); + MlasStoreAlignedFloat32x4(&d[48], ZeroFloat32x4); + + if ((CountY & 2) != 0) { + +#if defined(MLAS_NEON_INTRINSICS) + MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]); + MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[ldb]); + + MlasStoreLaneFloat32x4<0>(&d[0], t0); + MlasStoreLaneFloat32x4<0>(&d[1], t1); + MlasStoreLaneFloat32x4<1>(&d[16], t0); + MlasStoreLaneFloat32x4<1>(&d[17], t1); + MlasStoreLaneFloat32x4<2>(&d[32], t0); + MlasStoreLaneFloat32x4<2>(&d[33], t1); + MlasStoreLaneFloat32x4<3>(&d[48], t0); + MlasStoreLaneFloat32x4<3>(&d[49], t1); +#elif defined(MLAS_SSE2_INTRINSICS) + MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]); + MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[ldb]); + + __m128 v0 = _mm_unpacklo_ps(t0, t1); + __m128 v1 = _mm_unpackhi_ps(t0, t1); + _mm_storel_pi((__m64*)&d[0], v0); + _mm_storeh_pi((__m64*)&d[16], v0); + _mm_storel_pi((__m64*)&d[32], v1); + _mm_storeh_pi((__m64*)&d[48], v1); +#else +#error Unsupported architecture. +#endif + + d += 2; + b += ldb * 2; + } + + if ((CountY & 1) != 0) { + +#if defined(MLAS_NEON_INTRINSICS) + MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]); + + MlasStoreLaneFloat32x4<0>(&d[0], t0); + MlasStoreLaneFloat32x4<1>(&d[16], t0); + MlasStoreLaneFloat32x4<2>(&d[32], t0); + MlasStoreLaneFloat32x4<3>(&d[48], t0); +#else + d[0] = b[0]; + d[16] = b[1]; + d[32] = b[2]; + d[48] = b[3]; +#endif + } + + D += 16 * 4; + B += 4; + x -= 4; + } + + // + // Transpose the remaining columns. + // + + while (x > 0) { + + float* d = D; + const float* b = B; + + if ((CountY & 8) != 0) { float t0 = b[0]; float t1 = b[ldb]; @@ -610,24 +721,16 @@ Return Value: d[6] = t6; d[7] = t7; - d += 16; - b += 1; - x--; + d += 8; + b += ldb * 8; - } while (x > 0); + } else { - D += 8; - B += ldb * 8; - CountY -= 8; - } + MlasStoreAlignedFloat32x4(&d[8], ZeroFloat32x4); + MlasStoreAlignedFloat32x4(&d[12], ZeroFloat32x4); + } - if (CountY >= 4) { - - float* d = D; - const float* b = B; - size_t x = CountX; - - do { + if ((CountY & 4) != 0) { float t0 = b[0]; float t1 = b[ldb]; @@ -639,24 +742,17 @@ Return Value: d[2] = t2; d[3] = t3; - d += 16; - b += 1; - x--; + d += 4; + b += ldb * 4; - } while (x > 0); + } else { - D += 4; - B += ldb * 4; - CountY -= 4; - } + MlasStoreAlignedFloat32x4(&d[4], ZeroFloat32x4); + } - if (CountY >= 2) { + MlasStoreAlignedFloat32x4(d, ZeroFloat32x4); - float* d = D; - const float* b = B; - size_t x = CountX; - - do { + if ((CountY & 2) != 0) { float t0 = b[0]; float t1 = b[ldb]; @@ -664,32 +760,17 @@ Return Value: d[0] = t0; d[1] = t1; - d += 16; - b += 1; - x--; - - } while (x > 0); - - D += 2; - B += ldb * 2; - CountY -= 2; - } - - if (CountY >= 1) { - - float* d = D; - const float* b = B; - size_t x = CountX; - - do { + d += 2; + b += ldb * 2; + } + if ((CountY & 1) != 0) { d[0] = b[0]; + } - d += 16; - b += 1; - x--; - - } while (x > 0); + D += 16; + B += 1; + x--; } } }