diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index b4714b8130..6015dc0f86 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -21,20 +21,28 @@ Abstract: // Quantized integer matrix/matrix dispatch structure. // -typedef void(MLAS_GEMM_U8X8_OPERATION)(const MLAS_GEMM_U8X8_SHAPE_PARAMS* Shape, - const MLAS_GEMM_U8X8_DATA_PARAMS* Data, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN); +typedef +void +(MLAS_GEMM_U8X8_OPERATION)( + const MLAS_GEMM_U8X8_SHAPE_PARAMS* Shape, + const MLAS_GEMM_U8X8_DATA_PARAMS* Data, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ); -typedef void(MLAS_GEMM_U8X8_COPY_PACKB_ROUTINE)(uint8_t* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned); +typedef +void +(MLAS_GEMM_U8X8_COPY_PACKB_ROUTINE)( + uint8_t* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ); struct MLAS_GEMM_U8X8_DISPATCH { MLAS_GEMM_U8X8_OPERATION* Operation; @@ -45,7 +53,9 @@ struct MLAS_GEMM_U8X8_DISPATCH { }; const MLAS_GEMM_U8X8_DISPATCH* -MlasGemmU8X8GetDispatch(bool BIsSigned) +MlasGemmU8X8GetDispatch( + bool BIsSigned + ) { const MLAS_GEMM_U8X8_DISPATCH* GemmU8X8Dispatch; @@ -92,7 +102,12 @@ struct MLAS_GEMM_U8X8_STRIDES { }; void -MlasGemmU8X8ScaleSumBuffer(int32_t* Output, const int32_t* Input, size_t N, int32_t Scale) +MlasGemmU8X8ScaleSumBuffer( + int32_t* Output, + const int32_t* Input, + size_t N, + int32_t Scale + ) { for (size_t n = 0; n < N; n++) { Output[n] = Input[n] * Scale; @@ -101,20 +116,27 @@ MlasGemmU8X8ScaleSumBuffer(int32_t* Output, const int32_t* Input, size_t N, int3 MLAS_FORCEINLINE void -MlasGemmU8X8ScaleSumBuffer(int32_t* SumBuffer, size_t N, int32_t Scale) +MlasGemmU8X8ScaleSumBuffer( + int32_t* SumBuffer, + size_t N, + int32_t Scale + ) { return MlasGemmU8X8ScaleSumBuffer(SumBuffer, SumBuffer, N, Scale); } -template -MLAS_FORCEINLINE bool -MlasGemmU8X8TryGemvKernel(const uint8_t* A, - const uint8_t* B, - size_t ldb, - int32_t* C, - size_t CountK, - size_t CountN, - bool BIsSigned) +template +MLAS_FORCEINLINE +bool +MlasGemmU8X8TryGemvKernel( + const uint8_t* A, + const uint8_t* B, + size_t ldb, + int32_t* C, + size_t CountK, + size_t CountN, + bool BIsSigned + ) { MLAS_UNREFERENCED_PARAMETER(A); MLAS_UNREFERENCED_PARAMETER(B); @@ -127,25 +149,32 @@ MlasGemmU8X8TryGemvKernel(const uint8_t* A, return false; } -template +template int32_t -MlasGemmU8X8FixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +MlasGemmU8X8FixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) { MLAS_UNREFERENCED_PARAMETER(BIsSigned); return ZeroPointB; } -template -MLAS_FORCEINLINE void -MlasGemmU8X8FixupZeroPointB(const uint8_t* PackedZeroPointB, - int32_t* ZeroPointBBuffer, - size_t N, - bool BIsSigned) +template +MLAS_FORCEINLINE +void +MlasGemmU8X8FixupZeroPointB( + const uint8_t* PackedZeroPointB, + int32_t* ZeroPointBBuffer, + size_t N, + bool BIsSigned + ) { int32_t ZeroPointB; for (size_t n = 0; n < N; n++) { + ZeroPointB = typename KernelType::OffsetBType(PackedZeroPointB[n]); ZeroPointB = MlasGemmU8X8FixupZeroPointB(ZeroPointB, BIsSigned); @@ -157,55 +186,62 @@ MlasGemmU8X8FixupZeroPointB(const uint8_t* PackedZeroPointB, // against tools that check for uninitialized data usage. // - size_t AlignedN = - (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1); + size_t AlignedN = (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1); for (size_t n = N; n < AlignedN; n++) { ZeroPointBBuffer[n] = 0; } } -template +template void -MlasGemmU8X8CopyPackA(typename KernelType::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer); +MlasGemmU8X8CopyPackA( + typename KernelType::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ); -template +template void -MlasGemmU8X8CopyPackB(typename KernelType::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned); +MlasGemmU8X8CopyPackB( + typename KernelType::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ); -template +template size_t -MlasGemmU8X8Kernel(const typename KernelType::PackedAType* A, - const typename KernelType::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode); +MlasGemmU8X8Kernel( + const typename KernelType::PackedAType* A, + const typename KernelType::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ); -template +template void -MlasGemmU8X8Operation(const MLAS_GEMM_U8X8_SHAPE_PARAMS* Shape, - const MLAS_GEMM_U8X8_DATA_PARAMS* Data, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN) +MlasGemmU8X8Operation( + const MLAS_GEMM_U8X8_SHAPE_PARAMS* Shape, + const MLAS_GEMM_U8X8_DATA_PARAMS* Data, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ) /*++ Routine Description: @@ -251,8 +287,8 @@ Return Value: const uint8_t* A = Data->A + RangeStartM * lda; const uint8_t* B = (const uint8_t*)Data->B + RangeStartN; int32_t* C = Data->C + RangeStartM * ldc + RangeStartN; - const uint8_t* PackedZeroPointB = - Data->PerColumnZeroPoints ? Data->ZeroPointB + RangeStartN : nullptr; + const uint8_t* PackedZeroPointB = Data->PerColumnZeroPoints ? + Data->ZeroPointB + RangeStartN : nullptr; int32_t ZeroPointA = Data->ZeroPointA; int32_t ZeroPointB = typename KernelType::OffsetBType(*Data->ZeroPointB); @@ -261,8 +297,9 @@ Return Value: // Try to use a GEMV kernel if supported by this kernel type. // - if ((RangeCountM == 1) && (ZeroPointA == 0) && (PackedZeroPointB == nullptr) && - (ZeroPointB == 0) && (Data->OutputProcessor == nullptr)) { + if ((RangeCountM == 1) && + (ZeroPointA == 0) && (PackedZeroPointB == nullptr) && (ZeroPointB == 0) && + (Data->OutputProcessor == nullptr)) { if (MlasGemmU8X8TryGemvKernel(A, B, ldb, C, K, RangeCountN, Shape->BIsSigned)) { return; } @@ -283,6 +320,7 @@ Return Value: size_t CountK; for (size_t k = 0; k < K; k += CountK) { + CountK = std::min(K - k, Strides.K); const size_t PackedCountK = (CountK + KernelType::PackedK - 1) / KernelType::PackedK; @@ -294,6 +332,7 @@ Return Value: size_t CountN; for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, Strides.N); // @@ -302,16 +341,25 @@ Return Value: // if (PackedZeroPointB != nullptr) { - MlasGemmU8X8FixupZeroPointB(PackedZeroPointB + n, ZeroPointBBuffer, - CountN, Shape->BIsSigned); + MlasGemmU8X8FixupZeroPointB( + PackedZeroPointB + n, + ZeroPointBBuffer, + CountN, + Shape->BIsSigned); } // // Copy a panel of matrix B to a local packed buffer. // - MlasGemmU8X8CopyPackB(PanelB, B + n, ldb, CountN, CountK, ColumnSumBuffer, - Shape->BIsSigned); + MlasGemmU8X8CopyPackB( + PanelB, + B + n, + ldb, + CountN, + CountK, + ColumnSumBuffer, + Shape->BIsSigned); MlasGemmU8X8ScaleSumBuffer(ColumnSumBuffer, CountN, -ZeroPointA); @@ -323,14 +371,20 @@ Return Value: size_t CountM; for (size_t m = 0; m < RangeCountM; m += CountM) { + CountM = std::min(RangeCountM - m, Strides.M); // // Copy a panel of matrix A to a local packed buffer. // - MlasGemmU8X8CopyPackA(PanelA, A + m * lda, lda, CountM, CountK, - RowSumBuffer); + MlasGemmU8X8CopyPackA( + PanelA, + A + m * lda, + lda, + CountM, + CountK, + RowSumBuffer); // // Apply the global depth value constant without the ZeroPointB scaling from: @@ -367,15 +421,28 @@ Return Value: bool PostProcess = (k + CountK == K); while (RowsRemaining > 0) { + size_t RowsHandled = MlasGemmU8X8Kernel( - pa, PanelB, c, PackedCountK, RowsRemaining, CountN, ldc, RowSums, - ColumnSumBuffer, (PackedZeroPointB != nullptr) ? ZeroPointBBuffer : nullptr, + pa, + PanelB, + c, + PackedCountK, + RowsRemaining, + CountN, + ldc, + RowSums, + ColumnSumBuffer, + (PackedZeroPointB != nullptr) ? ZeroPointBBuffer : nullptr, ZeroMode); if (PostProcess && Data->OutputProcessor != nullptr) { Data->OutputProcessor->Process( - Data->C, RangeStartM + m + CountM - RowsRemaining, RangeStartN + n, - RowsHandled, CountN, Data->ldc); + Data->C, + RangeStartM + m + CountM - RowsRemaining, + RangeStartN + n, + RowsHandled, + CountN, + Data->ldc); } c += ldc * RowsHandled; @@ -391,14 +458,16 @@ Return Value: } } -template +template void -MlasGemmU8X8PackedOperation(const MLAS_GEMM_U8X8_SHAPE_PARAMS* Shape, - const MLAS_GEMM_U8X8_DATA_PARAMS* Data, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN) +MlasGemmU8X8PackedOperation( + const MLAS_GEMM_U8X8_SHAPE_PARAMS* Shape, + const MLAS_GEMM_U8X8_DATA_PARAMS* Data, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ) /*++ Routine Description: @@ -442,8 +511,8 @@ Return Value: const uint8_t* A = Data->A + RangeStartM * lda; const uint8_t* PackedB = (const uint8_t*)Data->B; int32_t* C = Data->C + RangeStartM * ldc + RangeStartN; - const uint8_t* PackedZeroPointB = - Data->PerColumnZeroPoints ? Data->ZeroPointB + RangeStartN : nullptr; + const uint8_t* PackedZeroPointB = Data->PerColumnZeroPoints ? + Data->ZeroPointB + RangeStartN : nullptr; int32_t ZeroPointA = Data->ZeroPointA; int32_t ZeroPointB = typename KernelType::OffsetBType(*Data->ZeroPointB); @@ -473,6 +542,7 @@ Return Value: size_t CountK; for (size_t k = 0; k < K; k += CountK) { + CountK = std::min(K - k, Strides.K); const size_t PackedCountK = (CountK + KernelType::PackedK - 1) / KernelType::PackedK; @@ -488,11 +558,12 @@ Return Value: size_t CountN; for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, Strides.N); if (k == 0) { - MlasGemmU8X8ScaleSumBuffer(ColumnSumBuffer, PackedColumnSumBuffer + n, CountN, - -ZeroPointA); + MlasGemmU8X8ScaleSumBuffer(ColumnSumBuffer, PackedColumnSumBuffer + n, + CountN, -ZeroPointA); } // @@ -501,27 +572,37 @@ Return Value: // if (PackedZeroPointB != nullptr) { - MlasGemmU8X8FixupZeroPointB(PackedZeroPointB + n, ZeroPointBBuffer, - CountN, Shape->BIsSigned); + MlasGemmU8X8FixupZeroPointB( + PackedZeroPointB + n, + ZeroPointBBuffer, + CountN, + Shape->BIsSigned); } // // Step through each slice of matrix A along the M dimension. // - const uint8_t* b = PackedB + (RangeStartN + n) * KernelType::PackedK * PackedCountK; + const uint8_t* b = PackedB + (RangeStartN + n) * + KernelType::PackedK * PackedCountK; int32_t* c = C + n; size_t CountM; for (size_t m = 0; m < RangeCountM; m += CountM) { + CountM = std::min(RangeCountM - m, Strides.M); // // Copy a panel of matrix A to a local packed buffer. // - MlasGemmU8X8CopyPackA(PanelA, A + m * lda, lda, CountM, CountK, - RowSumBuffer); + MlasGemmU8X8CopyPackA( + PanelA, + A + m * lda, + lda, + CountM, + CountK, + RowSumBuffer); // // Apply the global depth value constant without the ZeroPointB scaling from: @@ -558,15 +639,28 @@ Return Value: bool PostProcess = (k + CountK == K); while (RowsRemaining > 0) { + size_t RowsHandled = MlasGemmU8X8Kernel( - pa, b, c, PackedCountK, RowsRemaining, CountN, ldc, RowSums, - ColumnSumBuffer, (PackedZeroPointB != nullptr) ? ZeroPointBBuffer : nullptr, + pa, + b, + c, + PackedCountK, + RowsRemaining, + CountN, + ldc, + RowSums, + ColumnSumBuffer, + (PackedZeroPointB != nullptr) ? ZeroPointBBuffer : nullptr, ZeroMode); if (PostProcess && Data->OutputProcessor != nullptr) { Data->OutputProcessor->Process( - Data->C, RangeStartM + m + CountM - RowsRemaining, RangeStartN + n, - RowsHandled, CountN, Data->ldc); + Data->C, + RangeStartM + m + CountM - RowsRemaining, + RangeStartN + n, + RowsHandled, + CountN, + Data->ldc); } c += ldc * RowsHandled; @@ -584,7 +678,8 @@ Return Value: #if defined(MLAS_SSE2_INTRINSICS) -struct MLAS_GEMM_U8X8_KERNEL_SSE { +struct MLAS_GEMM_U8X8_KERNEL_SSE +{ typedef int16_t PackedAType; typedef int16_t PackedBType; typedef int8_t OffsetBType; @@ -596,9 +691,13 @@ struct MLAS_GEMM_U8X8_KERNEL_SSE { constexpr size_t MLAS_GEMM_U8X8_KERNEL_SSE::PackedK; constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_SSE::Strides; -template <> -MLAS_FORCEINLINE int32_t -MlasGemmU8X8FixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmU8X8FixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) { if (!BIsSigned) { ZeroPointB = MLAS_GEMM_U8X8_KERNEL_SSE::OffsetBType(ZeroPointB ^ 0x80); @@ -607,24 +706,27 @@ MlasGemmU8X8FixupZeroPointB(int32_t ZeroPointB, bool return ZeroPointB; } -template <> +template<> void -MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_SSE::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer) +MlasGemmU8X8CopyPackA( + MLAS_GEMM_U8X8_KERNEL_SSE::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ) { const __m128i ZeroVector = _mm_setzero_si128(); const __m128i OnesWordBroadcast = _mm_set1_epi16(1); - uint8_t PaddedMatrixAData[8] = {0}; + uint8_t PaddedMatrixAData[8] = { 0 }; // // Process a single row of matrix A in a loop. // while (CountM > 0) { + const uint8_t* a = A; size_t k = CountK; __m128i ReductionVector = ZeroVector; @@ -643,6 +745,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_SSE::Pack // while (k >= 8) { + __m128i Bytes = _mm_loadl_epi64((const __m128i*)&a[0]); __m128i Words = _mm_unpacklo_epi8(Bytes, ZeroVector); @@ -656,6 +759,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_SSE::Pack } if (k > 0) { + // // Copy the remaining bytes to the zero padded stack buffer. // @@ -691,10 +795,10 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_SSE::Pack // ReductionVector = _mm_madd_epi16(ReductionVector, OnesWordBroadcast); - ReductionVector = _mm_add_epi32( - ReductionVector, _mm_shuffle_epi32(ReductionVector, _MM_SHUFFLE(3, 2, 3, 2))); - ReductionVector = _mm_add_epi32( - ReductionVector, _mm_shuffle_epi32(ReductionVector, _MM_SHUFFLE(0, 1, 0, 1))); + ReductionVector = _mm_add_epi32(ReductionVector, + _mm_shuffle_epi32(ReductionVector, _MM_SHUFFLE(3, 2, 3, 2))); + ReductionVector = _mm_add_epi32(ReductionVector, + _mm_shuffle_epi32(ReductionVector, _MM_SHUFFLE(0, 1, 0, 1))); *RowSumBuffer++ = _mm_cvtsi128_si32(ReductionVector); @@ -705,20 +809,20 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_SSE::Pack MLAS_FORCEINLINE void -MlasGemmU8X8CopyPackBProcessSse(MLAS_GEMM_U8X8_KERNEL_SSE::PackedBType* D, - __m128i BytesRow0, - __m128i BytesRow1, - __m128i BitFlipVector, - __m128i ColumnSums[2]) +MlasGemmU8X8CopyPackBProcessSse( + MLAS_GEMM_U8X8_KERNEL_SSE::PackedBType* D, + __m128i BytesRow0, + __m128i BytesRow1, + __m128i BitFlipVector, + __m128i ColumnSums[2] + ) { __m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, BytesRow1); BytesInterleaved = _mm_xor_si128(BytesInterleaved, BitFlipVector); - __m128i WordsInterleaved0 = - _mm_srai_epi16(_mm_unpacklo_epi8(BytesInterleaved, BytesInterleaved), 8); - __m128i WordsInterleaved1 = - _mm_srai_epi16(_mm_unpackhi_epi8(BytesInterleaved, BytesInterleaved), 8); + __m128i WordsInterleaved0 = _mm_srai_epi16(_mm_unpacklo_epi8(BytesInterleaved, BytesInterleaved), 8); + __m128i WordsInterleaved1 = _mm_srai_epi16(_mm_unpackhi_epi8(BytesInterleaved, BytesInterleaved), 8); ColumnSums[0] = _mm_add_epi16(ColumnSums[0], WordsInterleaved0); ColumnSums[1] = _mm_add_epi16(ColumnSums[1], WordsInterleaved1); @@ -727,15 +831,17 @@ MlasGemmU8X8CopyPackBProcessSse(MLAS_GEMM_U8X8_KERNEL_SSE::PackedBType* D, _mm_storeu_si128((__m128i*)&D[8], WordsInterleaved1); } -template <> +template<> void -MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_SSE::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned) +MlasGemmU8X8CopyPackB( + MLAS_GEMM_U8X8_KERNEL_SSE::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) { const __m128i OnesWordBroadcast = _mm_set1_epi16(1); const __m128i BitFlipVector = _mm_set1_epi32(BIsSigned ? 0 : 0x80808080); @@ -745,6 +851,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_SSE::Pack // while (CountN >= 8) { + const uint8_t* b = B; size_t k = CountK; __m128i ColumnSums[2]; @@ -761,6 +868,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_SSE::Pack // while (k >= MLAS_GEMM_U8X8_KERNEL_SSE::PackedK) { + __m128i BytesRow0 = _mm_loadl_epi64((const __m128i*)&b[0]); __m128i BytesRow1 = _mm_loadl_epi64((const __m128i*)&b[ldb]); @@ -772,6 +880,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_SSE::Pack } if (k > 0) { + __m128i BytesRow0 = _mm_loadl_epi64((const __m128i*)&b[0]); MlasGemmU8X8CopyPackBProcessSse(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums); @@ -795,6 +904,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_SSE::Pack // if (CountN > 0) { + const uint8_t* b = B; size_t k = CountK; __m128i ColumnSums[2]; @@ -811,6 +921,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_SSE::Pack // while (k >= MLAS_GEMM_U8X8_KERNEL_SSE::PackedK) { + const uint8_t* bcopy = b; uint8_t* padded = PaddedMatrixBData; uint8_t* padded_end = padded + CountN; @@ -833,6 +944,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_SSE::Pack } if (k > 0) { + const uint8_t* bcopy = b; uint8_t* padded = PaddedMatrixBData; uint8_t* padded_end = padded + CountN; @@ -858,7 +970,11 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_SSE::Pack MLAS_FORCEINLINE void -MlasGemmU8X8MultiplyAccumulateRowSse(__m128i ABroadcast, const int16_t* B, __m128i Accumulators[2]) +MlasGemmU8X8MultiplyAccumulateRowSse( + __m128i ABroadcast, + const int16_t* B, + __m128i Accumulators[2] + ) { __m128i BElements0 = _mm_load_si128((__m128i*)&B[0]); __m128i BElements1 = _mm_load_si128((__m128i*)&B[8]); @@ -867,24 +983,27 @@ MlasGemmU8X8MultiplyAccumulateRowSse(__m128i ABroadcast, const int16_t* B, __m12 Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_madd_epi16(BElements1, ABroadcast)); } -template <> +template<> size_t -MlasGemmU8X8Kernel(const MLAS_GEMM_U8X8_KERNEL_SSE::PackedAType* A, - const MLAS_GEMM_U8X8_KERNEL_SSE::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode) +MlasGemmU8X8Kernel( + const MLAS_GEMM_U8X8_KERNEL_SSE::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_SSE::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) { MLAS_UNREFERENCED_PARAMETER(CountM); MLAS_UNREFERENCED_PARAMETER(ldc); while (CountN > 0) { + __m128i Accumulators[2]; // @@ -894,6 +1013,7 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8X8_KERNEL_SSE::P int32_t RowSumValue = RowSumBuffer[0]; if (ZeroPointB != nullptr) { + int32_t ScaledRowSumBuffer[8]; for (size_t i = 0; i < 8; i++) { @@ -906,14 +1026,13 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8X8_KERNEL_SSE::P Accumulators[1] = _mm_loadu_si128((__m128i*)&ScaledRowSumBuffer[4]); } else { + Accumulators[0] = _mm_set1_epi32(RowSumValue); Accumulators[1] = Accumulators[0]; } - Accumulators[0] = - _mm_add_epi32(Accumulators[0], _mm_loadu_si128((const __m128i*)&ColumnSumBuffer[0])); - Accumulators[1] = - _mm_add_epi32(Accumulators[1], _mm_loadu_si128((const __m128i*)&ColumnSumBuffer[4])); + Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((const __m128i*)&ColumnSumBuffer[0])); + Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_loadu_si128((const __m128i*)&ColumnSumBuffer[4])); ColumnSumBuffer += 8; // @@ -926,6 +1045,7 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8X8_KERNEL_SSE::P size_t k = PackedCountK; while (k >= 4) { + __m128i AElements = _mm_loadu_si128((__m128i*)a); __m128i ABroadcast; @@ -947,6 +1067,7 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8X8_KERNEL_SSE::P } while (k > 0) { + __m128i ABroadcast = _mm_set1_epi32(*((int32_t*)a)); MlasGemmU8X8MultiplyAccumulateRowSse(ABroadcast, &B[0], Accumulators); @@ -961,6 +1082,7 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8X8_KERNEL_SSE::P // if (CountN >= 8) { + if (!ZeroMode) { Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((__m128i*)&C[0])); Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_loadu_si128((__m128i*)&C[4])); @@ -973,14 +1095,15 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8X8_KERNEL_SSE::P CountN -= 8; } else { + // // Output the remaining partial output block. // if ((CountN & 4) != 0) { + if (!ZeroMode) { - Accumulators[0] = - _mm_add_epi32(Accumulators[0], _mm_loadu_si128((__m128i*)&C[0])); + Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((__m128i*)&C[0])); } _mm_storeu_si128((__m128i*)&C[0], Accumulators[0]); @@ -990,9 +1113,9 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8X8_KERNEL_SSE::P } if ((CountN & 2) != 0) { + if (!ZeroMode) { - Accumulators[0] = - _mm_add_epi32(Accumulators[0], _mm_loadl_epi64((__m128i*)&C[0])); + Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadl_epi64((__m128i*)&C[0])); } _mm_storel_epi64((__m128i*)&C[0], Accumulators[0]); @@ -1002,6 +1125,7 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8X8_KERNEL_SSE::P } if ((CountN & 1) != 0) { + int32_t AccumulatorValue = _mm_cvtsi128_si32(Accumulators[0]); if (!ZeroMode) { @@ -1032,7 +1156,8 @@ const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchSse = { // for this code is Windows only, so restrict this kernel to that environment. #if defined(MLAS_SSE2_INTRINSICS) && defined(_MSC_VER) -struct MLAS_GEMM_U8S8_KERNEL_SSE41 { +struct MLAS_GEMM_U8S8_KERNEL_SSE41 +{ typedef uint8_t PackedAType; typedef uint8_t PackedBType; typedef int8_t OffsetBType; @@ -1046,14 +1171,16 @@ constexpr size_t MLAS_GEMM_U8S8_KERNEL_SSE41::PackedK; constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8S8_KERNEL_SSE41::Strides; constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8S8_KERNEL_SSE41::PackedStrides; -template <> +template<> void -MlasGemmU8X8CopyPackA(MLAS_GEMM_U8S8_KERNEL_SSE41::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer) +MlasGemmU8X8CopyPackA( + MLAS_GEMM_U8S8_KERNEL_SSE41::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ) { const __m128i ZeroVector = _mm_setzero_si128(); const __m128i OnesWordBroadcast = _mm_set1_epi16(1); @@ -1063,6 +1190,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8S8_KERNEL_SSE41:: // while (CountM > 0) { + const uint8_t* a = A; size_t k = CountK; __m128i ReductionVector = ZeroVector; @@ -1076,11 +1204,11 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8S8_KERNEL_SSE41:: // while (k >= 8) { + __m128i Bytes = _mm_loadl_epi64((const __m128i*)&a[0]); __m128i Words = _mm_unpacklo_epi8(Bytes, ZeroVector); - ReductionVector = - _mm_add_epi32(ReductionVector, _mm_madd_epi16(Words, OnesWordBroadcast)); + ReductionVector = _mm_add_epi32(ReductionVector, _mm_madd_epi16(Words, OnesWordBroadcast)); _mm_storel_epi64((__m128i*)&D[0], Bytes); @@ -1090,6 +1218,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8S8_KERNEL_SSE41:: } if (k > 0) { + // // Copy the remaining bytes to the zero padded stack buffer. // @@ -1102,8 +1231,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8S8_KERNEL_SSE41:: D += (k + 3) & ~3; __m128i Words = _mm_unpacklo_epi8(Bytes, ZeroVector); - ReductionVector = - _mm_add_epi32(ReductionVector, _mm_madd_epi16(Words, OnesWordBroadcast)); + ReductionVector = _mm_add_epi32(ReductionVector, _mm_madd_epi16(Words, OnesWordBroadcast)); } // @@ -1122,11 +1250,13 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8S8_KERNEL_SSE41:: MLAS_FORCEINLINE void -MlasGemmU8X8CopyPackBProcessSse41(MLAS_GEMM_U8S8_KERNEL_SSE41::PackedBType* D, - __m128i BytesRows[4], - __m128i OnesByteBroadcast, - __m128i OnesWordBroadcast, - __m128i ColumnSums[2]) +MlasGemmU8X8CopyPackBProcessSse41( + MLAS_GEMM_U8S8_KERNEL_SSE41::PackedBType* D, + __m128i BytesRows[4], + __m128i OnesByteBroadcast, + __m128i OnesWordBroadcast, + __m128i ColumnSums[2] + ) { __m128i PairsInterleaved0 = _mm_unpacklo_epi8(BytesRows[0], BytesRows[1]); __m128i PairsInterleaved1 = _mm_unpacklo_epi8(BytesRows[2], BytesRows[3]); @@ -1147,15 +1277,17 @@ MlasGemmU8X8CopyPackBProcessSse41(MLAS_GEMM_U8S8_KERNEL_SSE41::PackedBType* D, _mm_storeu_si128((__m128i*)&D[16], QuadsInterleaved1); } -template <> +template<> void -MlasGemmU8X8CopyPackB(MLAS_GEMM_U8S8_KERNEL_SSE41::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned) +MlasGemmU8X8CopyPackB( + MLAS_GEMM_U8S8_KERNEL_SSE41::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) { const __m128i OnesByteBroadcast = _mm_set1_epi8(1); const __m128i OnesWordBroadcast = _mm_set1_epi16(1); @@ -1168,6 +1300,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8S8_KERNEL_SSE41:: // while (CountN >= 8) { + const uint8_t* b = B; size_t k = CountK; __m128i ColumnSums[2]; @@ -1180,13 +1313,13 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8S8_KERNEL_SSE41:: // while (k >= MLAS_GEMM_U8S8_KERNEL_SSE41::PackedK) { + BytesRows[0] = _mm_loadl_epi64((const __m128i*)&b[ldb * 0]); BytesRows[1] = _mm_loadl_epi64((const __m128i*)&b[ldb * 1]); BytesRows[2] = _mm_loadl_epi64((const __m128i*)&b[ldb * 2]); BytesRows[3] = _mm_loadl_epi64((const __m128i*)&b[ldb * 3]); - MlasGemmU8X8CopyPackBProcessSse41(D, BytesRows, OnesByteBroadcast, OnesWordBroadcast, - ColumnSums); + MlasGemmU8X8CopyPackBProcessSse41(D, BytesRows, OnesByteBroadcast, OnesWordBroadcast, ColumnSums); b += ldb * 4; D += 32; @@ -1194,6 +1327,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8S8_KERNEL_SSE41:: } if (k > 0) { + BytesRows[0] = _mm_loadl_epi64((const __m128i*)&b[ldb * 0]); BytesRows[1] = _mm_setzero_si128(); BytesRows[2] = _mm_setzero_si128(); @@ -1207,8 +1341,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8S8_KERNEL_SSE41:: BytesRows[2] = _mm_loadl_epi64((const __m128i*)&b[ldb * 2]); } - MlasGemmU8X8CopyPackBProcessSse41(D, BytesRows, OnesByteBroadcast, OnesWordBroadcast, - ColumnSums); + MlasGemmU8X8CopyPackBProcessSse41(D, BytesRows, OnesByteBroadcast, OnesWordBroadcast, ColumnSums); D += 32; } @@ -1226,6 +1359,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8S8_KERNEL_SSE41:: // if (CountN > 0) { + const __m128i ZeroVector = _mm_setzero_si128(); __m128i ColumnSums[2]; @@ -1235,6 +1369,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8S8_KERNEL_SSE41:: ColumnSums[1] = _mm_setzero_si128(); while (CountK > 0) { + size_t k = std::min(CountK, MLAS_GEMM_U8S8_KERNEL_SSE41::PackedK); CountK -= k; @@ -1244,6 +1379,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8S8_KERNEL_SSE41:: uint8_t* padded = PaddedMatrixBData; do { + std::copy_n(B, CountN, padded); padded += 8; @@ -1257,8 +1393,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8S8_KERNEL_SSE41:: BytesRows[2] = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[16]); BytesRows[3] = _mm_loadl_epi64((__m128i*)&PaddedMatrixBData[24]); - MlasGemmU8X8CopyPackBProcessSse41(D, BytesRows, OnesByteBroadcast, OnesWordBroadcast, - ColumnSums); + MlasGemmU8X8CopyPackBProcessSse41(D, BytesRows, OnesByteBroadcast, OnesWordBroadcast, ColumnSums); D += 32; } @@ -1270,10 +1405,12 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8S8_KERNEL_SSE41:: MLAS_FORCEINLINE void -MlasGemmU8X8MultiplyAccumulateRowSse41(__m128i ABroadcast, - const MLAS_GEMM_U8S8_KERNEL_SSE41::PackedBType* B, - __m128i OnesWordBroadcast, - __m128i Accumulators[2]) +MlasGemmU8X8MultiplyAccumulateRowSse41( + __m128i ABroadcast, + const MLAS_GEMM_U8S8_KERNEL_SSE41::PackedBType* B, + __m128i OnesWordBroadcast, + __m128i Accumulators[2] + ) { __m128i BElements0 = _mm_load_si128((__m128i*)&B[0]); __m128i BElements1 = _mm_load_si128((__m128i*)&B[16]); @@ -1281,25 +1418,25 @@ MlasGemmU8X8MultiplyAccumulateRowSse41(__m128i ABroadcast, __m128i Intermediate0 = _mm_maddubs_epi16(ABroadcast, BElements0); __m128i Intermediate1 = _mm_maddubs_epi16(ABroadcast, BElements1); - Accumulators[0] = - _mm_add_epi32(Accumulators[0], _mm_madd_epi16(Intermediate0, OnesWordBroadcast)); - Accumulators[1] = - _mm_add_epi32(Accumulators[1], _mm_madd_epi16(Intermediate1, OnesWordBroadcast)); + Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_madd_epi16(Intermediate0, OnesWordBroadcast)); + Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_madd_epi16(Intermediate1, OnesWordBroadcast)); } -template <> +template<> size_t -MlasGemmU8X8Kernel(const MLAS_GEMM_U8S8_KERNEL_SSE41::PackedAType* A, - const MLAS_GEMM_U8S8_KERNEL_SSE41::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode) +MlasGemmU8X8Kernel( + const MLAS_GEMM_U8S8_KERNEL_SSE41::PackedAType* A, + const MLAS_GEMM_U8S8_KERNEL_SSE41::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) { const __m128i OnesWordBroadcast = _mm_set1_epi16(1); @@ -1307,6 +1444,7 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8S8_KERNEL_SSE4 MLAS_UNREFERENCED_PARAMETER(ldc); while (CountN > 0) { + __m128i Accumulators[2]; // @@ -1317,17 +1455,13 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8S8_KERNEL_SSE4 Accumulators[1] = Accumulators[0]; if (ZeroPointB != nullptr) { - Accumulators[0] = - _mm_mullo_epi32(Accumulators[0], _mm_loadu_si128((const __m128i*)&ZeroPointB[0])); - Accumulators[1] = - _mm_mullo_epi32(Accumulators[1], _mm_loadu_si128((const __m128i*)&ZeroPointB[4])); + Accumulators[0] = _mm_mullo_epi32(Accumulators[0], _mm_loadu_si128((const __m128i*)&ZeroPointB[0])); + Accumulators[1] = _mm_mullo_epi32(Accumulators[1], _mm_loadu_si128((const __m128i*)&ZeroPointB[4])); ZeroPointB += 8; } - Accumulators[0] = - _mm_add_epi32(Accumulators[0], _mm_loadu_si128((const __m128i*)&ColumnSumBuffer[0])); - Accumulators[1] = - _mm_add_epi32(Accumulators[1], _mm_loadu_si128((const __m128i*)&ColumnSumBuffer[4])); + Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((const __m128i*)&ColumnSumBuffer[0])); + Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_loadu_si128((const __m128i*)&ColumnSumBuffer[4])); ColumnSumBuffer += 8; // @@ -1340,24 +1474,21 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8S8_KERNEL_SSE4 size_t k = PackedCountK; while (k >= 4) { + __m128i AElements = _mm_loadu_si128((__m128i*)a); __m128i ABroadcast; ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(0, 0, 0, 0)); - MlasGemmU8X8MultiplyAccumulateRowSse41(ABroadcast, &B[0], OnesWordBroadcast, - Accumulators); + MlasGemmU8X8MultiplyAccumulateRowSse41(ABroadcast, &B[0], OnesWordBroadcast, Accumulators); ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(1, 1, 1, 1)); - MlasGemmU8X8MultiplyAccumulateRowSse41(ABroadcast, &B[32], OnesWordBroadcast, - Accumulators); + MlasGemmU8X8MultiplyAccumulateRowSse41(ABroadcast, &B[32], OnesWordBroadcast, Accumulators); ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(2, 2, 2, 2)); - MlasGemmU8X8MultiplyAccumulateRowSse41(ABroadcast, &B[64], OnesWordBroadcast, - Accumulators); + MlasGemmU8X8MultiplyAccumulateRowSse41(ABroadcast, &B[64], OnesWordBroadcast, Accumulators); ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(3, 3, 3, 3)); - MlasGemmU8X8MultiplyAccumulateRowSse41(ABroadcast, &B[96], OnesWordBroadcast, - Accumulators); + MlasGemmU8X8MultiplyAccumulateRowSse41(ABroadcast, &B[96], OnesWordBroadcast, Accumulators); a += 4 * 4; B += 4 * 32; @@ -1365,9 +1496,9 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8S8_KERNEL_SSE4 } while (k > 0) { + __m128i ABroadcast = _mm_set1_epi32(*((int32_t*)a)); - MlasGemmU8X8MultiplyAccumulateRowSse41(ABroadcast, &B[0], OnesWordBroadcast, - Accumulators); + MlasGemmU8X8MultiplyAccumulateRowSse41(ABroadcast, &B[0], OnesWordBroadcast, Accumulators); a += 4; B += 32; @@ -1380,6 +1511,7 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8S8_KERNEL_SSE4 // if (CountN >= 8) { + if (!ZeroMode) { Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((__m128i*)&C[0])); Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_loadu_si128((__m128i*)&C[4])); @@ -1392,14 +1524,15 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8S8_KERNEL_SSE4 CountN -= 8; } else { + // // Output the remaining partial output block. // if ((CountN & 4) != 0) { + if (!ZeroMode) { - Accumulators[0] = - _mm_add_epi32(Accumulators[0], _mm_loadu_si128((__m128i*)&C[0])); + Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((__m128i*)&C[0])); } _mm_storeu_si128((__m128i*)&C[0], Accumulators[0]); @@ -1409,9 +1542,9 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8S8_KERNEL_SSE4 } if ((CountN & 2) != 0) { + if (!ZeroMode) { - Accumulators[0] = - _mm_add_epi32(Accumulators[0], _mm_loadl_epi64((__m128i*)&C[0])); + Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadl_epi64((__m128i*)&C[0])); } _mm_storel_epi64((__m128i*)&C[0], Accumulators[0]); @@ -1421,6 +1554,7 @@ MlasGemmU8X8Kernel(const MLAS_GEMM_U8S8_KERNEL_SSE4 } if ((CountN & 1) != 0) { + int32_t AccumulatorValue = _mm_cvtsi128_si32(Accumulators[0]); if (!ZeroMode) { @@ -1453,9 +1587,8 @@ const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8S8DispatchSse41 = { // Stores a vector to transpose a 4x4 byte vector using vpshufb. // -MLAS_INTERNAL_DATA -MLAS_DECLSPEC_ALIGN(const uint8_t MlasTranspose4x4BytesAvx[16], 16) = {0, 4, 8, 12, 1, 5, 9, 13, - 2, 6, 10, 14, 3, 7, 11, 15}; +MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint8_t MlasTranspose4x4BytesAvx[16], 16) = + { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; // // Define the prototypes of the AVX2/AVX512 routines written in assembly. @@ -1463,33 +1596,54 @@ MLAS_DECLSPEC_ALIGN(const uint8_t MlasTranspose4x4BytesAvx[16], 16) = {0, 4, 8, extern "C" { -void MLASCALL -MlasGemmU8S8CopyPackAAvx2( - uint8_t* D, const uint8_t* A, size_t lda, size_t CountM, size_t CountK, int32_t* RowSumBuffer); + void + MLASCALL + MlasGemmU8S8CopyPackAAvx2( + uint8_t* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ); -void MLASCALL -MlasGemmU8S8CopyPackBAvx2(uint8_t* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned); + void + MLASCALL + MlasGemmU8S8CopyPackBAvx2( + uint8_t* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ); -void MLASCALL -MlasGemmU8U8CopyPackAAvx2( - int16_t* D, const uint8_t* A, size_t lda, size_t CountM, size_t CountK, int32_t* RowSumBuffer); + void + MLASCALL + MlasGemmU8U8CopyPackAAvx2( + int16_t* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ); -void MLASCALL -MlasGemmU8U8CopyPackBAvx2(uint8_t* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer); + void + MLASCALL + MlasGemmU8U8CopyPackBAvx2( + uint8_t* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer + ); } -struct MLAS_GEMM_U8S8_KERNEL_AVX2 { +struct MLAS_GEMM_U8S8_KERNEL_AVX2 +{ typedef uint8_t PackedAType; typedef uint8_t PackedBType; typedef int8_t OffsetBType; @@ -1503,15 +1657,18 @@ constexpr size_t MLAS_GEMM_U8S8_KERNEL_AVX2::PackedK; constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8S8_KERNEL_AVX2::Strides; constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8S8_KERNEL_AVX2::PackedStrides; -template <> -MLAS_FORCEINLINE bool -MlasGemmU8X8TryGemvKernel(const uint8_t* A, - const uint8_t* B, - size_t ldb, - int32_t* C, - size_t CountK, - size_t CountN, - bool BIsSigned) +template<> +MLAS_FORCEINLINE +bool +MlasGemmU8X8TryGemvKernel( + const uint8_t* A, + const uint8_t* B, + size_t ldb, + int32_t* C, + size_t CountK, + size_t CountN, + bool BIsSigned + ) { if (BIsSigned) { MlasPlatform.GemvU8S8Kernel(A, B, C, CountK, CountN, ldb); @@ -1521,9 +1678,13 @@ MlasGemmU8X8TryGemvKernel(const uint8_t* A, return false; } -template <> -MLAS_FORCEINLINE int32_t -MlasGemmU8X8FixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmU8X8FixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) { if (!BIsSigned) { ZeroPointB = MLAS_GEMM_U8S8_KERNEL_AVX2::OffsetBType(ZeroPointB ^ 0x80); @@ -1532,47 +1693,56 @@ MlasGemmU8X8FixupZeroPointB(int32_t ZeroPointB, bool return ZeroPointB; } -template <> -MLAS_FORCEINLINE void -MlasGemmU8X8CopyPackA(MLAS_GEMM_U8S8_KERNEL_AVX2::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer) +template<> +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackA( + MLAS_GEMM_U8S8_KERNEL_AVX2::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ) { MlasGemmU8S8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer); } -template <> -MLAS_FORCEINLINE void -MlasGemmU8X8CopyPackB(MLAS_GEMM_U8S8_KERNEL_AVX2::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned) +template<> +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackB( + MLAS_GEMM_U8S8_KERNEL_AVX2::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) { MlasGemmU8S8CopyPackBAvx2(D, B, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned); } -template <> -MLAS_FORCEINLINE size_t -MlasGemmU8X8Kernel(const MLAS_GEMM_U8S8_KERNEL_AVX2::PackedAType* A, - const MLAS_GEMM_U8S8_KERNEL_AVX2::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode) +template<> +MLAS_FORCEINLINE +size_t +MlasGemmU8X8Kernel( + const MLAS_GEMM_U8S8_KERNEL_AVX2::PackedAType* A, + const MLAS_GEMM_U8S8_KERNEL_AVX2::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) { - return MlasPlatform.GemmU8S8Kernel(A, B, C, PackedCountK, CountM, CountN, ldc, RowSumBuffer, - ColumnSumBuffer, ZeroPointB, ZeroMode); + return MlasPlatform.GemmU8S8Kernel(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); } const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8S8DispatchAvx2 = { @@ -1583,7 +1753,8 @@ const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8S8DispatchAvx2 = { MLAS_GEMM_U8S8_KERNEL_AVX2::PackedStrides.K, }; -struct MLAS_GEMM_U8U8_KERNEL_AVX2 { +struct MLAS_GEMM_U8U8_KERNEL_AVX2 +{ typedef int16_t PackedAType; typedef uint8_t PackedBType; typedef uint8_t OffsetBType; @@ -1597,49 +1768,58 @@ constexpr size_t MLAS_GEMM_U8U8_KERNEL_AVX2::PackedK; constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8U8_KERNEL_AVX2::Strides; constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8U8_KERNEL_AVX2::PackedStrides; -template <> -MLAS_FORCEINLINE void -MlasGemmU8X8CopyPackA(MLAS_GEMM_U8U8_KERNEL_AVX2::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer) +template<> +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackA( + MLAS_GEMM_U8U8_KERNEL_AVX2::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ) { MlasGemmU8U8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer); } -template <> -MLAS_FORCEINLINE void -MlasGemmU8X8CopyPackB(MLAS_GEMM_U8U8_KERNEL_AVX2::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned) +template<> +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackB( + MLAS_GEMM_U8U8_KERNEL_AVX2::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) { MLAS_UNREFERENCED_PARAMETER(BIsSigned); MlasGemmU8U8CopyPackBAvx2(D, B, ldb, CountN, CountK, ColumnSumBuffer); } -template <> -MLAS_FORCEINLINE size_t -MlasGemmU8X8Kernel(const MLAS_GEMM_U8U8_KERNEL_AVX2::PackedAType* A, - const MLAS_GEMM_U8U8_KERNEL_AVX2::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode) +template<> +MLAS_FORCEINLINE +size_t +MlasGemmU8X8Kernel( + const MLAS_GEMM_U8U8_KERNEL_AVX2::PackedAType* A, + const MLAS_GEMM_U8U8_KERNEL_AVX2::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) { - return MlasPlatform.GemmU8U8Kernel(A, B, C, PackedCountK, CountM, CountN, ldc, RowSumBuffer, - ColumnSumBuffer, ZeroPointB, ZeroMode); + return MlasPlatform.GemmU8U8Kernel(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); } const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8U8DispatchAvx2 = { @@ -1662,21 +1842,25 @@ const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8U8DispatchAvx2 = { extern "C" { -size_t MLASCALL -MlasGemmU8X8KernelNeon(const uint8_t* A, - const uint8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumVector, - const int32_t* ColumnSumVector, - const int32_t* ZeroPointB, - bool ZeroMode); + size_t + MLASCALL + MlasGemmU8X8KernelNeon( + const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB, + bool ZeroMode + ); } -struct MLAS_GEMM_U8X8_KERNEL_NEON { +struct MLAS_GEMM_U8X8_KERNEL_NEON +{ typedef uint8_t PackedAType; typedef uint8_t PackedBType; typedef uint8_t OffsetBType; @@ -1690,9 +1874,13 @@ constexpr size_t MLAS_GEMM_U8X8_KERNEL_NEON::PackedK; constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_NEON::Strides; constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_NEON::PackedStrides; -template <> -MLAS_FORCEINLINE int32_t -MlasGemmU8X8FixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmU8X8FixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) { if (BIsSigned) { ZeroPointB = MLAS_GEMM_U8X8_KERNEL_NEON::OffsetBType(ZeroPointB ^ 0x80); @@ -1701,14 +1889,16 @@ MlasGemmU8X8FixupZeroPointB(int32_t ZeroPointB, bool return ZeroPointB; } -template <> +template<> void -MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer) +MlasGemmU8X8CopyPackA( + MLAS_GEMM_U8X8_KERNEL_NEON::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ) { uint8_t PaddedMatrixAData[16]; @@ -1728,6 +1918,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa // while (CountM >= 4) { + const uint8_t* a0 = A; const uint8_t* a1 = a0 + lda; const uint8_t* a2 = a1 + lda; @@ -1737,6 +1928,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa uint32x4_t RowSums = vmovq_n_u32(0); while (k >= 16) { + uint32x4_t v0 = vld1q_u32(reinterpret_cast(a0)); a0 += 16; uint32x4_t v1 = vld1q_u32(reinterpret_cast(a1)); @@ -1789,6 +1981,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa } while (k >= 4) { + uint32_t v0 = *reinterpret_cast(a0); a0 += 4; uint32_t v1 = *reinterpret_cast(a1); @@ -1810,6 +2003,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa } if (k > 0) { + // // Copy the remaining bytes to the zero padded stack buffer. // @@ -1819,6 +2013,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); while (k > 0) { + d[0] = *a0++; d[4] = *a1++; d[8] = *a2++; @@ -1859,6 +2054,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa // if ((CountM & 2) != 0) { + const uint8_t* a0 = A; const uint8_t* a1 = a0 + lda; @@ -1866,6 +2062,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa uint32x2_t RowSums = vmov_n_u32(0); while (k >= 4) { + uint32_t v0 = *reinterpret_cast(a0); a0 += 4; uint32_t v1 = *reinterpret_cast(a1); @@ -1881,6 +2078,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa } if (k > 0) { + // // Copy the remaining bytes to the zero padded stack buffer. // @@ -1890,6 +2088,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); while (k > 0) { + d[0] = *a0++; d[4] = *a1++; @@ -1927,11 +2126,13 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa // if ((CountM & 1) != 0) { + const uint8_t* a = A; size_t k = CountK; uint32x4_t RowSums = vmovq_n_u32(0); while (k >= 16) { + uint8x16_t v = vld1q_u8(a); a += 16; @@ -1944,6 +2145,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa } if (k > 0) { + // // Copy the remaining bytes to the zero padded stack buffer. // @@ -1980,10 +2182,12 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa MLAS_FORCEINLINE void -MlasGemmU8X8CopyPackBProcessNeon(uint8_t* D, - const uint8_t* B, - uint8x8_t BitFlipVector, - uint32x4_t ColumnSums[2]) +MlasGemmU8X8CopyPackBProcessNeon( + uint8_t* D, + const uint8_t* B, + uint8x8_t BitFlipVector, + uint32x4_t ColumnSums[2] + ) { uint8x8_t BytesRow = veor_u8(vld1_u8(B), BitFlipVector); vst1_u8(D, BytesRow); @@ -1997,20 +2201,22 @@ MlasGemmU8X8CopyPackBProcessNeon(uint8_t* D, #endif } -template <> +template<> void -MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_NEON::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned) +MlasGemmU8X8CopyPackB( + MLAS_GEMM_U8X8_KERNEL_NEON::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) { const uint8x8_t BitFlipVector = vdup_n_u8(BIsSigned ? 0x80 : 0); const uint8x8_t ZeroVector = vmov_n_u8(0); - const size_t AlignedCountK = (CountK + MLAS_GEMM_U8X8_KERNEL_NEON::PackedK - 1) & - ~(MLAS_GEMM_U8X8_KERNEL_NEON::PackedK - 1); + const size_t AlignedCountK = + (CountK + MLAS_GEMM_U8X8_KERNEL_NEON::PackedK - 1) & ~(MLAS_GEMM_U8X8_KERNEL_NEON::PackedK - 1); // // Process 8 columns of matrix B in a loop. @@ -2026,6 +2232,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_NEON::Pa // while (CountN >= 8) { + const uint8_t* b = B; uint32x4_t ColumnSums[2]; @@ -2033,6 +2240,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_NEON::Pa ColumnSums[1] = vmovq_n_u32(0); for (size_t k = CountK; k > 0; k--) { + MlasGemmU8X8CopyPackBProcessNeon(D, b, BitFlipVector, ColumnSums); b += ldb; @@ -2057,6 +2265,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_NEON::Pa // if (CountN > 0) { + const uint8_t* b = B; uint8_t PaddedMatrixBData[8]; uint32x4_t ColumnSums[2]; @@ -2067,6 +2276,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_NEON::Pa ColumnSums[1] = vmovq_n_u32(0); for (size_t k = CountK; k > 0; k--) { + for (size_t n = 0; n < CountN; n++) { PaddedMatrixBData[n] = b[n]; } @@ -2087,22 +2297,25 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_NEON::Pa } } -template <> -MLAS_FORCEINLINE size_t -MlasGemmU8X8Kernel(const MLAS_GEMM_U8X8_KERNEL_NEON::PackedAType* A, - const MLAS_GEMM_U8X8_KERNEL_NEON::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode) +template<> +MLAS_FORCEINLINE +size_t +MlasGemmU8X8Kernel( + const MLAS_GEMM_U8X8_KERNEL_NEON::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_NEON::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) { - return MlasGemmU8X8KernelNeon(A, B, C, PackedCountK, CountM, CountN, ldc, RowSumBuffer, - ColumnSumBuffer, ZeroPointB, ZeroMode); + return MlasGemmU8X8KernelNeon(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); } const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchNeon = { @@ -2123,21 +2336,25 @@ const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchNeon = { extern "C" { -size_t MLASCALL -MlasGemmU8X8KernelUdot(const uint8_t* A, - const uint8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumVector, - const int32_t* ColumnSumVector, - const int32_t* ZeroPointB, - bool ZeroMode); + size_t + MLASCALL + MlasGemmU8X8KernelUdot( + const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB, + bool ZeroMode + ); } -struct MLAS_GEMM_U8X8_KERNEL_UDOT { +struct MLAS_GEMM_U8X8_KERNEL_UDOT +{ typedef uint8_t PackedAType; typedef uint8_t PackedBType; typedef uint8_t OffsetBType; @@ -2151,9 +2368,13 @@ constexpr size_t MLAS_GEMM_U8X8_KERNEL_UDOT::PackedK; constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_UDOT::Strides; constexpr MLAS_GEMM_U8X8_STRIDES MLAS_GEMM_U8X8_KERNEL_UDOT::PackedStrides; -template <> -MLAS_FORCEINLINE int32_t -MlasGemmU8X8FixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmU8X8FixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) { if (BIsSigned) { ZeroPointB = MLAS_GEMM_U8X8_KERNEL_UDOT::OffsetBType(ZeroPointB ^ 0x80); @@ -2162,14 +2383,16 @@ MlasGemmU8X8FixupZeroPointB(int32_t ZeroPointB, bool return ZeroPointB; } -template <> +template<> void -MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer) +MlasGemmU8X8CopyPackA( + MLAS_GEMM_U8X8_KERNEL_NEON::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ) { uint8_t PaddedMatrixAData[16]; @@ -2189,6 +2412,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa // while (CountM >= 4) { + const uint8_t* a0 = A; const uint8_t* a1 = a0 + lda; const uint8_t* a2 = a1 + lda; @@ -2198,6 +2422,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa uint32x4_t RowSums = vmovq_n_u32(0); while (k >= 16) { + uint32x4_t v0 = vld1q_u32(reinterpret_cast(a0)); a0 += 16; uint32x4_t v1 = vld1q_u32(reinterpret_cast(a1)); @@ -2232,6 +2457,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa } while (k >= 4) { + uint32_t v0 = *reinterpret_cast(a0); a0 += 4; uint32_t v1 = *reinterpret_cast(a1); @@ -2253,6 +2479,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa } if (k > 0) { + // // Copy the remaining bytes to the zero padded stack buffer. // @@ -2262,6 +2489,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); while (k > 0) { + d[0] = *a0++; d[4] = *a1++; d[8] = *a2++; @@ -2280,6 +2508,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa } if (((CountK - 1) & 7) < 4) { + vst1q_u8(D, vmovq_n_u8(0)); D += 16; @@ -2308,6 +2537,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa // if (CountM >= 2) { + const uint8_t* a0 = A; const uint8_t* a1 = a0 + lda; @@ -2315,6 +2545,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa uint32x2_t RowSums = vmov_n_u32(0); while (k >= 4) { + uint32_t v0 = *reinterpret_cast(a0); a0 += 4; uint32_t v1 = *reinterpret_cast(a1); @@ -2330,6 +2561,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa } if (k > 0) { + // // Copy the remaining bytes to the zero padded stack buffer. // @@ -2339,6 +2571,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa vst1_u8(PaddedMatrixAData, vmov_n_u8(0)); while (k > 0) { + d[0] = *a0++; d[4] = *a1++; @@ -2355,6 +2588,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa } if (((CountK - 1) & 7) < 4) { + vst1_u8(D, vmov_n_u8(0)); D += 8; @@ -2382,11 +2616,13 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa // if (CountM > 0) { + const uint8_t* a = A; size_t k = CountK; uint32x4_t RowSums = vmovq_n_u32(0); while (k >= 16) { + uint8x16_t v = vld1q_u8(a); a += 16; @@ -2399,6 +2635,7 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa } if (k > 0) { + // // Copy the remaining bytes to the zero padded stack buffer. // @@ -2431,10 +2668,12 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_NEON::Pa MLAS_FORCEINLINE void -MlasGemmU8X8CopyPackBProcessUdot(MLAS_GEMM_U8X8_KERNEL_UDOT::PackedBType* D, - uint8x8_t BytesRow[4], - uint8x16_t BitFlipVector, - uint32x4_t ColumnSums[2]) +MlasGemmU8X8CopyPackBProcessUdot( + MLAS_GEMM_U8X8_KERNEL_UDOT::PackedBType* D, + uint8x8_t BytesRow[4], + uint8x16_t BitFlipVector, + uint32x4_t ColumnSums[2] + ) { uint8x16_t v02 = veorq_u8(vcombine_u8(BytesRow[0], BytesRow[2]), BitFlipVector); uint8x16_t v13 = veorq_u8(vcombine_u8(BytesRow[1], BytesRow[3]), BitFlipVector); @@ -2449,15 +2688,17 @@ MlasGemmU8X8CopyPackBProcessUdot(MLAS_GEMM_U8X8_KERNEL_UDOT::PackedBType* D, ColumnSums[1] = vpadalq_u16(ColumnSums[1], vpaddlq_u8(vreinterpretq_u8_u16(zd.val[1]))); } -template <> +template<> void -MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_UDOT::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned) +MlasGemmU8X8CopyPackB( + MLAS_GEMM_U8X8_KERNEL_UDOT::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) { const uint8x16_t ZeroVector = vmovq_n_u8(0); const uint8x16_t BitFlipVector = vdupq_n_u8(BIsSigned ? 0x80 : 0); @@ -2485,6 +2726,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_UDOT::Pa // while (CountN >= 8) { + const uint8_t* b = B; size_t k = CountK; uint32x4_t ColumnSums[2]; @@ -2497,6 +2739,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_UDOT::Pa // while (k >= 4) { + BytesRow[0] = vld1_u8(&b[ldb * 0]); BytesRow[1] = vld1_u8(&b[ldb * 1]); BytesRow[2] = vld1_u8(&b[ldb * 2]); @@ -2510,6 +2753,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_UDOT::Pa } if (k > 0) { + BytesRow[0] = vld1_u8(&b[ldb * 0]); BytesRow[1] = (k >= 2) ? vld1_u8(&b[ldb * 1]) : vget_low_u8(BitFlipVector); BytesRow[2] = (k > 2) ? vld1_u8(&b[ldb * 2]) : vget_low_u8(BitFlipVector); @@ -2526,6 +2770,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_UDOT::Pa // if (((CountK - 1) & 7) < 4) { + vst1q_u8(&D[0], ZeroVector); vst1q_u8(&D[16], ZeroVector); @@ -2545,6 +2790,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_UDOT::Pa // if (CountN > 0) { + const uint8_t* b = B; size_t k = CountK; uint8_t PaddedMatrixBData[32]; @@ -2562,16 +2808,19 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_UDOT::Pa // while (k > 0) { + const uint8_t* bcopy0 = &b[ldb * 0]; const uint8_t* bcopy1 = &b[ldb * 1]; const uint8_t* bcopy2 = &b[ldb * 2]; const uint8_t* bcopy3 = &b[ldb * 3]; if (k >= 4) { + b += ldb * 4; k -= 4; } else { + vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); @@ -2608,6 +2857,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_UDOT::Pa // if (((CountK - 1) & 7) < 4) { + vst1q_u8(&D[0], ZeroVector); vst1q_u8(&D[16], ZeroVector); @@ -2619,22 +2869,25 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_UDOT::Pa } } -template <> -MLAS_FORCEINLINE size_t -MlasGemmU8X8Kernel(const MLAS_GEMM_U8X8_KERNEL_UDOT::PackedAType* A, - const MLAS_GEMM_U8X8_KERNEL_UDOT::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode) +template<> +MLAS_FORCEINLINE +size_t +MlasGemmU8X8Kernel( + const MLAS_GEMM_U8X8_KERNEL_UDOT::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_UDOT::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) { - return MlasGemmU8X8KernelUdot(A, B, C, PackedCountK, CountM, CountN, ldc, RowSumBuffer, - ColumnSumBuffer, ZeroPointB, ZeroMode); + return MlasGemmU8X8KernelUdot(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); } const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchUdot = { @@ -2647,7 +2900,8 @@ const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchUdot = { #endif -struct MLAS_GEMM_U8X8_KERNEL_DEFAULT { +struct MLAS_GEMM_U8X8_KERNEL_DEFAULT +{ typedef uint8_t PackedAType; typedef uint8_t PackedBType; typedef uint8_t OffsetBType; @@ -2657,9 +2911,13 @@ struct MLAS_GEMM_U8X8_KERNEL_DEFAULT { static constexpr MLAS_GEMM_U8X8_STRIDES PackedStrides{16, 128, 128}; }; -template <> -MLAS_FORCEINLINE int32_t -MlasGemmU8X8FixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmU8X8FixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) { if (BIsSigned) { ZeroPointB = MLAS_GEMM_U8X8_KERNEL_DEFAULT::OffsetBType(ZeroPointB ^ 0x80); @@ -2668,26 +2926,30 @@ MlasGemmU8X8FixupZeroPointB(int32_t ZeroPointB, b return ZeroPointB; } -template <> +template<> void -MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer) +MlasGemmU8X8CopyPackA( + MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ) { - const size_t AlignedCountK = (CountK + MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1) & - ~(MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1); + const size_t AlignedCountK = + (CountK + MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1) & ~(MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1); // // Process a single row of matrix A in a loop. // while (CountM-- > 0) { + int32_t RowSum = 0; for (size_t k = 0; k < CountK; k++) { + uint8_t a0 = A[k]; D[k] = a0; @@ -2705,18 +2967,20 @@ MlasGemmU8X8CopyPackA(MLAS_GEMM_U8X8_KERNEL_DEFAU } } -template <> +template<> void -MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned) +MlasGemmU8X8CopyPackB( + MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) { - const size_t AlignedCountK = (CountK + MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1) & - ~(MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1); + const size_t AlignedCountK = + (CountK + MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1) & ~(MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedK - 1); const uint8_t BitFlipValue = (BIsSigned ? 0x80 : 0); // @@ -2724,6 +2988,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_DEFAU // while (CountN-- > 0) { + const uint8_t* b = B; int32_t ColumnSum = 0; @@ -2732,6 +2997,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_DEFAU // for (size_t k = 0; k < CountK; k++) { + uint8_t b0 = b[0] ^ BitFlipValue; D[k] = b0; @@ -2751,7 +3017,7 @@ MlasGemmU8X8CopyPackB(MLAS_GEMM_U8X8_KERNEL_DEFAU } } -template <> +template<> size_t MlasGemmU8X8Kernel( const MLAS_GEMM_U8X8_KERNEL_DEFAULT::PackedAType* A, @@ -2764,7 +3030,8 @@ MlasGemmU8X8Kernel( const int32_t* RowSumBuffer, const int32_t* ColumnSumBuffer, const int32_t* ZeroPointB, - bool ZeroMode) + bool ZeroMode + ) { MLAS_UNREFERENCED_PARAMETER(CountM); MLAS_UNREFERENCED_PARAMETER(ldc); @@ -2774,6 +3041,7 @@ MlasGemmU8X8Kernel( // while (CountN-- > 0) { + int32_t Accumulator = *RowSumBuffer; if (ZeroPointB != nullptr) { @@ -2785,6 +3053,7 @@ MlasGemmU8X8Kernel( const auto* a = A; for (size_t k = 0; k < PackedCountK; k++) { + Accumulator += a[0] * B[0]; Accumulator += a[1] * B[1]; Accumulator += a[2] * B[2]; @@ -2813,11 +3082,14 @@ const MLAS_GEMM_U8X8_DISPATCH MlasGemmU8X8DispatchDefault = { 0, }; + void -MlasGemmU8X8Threaded(const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock, - const MLAS_GEMM_U8X8_SHAPE_PARAMS* Shape, - const MLAS_GEMM_U8X8_DATA_PARAMS* Data, - ptrdiff_t ThreadId) +MlasGemmU8X8Threaded( + const MLAS_GEMM_U8X8_WORK_BLOCK* WorkBlock, + const MLAS_GEMM_U8X8_SHAPE_PARAMS* Shape, + const MLAS_GEMM_U8X8_DATA_PARAMS* Data, + ptrdiff_t ThreadId + ) /*++ Routine Description: @@ -2864,10 +3136,11 @@ Return Value: const size_t N = Shape->N; - const size_t BlockedN = - (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_QGEMM_STRIDEN_THREAD_ALIGN; + const size_t BlockedN = (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) / + MLAS_QGEMM_STRIDEN_THREAD_ALIGN; - MlasPartitionWork(ThreadIdN, WorkBlock->ThreadCountN, BlockedN, &RangeStartN, &RangeCountN); + MlasPartitionWork(ThreadIdN, WorkBlock->ThreadCountN, BlockedN, + &RangeStartN, &RangeCountN); RangeStartN *= MLAS_QGEMM_STRIDEN_THREAD_ALIGN; RangeCountN *= MLAS_QGEMM_STRIDEN_THREAD_ALIGN; @@ -2890,10 +3163,13 @@ Return Value: GemmU8X8Operation(Shape, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); } -void MLASCALL -MlasGemm(const MLAS_GEMM_U8X8_SHAPE_PARAMS& Shape, - const MLAS_GEMM_U8X8_DATA_PARAMS& DataParams, - MLAS_THREADPOOL* ThreadPool) + +void +MLASCALL +MlasGemm( + const MLAS_GEMM_U8X8_SHAPE_PARAMS &Shape, + const MLAS_GEMM_U8X8_DATA_PARAMS &DataParams, + MLAS_THREADPOOL *ThreadPool) /*++ Routine Description: @@ -2919,11 +3195,13 @@ Return Value: MlasGemmBatch(Shape, &DataParams, 1, ThreadPool); } -void MLASCALL -MlasGemmBatch(const MLAS_GEMM_U8X8_SHAPE_PARAMS& Shape, - const MLAS_GEMM_U8X8_DATA_PARAMS* DataParams, - const size_t BatchN, - MLAS_THREADPOOL* ThreadPool) +void +MLASCALL +MlasGemmBatch( + const MLAS_GEMM_U8X8_SHAPE_PARAMS& Shape, + const MLAS_GEMM_U8X8_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool) { const size_t M = Shape.M; const size_t N = Shape.N; @@ -2965,8 +3243,9 @@ MlasGemmBatch(const MLAS_GEMM_U8X8_SHAPE_PARAMS& Shape, MLAS_GEMM_U8X8_WORK_BLOCK WorkBlock; if (N > M) { - const size_t BlockedN = - (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_QGEMM_STRIDEN_THREAD_ALIGN; + + const size_t BlockedN = (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) / + MLAS_QGEMM_STRIDEN_THREAD_ALIGN; if (size_t(ThreadsPerGemm) > BlockedN) { ThreadsPerGemm = ptrdiff_t(BlockedN); @@ -2976,6 +3255,7 @@ MlasGemmBatch(const MLAS_GEMM_U8X8_SHAPE_PARAMS& Shape, WorkBlock.ThreadCountN = ThreadsPerGemm; } else { + if (size_t(ThreadsPerGemm) > M) { ThreadsPerGemm = ptrdiff_t(M); } @@ -2992,8 +3272,14 @@ MlasGemmBatch(const MLAS_GEMM_U8X8_SHAPE_PARAMS& Shape, }); } -size_t MLASCALL -MlasGemmPackBSize(size_t N, size_t K, bool BIsSigned) + +size_t +MLASCALL +MlasGemmPackBSize( + size_t N, + size_t K, + bool BIsSigned + ) /*++ Routine Description: @@ -3041,14 +3327,22 @@ Return Value: const size_t BytesRequired = (AlignedN * sizeof(int32_t)) + (AlignedN * AlignedK * sizeof(uint8_t)); const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); - const size_t AlignedBytesRequired = - (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) & + ~(BufferAlignment - 1); return AlignedBytesRequired; } -void MLASCALL -MlasGemmPackB(size_t N, size_t K, const uint8_t* B, size_t ldb, bool BIsSigned, void* PackedB) +void +MLASCALL +MlasGemmPackB( + size_t N, + size_t K, + const uint8_t* B, + size_t ldb, + bool BIsSigned, + void* PackedB + ) /*++ Routine Description: @@ -3105,6 +3399,7 @@ Return Value: size_t CountK; for (size_t k = 0; k < K; k += CountK) { + CountK = std::min(K - k, PackedStrideK); // @@ -3116,13 +3411,13 @@ Return Value: size_t CountN; for (size_t n = 0; n < N; n += CountN) { + constexpr size_t BatchedN = 128; MLAS_DECLSPEC_ALIGN(int32_t ColumnSumBuffer[BatchedN], 64); CountN = std::min(N - n, BatchedN); - GemmU8X8Dispatch->CopyPackBRoutine(pb, B + n, ldb, CountN, CountK, ColumnSumBuffer, - BIsSigned); + GemmU8X8Dispatch->CopyPackBRoutine(pb, B + n, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned); // // Accumulate this batch of the column sum buffer into the packed