From 20cd3394fc74c35fdbc532c5c31026bc7179aed8 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:39:26 -0700 Subject: [PATCH] [MLAS] AArch64 SQNBitGemm CompInt8 initial multi-row implementation (#21193) Update AArch64 SQNBitGemm CompInt8 kernels to process matrix in tiles. E.g., computing the output in 2x2 tiles allows us to compute four elements of the output with one read of two rows of A and two columns of B. Also moved some code around as it was getting big for a single file. --- cmake/onnxruntime_mlas.cmake | 8 +- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 56 +- onnxruntime/core/mlas/lib/sqnbitgemm.h | 17 +- .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 40 +- .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 2 +- .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 40 +- .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 8 +- .../core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 1343 +---------------- .../core/mlas/lib/sqnbitgemm_kernel_neon.h | 144 ++ .../mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp | 646 ++++++++ .../mlas/lib/sqnbitgemm_kernel_neon_int8.cpp | 1315 ++++++++++++++++ .../test/mlas/unittest/test_sqnbitgemm.cpp | 13 +- 12 files changed, 2248 insertions(+), 1384 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 304aa77f54..54bddcbdcf 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -82,7 +82,10 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp ) set(mlas_platform_preprocess_srcs @@ -350,9 +353,12 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp ) - set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") if (NOT APPLE) set(mlas_platform_srcs diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 4b852be951..81789386a3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -16,10 +16,11 @@ Abstract: --*/ #include "sqnbitgemm.h" -#include "sqnbitgemm_q8_block.h" #include +#include "sqnbitgemm_q8_block.h" + namespace { @@ -80,7 +81,7 @@ MlasIsSQNBitGemmAvailable( Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; } case SQNBitGemmVariant_BitWidth4_CompInt8: { - return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr && + return Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr; } default: { @@ -372,15 +373,17 @@ SQ4BitGemm_CompFp32( if (bias) { AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); } + if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, RowsHandled, CountN, ldc ); } c_blk += ldc * RowsHandled; a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; } } @@ -431,36 +434,6 @@ SQ4BitGemm_CompInt8( const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; - if (RangeCountM == 1) { - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, size_t{128}); - - const std::byte* a_row = QuantA; - const std::byte* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const std::byte* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( - BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias - ); - - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc - ); - } - } - return; - } - - // This is a naive M > 1 implementation that repeatedly calls the M=1 kernel. - // TODO Replace it with an optimized implementation. size_t CountN; for (size_t n = 0; n < RangeCountN; n += CountN) { CountN = std::min(RangeCountN - n, size_t{128}); @@ -473,21 +446,24 @@ SQ4BitGemm_CompInt8( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - for (size_t m = 0; m < RangeCountM; ++m) { - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias ); if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, + RowsHandled, CountN, ldc ); } - c_blk += ldc; - a_row += lda; + c_blk += RowsHandled * ldc; + a_row += RowsHandled * lda; + + RowsRemaining -= RowsHandled; } } } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index effb59b250..8321dcc217 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -184,7 +184,6 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { /** * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. * A and B are block quantized and B is column major. - * This kernel handles the special case where M, the number of rows of A and C, is 1. * * @param BlkLen Number of values in a block. * @param QuantA Supplies the quantized A matrix. @@ -193,25 +192,31 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param QuantBScale Supplies the quantized B matrix block scale values. * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. * @param[out] C Supplies the output C matrix. - * @param CountN Number of columns of B and C. + * @param CountM Number of rows of A and C to process, an upper bound. + * @param CountN Number of columns of B and C to process. * @param CountK Number of columns of A and rows of B. - * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param BlockCountK Number of blocks in one row of A and one column of B. + * @param ldc Number of elements between adjacent rows of C. * @param Bias Bias vector of length N. + * + * @return The number of rows of A and C that were processed, at most CountM. */ - typedef void(SQ4BitGemmM1Kernel_CompInt8_Fn)( + typedef size_t(SQ4BitGemmKernel_CompInt8_Fn)( size_t BlkLen, const std::byte* QuantA, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, float* C, + size_t CountM, size_t CountN, size_t CountK, - size_t BlockStrideQuantB, + size_t BlockCountK, + size_t ldc, const float* Bias ); - SQ4BitGemmM1Kernel_CompInt8_Fn* SQ4BitGemmM1Kernel_CompInt8 = nullptr; + SQ4BitGemmKernel_CompInt8_Fn* SQ4BitGemmKernel_CompInt8 = nullptr; /** * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index be573381c3..0922f5ef64 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -434,6 +434,44 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( } } +size_t +SQ4BitGemmKernel_CompInt8_avx2( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + MLAS_UNREFERENCED_PARAMETER(ldc); + + if (CountM == 0) { + return 0; + } + + SQ4BitGemmM1Kernel_CompInt8_avx2( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockCountK, + Bias + ); + + return 1; +} + template MLAS_FORCEINLINE void ComputeDotProducts_BlkLen16_CompFp32_avx2( @@ -1109,7 +1147,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; + d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx2; return d; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 0099b61d81..b868906760 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -239,7 +239,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; + d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx512; return d; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 27310d8253..6477a2019b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -237,6 +237,44 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( } } +size_t +SQ4BitGemmKernel_CompInt8_avx512vnni( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + MLAS_UNREFERENCED_PARAMETER(ldc); + + if (CountM == 0) { + return 0; + } + + SQ4BitGemmM1Kernel_CompInt8_avx512vnni( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockCountK, + Bias + ); + + return 1; +} + void MLASCALL MlasQ80BlkQuantRow_avx512( size_t BlkLen, @@ -260,7 +298,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx512vnni; + d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx512vnni; d.QuantizeARow_CompInt8 = MlasQ80BlkQuantRow_avx512; return d; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index cfc0564cd0..706e08fc46 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -158,17 +158,19 @@ Q4BitBlkDequantBForSgemm_CompFp32_avx2( const size_t BlockStrideQuantB ); -void -SQ4BitGemmM1Kernel_CompInt8_avx2( +size_t +SQ4BitGemmKernel_CompInt8_avx2( size_t BlkLen, const std::byte* QuantA, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, float* C, + size_t CountM, size_t CountN, size_t CountK, - size_t BlockStrideQuantB, + size_t BlockCountK, + size_t ldc, const float* Bias ); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 6d1864794f..3f32cc6c53 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm_kernel_neon.h + sqnbitgemm_kernel_neon.cpp Abstract: @@ -17,20 +17,22 @@ Abstract: #include -#include #include -#include #include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" +namespace sqnbitgemm_neon +{ + +namespace +{ + // // Quantized B data packing function implementation. // -namespace -{ - size_t SQ4BitGemmPackQuantBDataSize( size_t N, @@ -134,7 +136,7 @@ SQ4BitGemmPerGemmWorkspaceSize( { MLAS_UNREFERENCED_PARAMETER(N); - switch(ComputeType) { + switch (ComputeType) { case CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); @@ -167,1316 +169,7 @@ SQ4BitGemmPerGemmWorkspaceAlignment( } // namespace -// -// General helpers. -// - -namespace -{ - -template -MLAS_FORCEINLINE void -UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) -{ - (f(Indices), ...); -} - -template -MLAS_FORCEINLINE void -UnrolledLoop(IterationFn&& f) -{ - UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); -} - -MLAS_FORCEINLINE void -Transpose4x4(float32x4_t& a0, float32x4_t& a1, float32x4_t& a2, float32x4_t& a3) -{ - // aN: aN_0 aN_1 aN_2 aN_3 - - float32x4_t b0 = vzip1q_f32(a0, a1); // a0_0 a1_0 a0_1 a1_1 - float32x4_t b1 = vzip2q_f32(a0, a1); // a0_2 a1_2 a0_3 a1_3 - float32x4_t b2 = vzip1q_f32(a2, a3); // a2_0 a3_0 a2_1 a3_1 - float32x4_t b3 = vzip2q_f32(a2, a3); // a2_2 a3_2 a2_3 a3_3 - - // a0_0 a1_0 a2_0 a3_0 - a0 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); - // a0_1 a1_1 a2_1 a3_1 - a1 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); - // a0_2 a1_2 a3_2 a3_2 - a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); - // a0_3 a1_3 a2_3 a3_3 - a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); -} - -MLAS_FORCEINLINE float32x4_t -FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) -{ - Transpose4x4(a0, a1, a2, a3); - return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3)); -} - -template -MLAS_FORCEINLINE void -LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) -{ - static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); - - assert(count <= Capacity); - - size_t vi = 0; // vector index - - // handle 4 values at a time - while (count > 3) { - dst[vi] = vld1q_f32(src); - - vi += 1; - src += 4; - count -= 4; - } - - // handle remaining values - if (count > 0) { - dst[vi] = vsetq_lane_f32(src[0], dst[vi], 0); - - if (count > 1) { - dst[vi] = vsetq_lane_f32(src[1], dst[vi], 1); - - if (count > 2) { - dst[vi] = vsetq_lane_f32(src[2], dst[vi], 2); - } - } - } -} - -} // namespace - -// -// CompFp32 kernel implementation. -// - -namespace -{ - -namespace fp32_conversion -{ - -// Manual conversion to float takes place in two steps: -// 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f]. -// This target float range is convenient because the 4-bit source values can be placed directly into the -// target float bits. -// 2. Subtract the conversion offset of 16 from the float result. - -// The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values. -constexpr uint16_t float_high_half_template = 0b0'10000011'0000000; -// sign|exponent|partial mantissa -// +|131: 2^4|~~~~ <- 4 bits go here - -const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template); - -constexpr float offset = 16.0f; - -} // namespace fp32_conversion - -template -MLAS_FORCEINLINE void -ComputeDotProducts_BlkBitWidth4_CompFp32( - size_t BlkLen, - const float* ARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - float* SumPtr, - size_t CountK, - size_t StrideQuantBData, - size_t StrideQuantBScale, - size_t StrideQuantBZeroPoint, - const float* BiasPtr -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t SubBlkLen = 16; - - static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); - - assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); - - const uint8x8_t LowMask = vdup_n_u8(0x0F); - - float32x4_t acc[NCols]{}; - - const std::byte* QuantBData = QuantBDataColPtr; - const float* QuantBScale = QuantBScaleColPtr; - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - // only used if HasZeroPoint is true - - for (size_t k = 0; k < CountK; k += BlkLen) { - const size_t k_blk_len = std::min(CountK - k, BlkLen); - - float scale[NCols]; - UnrolledLoop( - [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } - ); - - [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset. - // only used if HasZeroPoint is true - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - const std::byte zp_packed = - QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) - ? (zp_packed >> 4) - : (zp_packed & std::byte{0x0F}); - offset[i] = fp32_conversion::offset + std::to_integer(zp); - }); - } - - for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { - // load A row vector elements - - // load `SubBlkLen` elements from A, padded with 0's if there aren't enough - const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen); - float32x4_t av[4]{}; - LoadFloatData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); - - // load B column vectors - uint8x8_t bv_packed[NCols]; - const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; - UnrolledLoop([&](size_t i) { - bv_packed[i] = vld1_u8( - reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset - ); - }); - - uint8x8_t bv_u8[NCols][2]; - UnrolledLoop([&](size_t i) { - bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); - bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); - }); - - // shift left 3 and widen to 16 bits - uint16x8_t bv_u16[NCols][2]; - UnrolledLoop([&](size_t i) { - constexpr int shift = 3; - bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); - bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); - }); - - // combine 4 bits with float high half template - UnrolledLoop([&](size_t i) { - bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); - bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); - }); - - // `SubBlkLen` floats of B - float32x4_t bv[NCols][4]; - - // shift left 16, widen to 32 bits, and reinterpret as float - UnrolledLoop([&](size_t i) { - constexpr int shift = 16; - bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); - bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); - - bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); - bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); - }); - - // subtract float conversion offset and zero point - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - const float32x4_t offset_v = vdupq_n_f32(offset[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } else { - const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); - UnrolledLoop([&](size_t i) { - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } - - // multiply by scale - UnrolledLoop([&](size_t i) { - const float32x4_t scale_v = vdupq_n_f32(scale[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); - }); - - // c[m,n] += a[m,k] * b[k,n] - UnrolledLoop<4>([&](size_t j) { - UnrolledLoop([&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); - }); - } - - // increment pointers to next block - QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - QuantBScale += 1; - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx += 1; - } - } - - if constexpr (NCols == 4) { - float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - - if (BiasPtr != nullptr) { - sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); - } - - vst1q_f32(SumPtr, sum); - } else { - for (size_t i = 0; i < NCols; ++i) { - SumPtr[i] = vaddvq_f32(acc[i]); - if (BiasPtr != nullptr) { - SumPtr[i] += BiasPtr[i]; - } - } - } -} - -template -void -SQ4BitGemmM1Kernel_CompFp32_Impl( - size_t BlkLen, - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t NCols = 4; - - const float* ARowPtr = A; - float* CRowPtr = C; - - const size_t BlockCountK = BlockStrideQuantB; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - int64_t nblk = static_cast(CountN) - NCols; - - while (nblk >= 0) { - ComputeDotProducts_BlkBitWidth4_CompFp32( - BlkLen, - ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - - // move to next `NCols` columns - - QuantBDataColPtr += NCols * StrideQuantBData; - QuantBScaleColPtr += NCols * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? NCols : 0; - SumPtr += NCols; - - nblk -= NCols; - } - - // left over columns less than `NCols`? - nblk += NCols; - for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>( - BlkLen, - ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -void -SQ4BitGemmM1Kernel_CompFp32( - size_t BlkLen, - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_CompFp32_Impl( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_CompFp32_Impl( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } -} - -// Block dequantize a 16 x NCols section of B from column major source to row major destination. -template -MLAS_FORCEINLINE void -Q4BitBlkDequantB_16xNCols( - const std::byte* QuantBDataPtr, - size_t StrideQuantBData, - const float* QuantBColScalePtr, // pointer to NCols scales of adjacent columns - [[maybe_unused]] const float* QuantBColOffsetPtr, // pointer to NCols offsets of adjacent columns - // only used if HasZeroPoint is true - float* DstColPtr -) -{ - const uint8x8_t LowMask = vdup_n_u8(0x0F); - - // load B column vectors - uint8x8_t bv_packed[NCols]; - UnrolledLoop([&](size_t i) { - bv_packed[i] = vld1_u8( - reinterpret_cast(QuantBDataPtr) + i * StrideQuantBData - ); - }); - - uint8x8_t bv_u8[NCols][2]; - UnrolledLoop([&](size_t i) { - bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); - bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); - }); - - // shift left 3 and widen to 16 bits - uint16x8_t bv_u16[NCols][2]; - UnrolledLoop([&](size_t i) { - constexpr int shift = 3; - bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); - bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); - }); - - // combine 4 bits with float high half template - UnrolledLoop([&](size_t i) { - bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); - bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); - }); - - // `SubBlkLen` floats of B - float32x4_t bv[NCols][4]; - - // shift left 16, widen to 32 bits, and reinterpret as float - UnrolledLoop([&](size_t i) { - constexpr int shift = 16; - bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); - bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); - - bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); - bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); - }); - - // subtract float conversion offset and zero point - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - const float32x4_t offset_v = vdupq_n_f32(QuantBColOffsetPtr[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } else { - const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); - UnrolledLoop([&](size_t i) { - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } - - // multiply by scale - UnrolledLoop([&](size_t i) { - const float32x4_t scale_v = vdupq_n_f32(QuantBColScalePtr[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); - }); - - // write, transposed, 16 x NCols values - if constexpr (NCols == 4) { - UnrolledLoop<4>([&](size_t j) { - Transpose4x4(bv[0][j], bv[1][j], bv[2][j], bv[3][j]); - - vst1q_f32(&DstColPtr[(j * 4 + 0) * 16], bv[0][j]); - vst1q_f32(&DstColPtr[(j * 4 + 1) * 16], bv[1][j]); - vst1q_f32(&DstColPtr[(j * 4 + 2) * 16], bv[2][j]); - vst1q_f32(&DstColPtr[(j * 4 + 3) * 16], bv[3][j]); - }); - } else { - UnrolledLoop([&](size_t i) { - UnrolledLoop<4>([&](size_t j) { - DstColPtr[(j * 4 + 0) * 16 + i] = vgetq_lane_f32(bv[i][j], 0); - DstColPtr[(j * 4 + 1) * 16 + i] = vgetq_lane_f32(bv[i][j], 1); - DstColPtr[(j * 4 + 2) * 16 + i] = vgetq_lane_f32(bv[i][j], 2); - DstColPtr[(j * 4 + 3) * 16 + i] = vgetq_lane_f32(bv[i][j], 3); - }); - }); - } -} - -template -void -Q4BitBlkDequantBForSgemm_CompFp32_Impl( - size_t BlkLen, - float* FpData, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB -) -{ - constexpr size_t BlkBitWidth = 4; - - float* Dst = FpData; - - const std::byte* QuantBDataCol = QuantBData; - const float* QuantBScaleCol = QuantBScale; - [[maybe_unused]] const std::byte* QuantBZeroPointCol = QuantBZeroPoint; // only used if HasZeroPoint is true - - const size_t StrideQuantBData = BlockStrideQuantB * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - [[maybe_unused]] const size_t StrideQuantBZeroPoint = // only used if HasZeroPoint is true - MlasQNBitZeroPointsForBlksSizeInBytes(BlockStrideQuantB); - - // - // Proceed down 16 column-wide regions of B. Dequantize and write output 16 x 16 elements at a time. - // - - // scales of blocks from 16 adjacent columns - float scale[16]; - // float conversion offsets (including zero point) of blocks from 16 adjacent columns - [[maybe_unused]] float offset[16]; // only used if HasZeroPoint is true - - size_t n_cols_remaining = CountN; - while (n_cols_remaining > 15) { - for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { - for (size_t nn = 0; nn < 16; ++nn) { - scale[nn] = QuantBScaleCol[nn * BlockStrideQuantB + k_blk_idx]; - - if constexpr (HasZeroPoint) { - const std::byte zp_packed = - QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; - const std::byte zp = ((k_blk_idx & 1) == 1) - ? (zp_packed >> 4) - : (zp_packed & std::byte{0x0F}); - offset[nn] = fp32_conversion::offset + std::to_integer(zp); - } - } - - const size_t kklen = std::min(CountK - k, BlkLen); - - for (size_t kk = 0; kk < kklen; kk += 16) { - constexpr size_t NCols = 4; - - const float* ScalePtr = &scale[0]; - const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; - - float* DstColPtr = Dst; - - for (size_t nn = 0; nn < 16; nn += NCols) { - const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; - - Q4BitBlkDequantB_16xNCols( - QuantBDataPtr, - StrideQuantBData, - ScalePtr, - OffsetPtr, - DstColPtr - ); - - ScalePtr += NCols; - if constexpr (HasZeroPoint) { - OffsetPtr += NCols; - } - DstColPtr += NCols; - } - - Dst += 16 * std::min(kklen - kk, size_t{16}); - } - } - - n_cols_remaining -= 16; - - QuantBDataCol += 16 * StrideQuantBData; - QuantBScaleCol += 16 * BlockStrideQuantB; - if constexpr (HasZeroPoint) { - QuantBZeroPointCol += 16 * StrideQuantBZeroPoint; - } - } - - if (n_cols_remaining > 0) { - for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { - for (size_t nn = 0; nn < n_cols_remaining; ++nn) { - scale[nn] = QuantBScaleCol[nn * BlockStrideQuantB + k_blk_idx]; - - if constexpr (HasZeroPoint) { - const std::byte zp_packed = - QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; - const std::byte zp = ((k_blk_idx & 1) == 1) - ? (zp_packed >> 4) - : (zp_packed & std::byte{0x0F}); - offset[nn] = fp32_conversion::offset + std::to_integer(zp); - } - } - - const size_t kklen = std::min(CountK - k, BlkLen); - - for (size_t kk = 0; kk < kklen; kk += 16) { - // zero out the 16x16 block in Dst first to ensure zero padding - const float32x4_t zero_v = vdupq_n_f32(0.0f); - UnrolledLoop<16 * 4>([&](size_t i) { - vst1q_f32(Dst + 4 * i, zero_v); - }); - - const float* ScalePtr = &scale[0]; - const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; - - float* DstColPtr = Dst; - - for (size_t nn = 0; nn < n_cols_remaining; ++nn) { - const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; - - Q4BitBlkDequantB_16xNCols<1, HasZeroPoint>( - QuantBDataPtr, - StrideQuantBData, - ScalePtr, - OffsetPtr, - DstColPtr - ); - - ScalePtr += 1; - if constexpr (HasZeroPoint) { - OffsetPtr += 1; - } - DstColPtr += 1; - } - - Dst += 16 * std::min(kklen - kk, size_t{16}); - } - } - } -} - -void -Q4BitBlkDequantBForSgemm_CompFp32( - size_t BlkLen, - float* FpData, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB -) -{ - if (QuantBZeroPoint != nullptr) { - Q4BitBlkDequantBForSgemm_CompFp32_Impl( - BlkLen, - FpData, - QuantBData, - QuantBScale, - QuantBZeroPoint, - CountN, - CountK, - BlockStrideQuantB - ); - } else { - Q4BitBlkDequantBForSgemm_CompFp32_Impl( - BlkLen, - FpData, - QuantBData, - QuantBScale, - QuantBZeroPoint, - CountN, - CountK, - BlockStrideQuantB - ); - } -} - -// -// CompInt8 kernel implementation. -// - -template -MLAS_FORCEINLINE void -QuantizeBlock( - size_t BlkLen, - const float* A, - size_t ElementCount, - std::byte* QuantA -) -{ - static_assert(SubBlkLen >= 16 && SubBlkLen % 16 == 0); - - assert(BlkLen % SubBlkLen == 0); - - // - // Scan block values first to determine scale. - // - - float amax = 0.0f; // max of absolute values of A block - - size_t k; - for (k = 0; k < ElementCount; k += SubBlkLen) { - const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - - float32x4_t a[SubBlkLen / 4]{}; - LoadFloatData(A + k, SubBlkElementCount, a); - - float32x4_t abs_a[SubBlkLen / 4]; - UnrolledLoop([&](size_t i) { - abs_a[i] = vabsq_f32(a[i]); - }); - - // find amax of SubBlkLen elements - for (size_t interval = SubBlkLen / 4 / 2; interval > 0; interval /= 2) { - for (size_t i = 0; i < interval; ++i) { - abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); - } - } - - // update existing amax - amax = std::max(amax, vmaxvq_f32(abs_a[0])); - } - - constexpr float range_max = (1 << 7) - 1; - const float scale = amax / range_max; - const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; - - Q8BlkScale(QuantA) = scale; - - // - // Compute quantized block values. - // - - int8_t* QuantAData = Q8BlkData(QuantA); - - for (k = 0; k < ElementCount; k += SubBlkLen) { - const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - - float32x4_t a[SubBlkLen / 4]{}; - LoadFloatData(A + k, SubBlkElementCount, a); - - UnrolledLoop([&](size_t i) { - a[i] = vmulq_n_f32(a[i], scale_reciprocal); - }); - - int32x4_t a_s32[SubBlkLen / 4]; - UnrolledLoop([&](size_t i) { - a_s32[i] = vcvtaq_s32_f32(a[i]); - }); - - UnrolledLoop([&](size_t i) { - QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); - QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); - QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); - QuantAData[k + i * 4 + 3] = static_cast(vgetq_lane_s32(a_s32[i], 3)); - }); - } - - // - // Zero out any remaining sub-block elements. - // - - for (; k < BlkLen; k += SubBlkLen) { - const int8x16_t Zeros = vdupq_n_s8(0); - UnrolledLoop([&](size_t i) { - vst1q_s8(QuantAData + k + i * 16, Zeros); - }); - } -} - -void -QuantizeARow_CompInt8( - size_t BlkLen, - const float* A, - size_t CountK, - std::byte* QuantA -) -{ - const float* ADataBlkPtr = A; - std::byte* QuantABlkPtr = QuantA; - - for (size_t k = 0; k < CountK; k += BlkLen) { - const size_t k_blk_len = std::min(CountK - k, BlkLen); - - QuantizeBlock<16>(BlkLen, ADataBlkPtr, k_blk_len, QuantABlkPtr); - - ADataBlkPtr += BlkLen; - QuantABlkPtr += Q8BlkSize(BlkLen); - } -} - -template -void -SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen16( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t BlkLen = 16; - - float* CRowPtr = C; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); - - for (size_t n = 0; n < CountN; ++n) { - const std::byte* QuantAPtr = QuantA; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - float32x4_t acc0{}, acc1{}; - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= 2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); - const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] & std::byte{0x0F}) : 8 - ); - const int8x16_t bzp1 = vdupq_n_s8( - HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] >> 4) : 8 - ); - - // load A - const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); - const int8x16_t av1 = vld1q_s8(Q8BlkData(QuantABlk1)); - - // load B - const uint8x16_t bv_packed01 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - - const uint8x16_t bv_lo01 = vandq_u8(bv_packed01, LowMaskU8x16); - const uint8x16_t bv_hi01 = vshrq_n_u8(bv_packed01, 4); - - int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(vget_low_u8(bv_lo01), vget_low_u8(bv_hi01))); - int8x16_t bv1 = vreinterpretq_s8_u8(vcombine_u8(vget_high_u8(bv_lo01), vget_high_u8(bv_hi01))); - - // subtract B zero point - bv0 = vsubq_s8(bv0, bzp0); - bv1 = vsubq_s8(bv1, bzp1); - - // quantized dot product - const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); - const int32x4_t dot1 = vdotq_s32(vdupq_n_s32(0), av1, bv1); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); - - // increment block pointers - - QuantAPtr += Q8BlkSize(BlkLen) * 2; - QuantBDataPtr += 8 * 2; - QuantBScalePtr += 2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - if (k_blks_remaining > 0) { - const std::byte* QuantABlk0 = QuantAPtr; - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] & std::byte{0x0F}) : 8 - ); - - // load A - const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); - - // load B - const uint8x8_t bv_packed0 = vld1_u8(reinterpret_cast(QuantBDataPtr)); - - const uint8x8_t bv_lo0 = vand_u8(bv_packed0, LowMaskU8x8); - const uint8x8_t bv_hi0 = vshr_n_u8(bv_packed0, 4); - - int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(bv_lo0, bv_hi0)); - - // subtract B zero point - bv0 = vsubq_s8(bv0, bzp0); - - // quantized dot product - const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - } - - *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -template -void -SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen32( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t BlkLen = 32; - - float* CRowPtr = C; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - - for (size_t n = 0; n < CountN; ++n) { - const std::byte* QuantAPtr = QuantA; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - float32x4_t acc0{}, acc1{}; - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= 2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); - const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 - ); - const int8x16_t bzp1 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) >> 4) : 8 - ); - - // load A - const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); - const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); - const int8x16_t av_lo1 = vld1q_s8(Q8BlkData(QuantABlk1)); - const int8x16_t av_hi1 = vld1q_s8(Q8BlkData(QuantABlk1) + 16); - - // load B - const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); - - int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); - int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); - int8x16_t bv_lo1 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); - int8x16_t bv_hi1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); - - // subtract B zero point - bv_lo0 = vsubq_s8(bv_lo0, bzp0); - bv_hi0 = vsubq_s8(bv_hi0, bzp0); - bv_lo1 = vsubq_s8(bv_lo1, bzp1); - bv_hi1 = vsubq_s8(bv_hi1, bzp1); - - // quantized dot product - int32x4_t dot0{}, dot1{}; - dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); - dot1 = vdotq_s32(vdotq_s32(dot1, av_lo1, bv_lo1), av_hi1, bv_hi1); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); - - // increment block pointers - - QuantAPtr += Q8BlkSize(BlkLen) * 2; - QuantBDataPtr += 16 * 2; - QuantBScalePtr += 2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - if (k_blks_remaining > 0) { - const std::byte* QuantABlk0 = QuantAPtr; - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 - ); - - // load A - const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); - const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); - - // load B - const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - - int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); - int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); - - // subtract B zero point - bv_lo0 = vsubq_s8(bv_lo0, bzp0); - bv_hi0 = vsubq_s8(bv_hi0, bzp0); - - // quantized dot product - int32x4_t dot0{}; - dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - } - - *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -template -void -SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLenGreaterThan32( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - - assert(BlkLen > 32); - assert(BlkLen % 32 == 0); - - float* CRowPtr = C; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - - // process blocks in 32-element sub-blocks - const size_t SubBlksPerBlk = BlkLen / 32; - - for (size_t n = 0; n < CountN; ++n) { - const std::byte* QuantAPtr = QuantA; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - float32x4_t acc0{}, acc1{}; - - for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { - // compute combined scale - const float32x4_t scale = vdupq_n_f32(Q8BlkScale(QuantAPtr) * (*QuantBScalePtr)); - - // load B zero point - const int8x16_t bzp = [&]() -> int8x16_t { - if constexpr (HasZeroPoint) { - return vdupq_n_s8( - ((k_blk_idx & 1) == 0) ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) - : std::to_integer((*QuantBZeroPointPtr) >> 4) - ); - } else { - return vdupq_n_s8(8); - } - }(); - - const int8_t* QuantADataPtr = Q8BlkData(QuantAPtr); - - for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; sub_blk_idx += 2) { - // load A - const int8x16_t av0 = vld1q_s8(QuantADataPtr + 0); - const int8x16_t av1 = vld1q_s8(QuantADataPtr + 16); - const int8x16_t av2 = vld1q_s8(QuantADataPtr + 32); - const int8x16_t av3 = vld1q_s8(QuantADataPtr + 48); - - // load B - const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); - - int8x16_t bv0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); - int8x16_t bv1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); - int8x16_t bv2 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); - int8x16_t bv3 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); - - // subtract B zero point - bv0 = vsubq_s8(bv0, bzp); - bv1 = vsubq_s8(bv1, bzp); - bv2 = vsubq_s8(bv2, bzp); - bv3 = vsubq_s8(bv3, bzp); - - // quantized dot product - int32x4_t dot0{}, dot1{}; - dot0 = vdotq_s32(vdotq_s32(dot0, av0, bv0), av1, bv1); - dot1 = vdotq_s32(vdotq_s32(dot1, av2, bv2), av3, bv3); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale); - acc1 = vfmaq_f32(acc1, dot_f32_1, scale); - - // increment block data pointers to next sub-block - QuantADataPtr += 16 * 4; - QuantBDataPtr += 16 * 2; - } - - // increment other block pointers - - QuantAPtr += Q8BlkSize(BlkLen); - QuantBScalePtr += 1; - - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; - } - } - - *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -template -MLAS_FORCEINLINE void -SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockStrideQuantB, - const float* Bias -) -{ - if (BlkLen == 16) { - SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen16( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen32( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLenGreaterThan32( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } -} - -void -SQ4BitGemmM1Kernel_CompInt8( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t /*CountK*/, - size_t BlockStrideQuantB, - const float* Bias -) -{ - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } -} - -} // namespace +} // namespace sqnbitgemm_neon // // Kernel dispatch structure definition. @@ -1485,17 +178,17 @@ SQ4BitGemmM1Kernel_CompInt8( const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { MLAS_SQNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataSize = sqnbitgemm_neon::SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.SQ4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceAlignment; - d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32; + d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32; + d.Q4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::Q4BitBlkDequantBForSgemm_CompFp32; - d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; + d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h new file mode 100644 index 0000000000..ef9345d7ac --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h @@ -0,0 +1,144 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon.h + +Abstract: + + This module includes function declarations and common helper functions for + SQNBitGemm ARM NEON kernels. + +--*/ + +#pragma once + +#include + +#include +#include +#include + +#include "mlasi.h" + +namespace sqnbitgemm_neon +{ + +// +// Function declarations for SQNBitGemm ARM NEON kernel entry points. +// Refer to the prototypes in sqnbitgemm.h for documentation. +// These are declared here so they can be used to initialize the +// MLAS_SQNBIT_GEMM_DISPATCH structure and also be implemented in separate +// files. +// + +// CompFp32 declarations + +void +SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias +); + +void +Q4BitBlkDequantBForSgemm_CompFp32( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockCountK +); + +// CompInt8 declarations + +void +QuantizeARow_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +); + +size_t +SQ4BitGemmKernel_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + size_t ldc, + const float* Bias +); + +// +// General helpers. +// + +template +MLAS_FORCEINLINE void +UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) +{ + (f(Indices), ...); +} + +template +MLAS_FORCEINLINE void +UnrolledLoop(IterationFn&& f) +{ + UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); +} + +template +MLAS_FORCEINLINE void +LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) +{ + static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); + + assert(count <= Capacity); + + size_t vi = 0; // vector index + + // handle 4 values at a time + while (count > 3) { + dst[vi] = vld1q_f32(src); + + vi += 1; + src += 4; + count -= 4; + } + + // handle remaining values + if (count > 0) { + dst[vi] = vsetq_lane_f32(src[0], dst[vi], 0); + + if (count > 1) { + dst[vi] = vsetq_lane_f32(src[1], dst[vi], 1); + + if (count > 2) { + dst[vi] = vsetq_lane_f32(src[2], dst[vi], 2); + } + } + } +} + +} // namespace sqnbitgemm_neon diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp new file mode 100644 index 0000000000..ca64ebe3b1 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp @@ -0,0 +1,646 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon_fp32.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON specific to + MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompFp32. + +--*/ + +#include + +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_neon.h" + +namespace sqnbitgemm_neon +{ + +namespace +{ + +// +// CompFp32 kernel implementation. +// + +MLAS_FORCEINLINE void +Transpose4x4(float32x4_t& a0, float32x4_t& a1, float32x4_t& a2, float32x4_t& a3) +{ + // aN: aN_0 aN_1 aN_2 aN_3 + + float32x4_t b0 = vzip1q_f32(a0, a1); // a0_0 a1_0 a0_1 a1_1 + float32x4_t b1 = vzip2q_f32(a0, a1); // a0_2 a1_2 a0_3 a1_3 + float32x4_t b2 = vzip1q_f32(a2, a3); // a2_0 a3_0 a2_1 a3_1 + float32x4_t b3 = vzip2q_f32(a2, a3); // a2_2 a3_2 a2_3 a3_3 + + // a0_0 a1_0 a2_0 a3_0 + a0 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_1 a1_1 a2_1 a3_1 + a1 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_2 a1_2 a3_2 a3_2 + a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); + // a0_3 a1_3 a2_3 a3_3 + a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); +} + +MLAS_FORCEINLINE float32x4_t +FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) +{ + Transpose4x4(a0, a1, a2, a3); + return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3)); +} + +namespace fp32_conversion +{ + +// Manual conversion to float takes place in two steps: +// 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f]. +// This target float range is convenient because the 4-bit source values can be placed directly into the +// target float bits. +// 2. Subtract the conversion offset of 16 from the float result. + +// The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values. +constexpr uint16_t float_high_half_template = 0b0'10000011'0000000; +// sign|exponent|partial mantissa +// +|131: 2^4|~~~~ <- 4 bits go here + +const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template); + +constexpr float offset = 16.0f; + +} // namespace fp32_conversion + +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkBitWidth4_CompFp32( + size_t BlkLen, + const float* ARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkLen = 16; + + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + + assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); + + const uint8x8_t LowMask = vdup_n_u8(0x0F); + + float32x4_t acc[NCols]{}; + + const std::byte* QuantBData = QuantBDataColPtr; + const float* QuantBScale = QuantBScaleColPtr; + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint is true + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + float scale[NCols]; + UnrolledLoop( + [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } + ); + + [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset. + // only used if HasZeroPoint is true + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = fp32_conversion::offset + std::to_integer(zp); + }); + } + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { + // load A row vector elements + + // load `SubBlkLen` elements from A, padded with 0's if there aren't enough + const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen); + float32x4_t av[4]{}; + LoadFloatData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); + + // load B column vectors + uint8x8_t bv_packed[NCols]; + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset + ); + }); + + uint8x8_t bv_u8[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); + }); + + // shift left 3 and widen to 16 bits + uint16x8_t bv_u16[NCols][2]; + UnrolledLoop([&](size_t i) { + constexpr int shift = 3; + bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); + bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); + }); + + // combine 4 bits with float high half template + UnrolledLoop([&](size_t i) { + bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); + bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); + }); + + // `SubBlkLen` floats of B + float32x4_t bv[NCols][4]; + + // shift left 16, widen to 32 bits, and reinterpret as float + UnrolledLoop([&](size_t i) { + constexpr int shift = 16; + bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); + bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); + + bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); + bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); + }); + + // subtract float conversion offset and zero point + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(offset[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } else { + const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } + + // multiply by scale + UnrolledLoop([&](size_t i) { + const float32x4_t scale_v = vdupq_n_f32(scale[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); + }); + + // c[m,n] += a[m,k] * b[k,n] + UnrolledLoop<4>([&](size_t j) { + UnrolledLoop([&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); + }); + } + + // increment pointers to next block + QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScale += 1; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } + + if constexpr (NCols == 4) { + float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + + if (BiasPtr != nullptr) { + sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); + } + + vst1q_f32(SumPtr, sum); + } else { + for (size_t i = 0; i < NCols; ++i) { + SumPtr[i] = vaddvq_f32(acc[i]); + if (BiasPtr != nullptr) { + SumPtr[i] += BiasPtr[i]; + } + } + } +} + +template +void +SQ4BitGemmM1Kernel_CompFp32_Impl( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t NCols = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols; + + while (nblk >= 0) { + ComputeDotProducts_BlkBitWidth4_CompFp32( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + + nblk -= NCols; + } + + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +} // namespace + +void +SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + constexpr bool HasZeroPoint = true; + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockCountK, + Bias + ); + } else { + constexpr bool HasZeroPoint = false; + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockCountK, + Bias + ); + } +} + +namespace +{ + +// Block dequantize a 16 x NCols section of B from column major source to row major destination. +template +MLAS_FORCEINLINE void +Q4BitBlkDequantB_16xNCols( + const std::byte* QuantBDataPtr, + size_t StrideQuantBData, + const float* QuantBColScalePtr, // pointer to NCols scales of adjacent columns + [[maybe_unused]] const float* QuantBColOffsetPtr, // pointer to NCols offsets of adjacent columns + // only used if HasZeroPoint is true + float* DstColPtr +) +{ + const uint8x8_t LowMask = vdup_n_u8(0x0F); + + // load B column vectors + uint8x8_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBDataPtr) + i * StrideQuantBData + ); + }); + + uint8x8_t bv_u8[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); + }); + + // shift left 3 and widen to 16 bits + uint16x8_t bv_u16[NCols][2]; + UnrolledLoop([&](size_t i) { + constexpr int shift = 3; + bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); + bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); + }); + + // combine 4 bits with float high half template + UnrolledLoop([&](size_t i) { + bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); + bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); + }); + + // `SubBlkLen` floats of B + float32x4_t bv[NCols][4]; + + // shift left 16, widen to 32 bits, and reinterpret as float + UnrolledLoop([&](size_t i) { + constexpr int shift = 16; + bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); + bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); + + bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); + bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); + }); + + // subtract float conversion offset and zero point + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(QuantBColOffsetPtr[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } else { + const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } + + // multiply by scale + UnrolledLoop([&](size_t i) { + const float32x4_t scale_v = vdupq_n_f32(QuantBColScalePtr[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); + }); + + // write, transposed, 16 x NCols values + if constexpr (NCols == 4) { + UnrolledLoop<4>([&](size_t j) { + Transpose4x4(bv[0][j], bv[1][j], bv[2][j], bv[3][j]); + + vst1q_f32(&DstColPtr[(j * 4 + 0) * 16], bv[0][j]); + vst1q_f32(&DstColPtr[(j * 4 + 1) * 16], bv[1][j]); + vst1q_f32(&DstColPtr[(j * 4 + 2) * 16], bv[2][j]); + vst1q_f32(&DstColPtr[(j * 4 + 3) * 16], bv[3][j]); + }); + } else { + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { + DstColPtr[(j * 4 + 0) * 16 + i] = vgetq_lane_f32(bv[i][j], 0); + DstColPtr[(j * 4 + 1) * 16 + i] = vgetq_lane_f32(bv[i][j], 1); + DstColPtr[(j * 4 + 2) * 16 + i] = vgetq_lane_f32(bv[i][j], 2); + DstColPtr[(j * 4 + 3) * 16 + i] = vgetq_lane_f32(bv[i][j], 3); + }); + }); + } +} + +template +void +Q4BitBlkDequantBForSgemm_CompFp32_Impl( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockCountK +) +{ + constexpr size_t BlkBitWidth = 4; + + float* Dst = FpData; + + const std::byte* QuantBDataCol = QuantBData; + const float* QuantBScaleCol = QuantBScale; + [[maybe_unused]] const std::byte* QuantBZeroPointCol = QuantBZeroPoint; // only used if HasZeroPoint is true + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + [[maybe_unused]] const size_t StrideQuantBZeroPoint = // only used if HasZeroPoint is true + MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + // + // Proceed down 16 column-wide regions of B. Dequantize and write output 16 x 16 elements at a time. + // + + // scales of blocks from 16 adjacent columns + float scale[16]; + // float conversion offsets (including zero point) of blocks from 16 adjacent columns + [[maybe_unused]] float offset[16]; // only used if HasZeroPoint is true + + size_t n_cols_remaining = CountN; + while (n_cols_remaining > 15) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { + for (size_t nn = 0; nn < 16; ++nn) { + scale[nn] = QuantBScaleCol[nn * BlockCountK + k_blk_idx]; + + if constexpr (HasZeroPoint) { + const std::byte zp_packed = + QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; + const std::byte zp = ((k_blk_idx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[nn] = fp32_conversion::offset + std::to_integer(zp); + } + } + + const size_t kklen = std::min(CountK - k, BlkLen); + + for (size_t kk = 0; kk < kklen; kk += 16) { + constexpr size_t NCols = 4; + + const float* ScalePtr = &scale[0]; + const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; + + float* DstColPtr = Dst; + + for (size_t nn = 0; nn < 16; nn += NCols) { + const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; + + Q4BitBlkDequantB_16xNCols( + QuantBDataPtr, + StrideQuantBData, + ScalePtr, + OffsetPtr, + DstColPtr + ); + + ScalePtr += NCols; + if constexpr (HasZeroPoint) { + OffsetPtr += NCols; + } + DstColPtr += NCols; + } + + Dst += 16 * std::min(kklen - kk, size_t{16}); + } + } + + n_cols_remaining -= 16; + + QuantBDataCol += 16 * StrideQuantBData; + QuantBScaleCol += 16 * BlockCountK; + if constexpr (HasZeroPoint) { + QuantBZeroPointCol += 16 * StrideQuantBZeroPoint; + } + } + + if (n_cols_remaining > 0) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { + for (size_t nn = 0; nn < n_cols_remaining; ++nn) { + scale[nn] = QuantBScaleCol[nn * BlockCountK + k_blk_idx]; + + if constexpr (HasZeroPoint) { + const std::byte zp_packed = + QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; + const std::byte zp = ((k_blk_idx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[nn] = fp32_conversion::offset + std::to_integer(zp); + } + } + + const size_t kklen = std::min(CountK - k, BlkLen); + + for (size_t kk = 0; kk < kklen; kk += 16) { + // zero out the 16x16 block in Dst first to ensure zero padding + const float32x4_t zero_v = vdupq_n_f32(0.0f); + UnrolledLoop<16 * 4>([&](size_t i) { + vst1q_f32(Dst + 4 * i, zero_v); + }); + + const float* ScalePtr = &scale[0]; + const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; + + float* DstColPtr = Dst; + + for (size_t nn = 0; nn < n_cols_remaining; ++nn) { + const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; + + Q4BitBlkDequantB_16xNCols<1, HasZeroPoint>( + QuantBDataPtr, + StrideQuantBData, + ScalePtr, + OffsetPtr, + DstColPtr + ); + + ScalePtr += 1; + if constexpr (HasZeroPoint) { + OffsetPtr += 1; + } + DstColPtr += 1; + } + + Dst += 16 * std::min(kklen - kk, size_t{16}); + } + } + } +} + +} // namespace + +void +Q4BitBlkDequantBForSgemm_CompFp32( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockCountK +) +{ + if (QuantBZeroPoint != nullptr) { + Q4BitBlkDequantBForSgemm_CompFp32_Impl( + BlkLen, + FpData, + QuantBData, + QuantBScale, + QuantBZeroPoint, + CountN, + CountK, + BlockCountK + ); + } else { + Q4BitBlkDequantBForSgemm_CompFp32_Impl( + BlkLen, + FpData, + QuantBData, + QuantBScale, + QuantBZeroPoint, + CountN, + CountK, + BlockCountK + ); + } +} + +} // namespace sqnbitgemm_neon diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp new file mode 100644 index 0000000000..db3b9ee656 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -0,0 +1,1315 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon_int8.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON specific to + MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompInt8. + +--*/ + +#include + +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_neon.h" +#include "sqnbitgemm_q8_block.h" + +namespace sqnbitgemm_neon +{ + +// +// CompInt8 kernel implementation. +// + +namespace +{ + +template +MLAS_FORCEINLINE void +QuantizeBlock( + size_t BlkLen, + const float* A, + size_t ElementCount, + std::byte* QuantA +) +{ + static_assert(SubBlkLen >= 16 && SubBlkLen % 16 == 0); + + assert(BlkLen % SubBlkLen == 0); + + // + // Scan block values first to determine scale. + // + + float amax = 0.0f; // max of absolute values of A block + + size_t k; + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); + + float32x4_t a[SubBlkLen / 4]{}; + LoadFloatData(A + k, SubBlkElementCount, a); + + float32x4_t abs_a[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { + abs_a[i] = vabsq_f32(a[i]); + }); + + // find amax of SubBlkLen elements + for (size_t interval = SubBlkLen / 4 / 2; interval > 0; interval /= 2) { + for (size_t i = 0; i < interval; ++i) { + abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); + } + } + + // update existing amax + amax = std::max(amax, vmaxvq_f32(abs_a[0])); + } + + constexpr float range_max = (1 << 7) - 1; + const float scale = amax / range_max; + const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; + + Q8BlkScale(QuantA) = scale; + + // + // Compute quantized block values. + // + + int8_t* QuantAData = Q8BlkData(QuantA); + + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); + + float32x4_t a[SubBlkLen / 4]{}; + LoadFloatData(A + k, SubBlkElementCount, a); + + UnrolledLoop([&](size_t i) { + a[i] = vmulq_n_f32(a[i], scale_reciprocal); + }); + + int32x4_t a_s32[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { + a_s32[i] = vcvtaq_s32_f32(a[i]); + }); + + UnrolledLoop([&](size_t i) { + QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); + QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); + QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); + QuantAData[k + i * 4 + 3] = static_cast(vgetq_lane_s32(a_s32[i], 3)); + }); + } + + // + // Zero out any remaining sub-block elements. + // + + for (; k < BlkLen; k += SubBlkLen) { + const int8x16_t Zeros = vdupq_n_s8(0); + UnrolledLoop([&](size_t i) { + vst1q_s8(QuantAData + k + i * 16, Zeros); + }); + } +} + +} // namespace + +void +QuantizeARow_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +) +{ + const float* ADataBlkPtr = A; + std::byte* QuantABlkPtr = QuantA; + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + QuantizeBlock<16>(BlkLen, ADataBlkPtr, k_blk_len, QuantABlkPtr); + + ADataBlkPtr += BlkLen; + QuantABlkPtr += Q8BlkSize(BlkLen); + } +} + +namespace +{ + +// +// The ComputeRxC functions compute an R row by C column tile of the output matrix. +// + +template +MLAS_FORCEINLINE void +SQ4BitGemm_CompInt8_Compute2x2_BlkLen16( + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + const float* BiasPtr, + float* SumPtr, + size_t BlockCountK, + size_t StrideQuantA, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + size_t ldc +) +{ + constexpr size_t BlkLen = 16; + + const std::byte* QuantAPtr = QuantARowPtr; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc00{}, acc01{}, acc10{}, acc11{}; + + for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { + const std::byte* QuantABlkRow0 = QuantAPtr; + const std::byte* QuantABlkRow1 = QuantAPtr + StrideQuantA; + + const float QuantBScaleCol0 = *QuantBScalePtr; + const float QuantBScaleCol1 = *(QuantBScalePtr + StrideQuantBScale); + + // compute combined scales + const float scale00 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol0; + const float scale01 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol1; + const float scale10 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol0; + const float scale11 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol1; + + // load B zero point + int8_t bzp_col0; + int8_t bzp_col1; + if constexpr (HasZeroPoint) { + const std::byte QuantBZeroPointByteCol0 = *QuantBZeroPointPtr; + const std::byte QuantBZeroPointByteCol1 = *(QuantBZeroPointPtr + StrideQuantBZeroPoint); + if ((k_blk_idx & 1) == 0) { + bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 & std::byte{0x0F}); + bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 & std::byte{0x0F}); + } else { + bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 >> 4); + bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 >> 4); + } + } else { + bzp_col0 = 8; + bzp_col1 = 8; + } + + const int8_t* QuantADataPtrRow0 = Q8BlkData(QuantABlkRow0); + const int8_t* QuantADataPtrRow1 = Q8BlkData(QuantABlkRow1); + + // TODO handling only 16 elements per accumulator at a time here, probably can do better + { + // load A + const int8x16_t av_row0 = vld1q_s8(QuantADataPtrRow0 + 0); + const int8x16_t av_row1 = vld1q_s8(QuantADataPtrRow1 + 0); + + // load B + const uint8x8_t bv_packed_col0 = vld1_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x8_t bv_packed_col1 = vld1_u8(reinterpret_cast(QuantBDataPtr) + StrideQuantBData); + + const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); + + int8x16_t bv_col0 = vreinterpretq_s8_u8( + vcombine_u8( + vand_u8(bv_packed_col0, LowMaskU8x8), + vshr_n_u8(bv_packed_col0, 4) + ) + ); + int8x16_t bv_col1 = vreinterpretq_s8_u8( + vcombine_u8( + vand_u8(bv_packed_col1, LowMaskU8x8), + vshr_n_u8(bv_packed_col1, 4) + ) + ); + + // subtract B zero point + bv_col0 = vsubq_s8(bv_col0, vdupq_n_s8(bzp_col0)); + bv_col1 = vsubq_s8(bv_col1, vdupq_n_s8(bzp_col1)); + + // quantized dot product + int32x4_t dot00{}, dot01{}, dot10{}, dot11{}; + dot00 = vdotq_s32(dot00, av_row0, bv_col0); + dot01 = vdotq_s32(dot01, av_row0, bv_col1); + dot10 = vdotq_s32(dot10, av_row1, bv_col0); + dot11 = vdotq_s32(dot11, av_row1, bv_col1); + + // convert to float + const float32x4_t dot_f32_00 = vcvtq_f32_s32(dot00); + const float32x4_t dot_f32_01 = vcvtq_f32_s32(dot01); + const float32x4_t dot_f32_10 = vcvtq_f32_s32(dot10); + const float32x4_t dot_f32_11 = vcvtq_f32_s32(dot11); + + // multiply by scale and update accumulator + acc00 = vfmaq_f32(acc00, dot_f32_00, vdupq_n_f32(scale00)); + acc01 = vfmaq_f32(acc01, dot_f32_01, vdupq_n_f32(scale01)); + acc10 = vfmaq_f32(acc10, dot_f32_10, vdupq_n_f32(scale10)); + acc11 = vfmaq_f32(acc11, dot_f32_11, vdupq_n_f32(scale11)); + } + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen); + QuantBDataPtr += 8; + QuantBScalePtr += 1; + + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; + } + } + + SumPtr[0] = vaddvq_f32(acc00); + SumPtr[1] = vaddvq_f32(acc01); + SumPtr[ldc + 0] = vaddvq_f32(acc10); + SumPtr[ldc + 1] = vaddvq_f32(acc11); + + if (BiasPtr != nullptr) { + SumPtr[0] += BiasPtr[0]; + SumPtr[1] += BiasPtr[1]; + SumPtr[ldc + 0] += BiasPtr[0]; + SumPtr[ldc + 1] += BiasPtr[1]; + } +} + +template +MLAS_FORCEINLINE void +SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16( + size_t BlkLen, + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + const float* BiasPtr, + float* SumPtr, + size_t BlockCountK, + size_t StrideQuantA, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + size_t ldc +) +{ + // process blocks in 32-element sub-blocks + const size_t SubBlksPerBlk = BlkLen / 32; + + const std::byte* QuantAPtr = QuantARowPtr; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc00{}, acc01{}, acc10{}, acc11{}; + + for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { + const std::byte* QuantABlkRow0 = QuantAPtr; + const std::byte* QuantABlkRow1 = QuantAPtr + StrideQuantA; + + const float QuantBScaleCol0 = *QuantBScalePtr; + const float QuantBScaleCol1 = *(QuantBScalePtr + StrideQuantBScale); + + // compute combined scales + const float scale00 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol0; + const float scale01 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol1; + const float scale10 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol0; + const float scale11 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol1; + + // load B zero point + int8_t bzp_col0; + int8_t bzp_col1; + if constexpr (HasZeroPoint) { + const std::byte QuantBZeroPointByteCol0 = *QuantBZeroPointPtr; + const std::byte QuantBZeroPointByteCol1 = *(QuantBZeroPointPtr + StrideQuantBZeroPoint); + if ((k_blk_idx & 1) == 0) { + bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 & std::byte{0x0F}); + bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 & std::byte{0x0F}); + } else { + bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 >> 4); + bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 >> 4); + } + } else { + bzp_col0 = 8; + bzp_col1 = 8; + } + + const int8_t* QuantADataPtrRow0 = Q8BlkData(QuantABlkRow0); + const int8_t* QuantADataPtrRow1 = Q8BlkData(QuantABlkRow1); + + for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; ++sub_blk_idx) { + // load A + const int8x16_t av_row0_0 = vld1q_s8(QuantADataPtrRow0 + 0); + const int8x16_t av_row0_1 = vld1q_s8(QuantADataPtrRow0 + 16); + const int8x16_t av_row1_0 = vld1q_s8(QuantADataPtrRow1 + 0); + const int8x16_t av_row1_1 = vld1q_s8(QuantADataPtrRow1 + 16); + + // load B + const uint8x16_t bv_packed_col0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x16_t bv_packed_col1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + StrideQuantBData); + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + int8x16_t bv_col0_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col0, LowMaskU8x16)); + int8x16_t bv_col0_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col0, 4)); + int8x16_t bv_col1_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col1, LowMaskU8x16)); + int8x16_t bv_col1_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col1, 4)); + + // subtract B zero point + bv_col0_0 = vsubq_s8(bv_col0_0, vdupq_n_s8(bzp_col0)); + bv_col0_1 = vsubq_s8(bv_col0_1, vdupq_n_s8(bzp_col0)); + bv_col1_0 = vsubq_s8(bv_col1_0, vdupq_n_s8(bzp_col1)); + bv_col1_1 = vsubq_s8(bv_col1_1, vdupq_n_s8(bzp_col1)); + + // quantized dot product + int32x4_t dot00{}, dot01{}, dot10{}, dot11{}; + dot00 = vdotq_s32(vdotq_s32(dot00, av_row0_0, bv_col0_0), av_row0_1, bv_col0_1); + dot01 = vdotq_s32(vdotq_s32(dot01, av_row0_0, bv_col1_0), av_row0_1, bv_col1_1); + dot10 = vdotq_s32(vdotq_s32(dot10, av_row1_0, bv_col0_0), av_row1_1, bv_col0_1); + dot11 = vdotq_s32(vdotq_s32(dot11, av_row1_0, bv_col1_0), av_row1_1, bv_col1_1); + + // convert to float + const float32x4_t dot_f32_00 = vcvtq_f32_s32(dot00); + const float32x4_t dot_f32_01 = vcvtq_f32_s32(dot01); + const float32x4_t dot_f32_10 = vcvtq_f32_s32(dot10); + const float32x4_t dot_f32_11 = vcvtq_f32_s32(dot11); + + // multiply by scale and update accumulator + acc00 = vfmaq_f32(acc00, dot_f32_00, vdupq_n_f32(scale00)); + acc01 = vfmaq_f32(acc01, dot_f32_01, vdupq_n_f32(scale01)); + acc10 = vfmaq_f32(acc10, dot_f32_10, vdupq_n_f32(scale10)); + acc11 = vfmaq_f32(acc11, dot_f32_11, vdupq_n_f32(scale11)); + + // increment block data pointers to next sub-block + QuantADataPtrRow0 += 32; + QuantADataPtrRow1 += 32; + QuantBDataPtr += 16; + } + + // increment other block pointers + + QuantAPtr += Q8BlkSize(BlkLen); + QuantBScalePtr += 1; + + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; + } + } + + SumPtr[0] = vaddvq_f32(acc00); + SumPtr[1] = vaddvq_f32(acc01); + SumPtr[ldc + 0] = vaddvq_f32(acc10); + SumPtr[ldc + 1] = vaddvq_f32(acc11); + + if (BiasPtr != nullptr) { + SumPtr[0] += BiasPtr[0]; + SumPtr[1] += BiasPtr[1]; + SumPtr[ldc + 0] += BiasPtr[0]; + SumPtr[ldc + 1] += BiasPtr[1]; + } +} + +template +MLAS_FORCEINLINE void +SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + const float* BiasPtr, + float* SumPtr, + size_t BlockCountK +) +{ + constexpr size_t BlkLen = 16; + + const std::byte* QuantAPtr = QuantARowPtr; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); + const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + const int8x16_t bzp1 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) >> 4) : 8 + ); + + // load A + const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av1 = vld1q_s8(Q8BlkData(QuantABlk1)); + + // load B + const uint8x16_t bv_packed01 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + const uint8x16_t bv_lo01 = vandq_u8(bv_packed01, LowMaskU8x16); + const uint8x16_t bv_hi01 = vshrq_n_u8(bv_packed01, 4); + + int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(vget_low_u8(bv_lo01), vget_low_u8(bv_hi01))); + int8x16_t bv1 = vreinterpretq_s8_u8(vcombine_u8(vget_high_u8(bv_lo01), vget_high_u8(bv_hi01))); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp0); + bv1 = vsubq_s8(bv1, bzp1); + + // quantized dot product + const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); + const int32x4_t dot1 = vdotq_s32(vdupq_n_s32(0), av1, bv1); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 8 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + + // load A + const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); + + // load B + const uint8x8_t bv_packed0 = vld1_u8(reinterpret_cast(QuantBDataPtr)); + + const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); + + const uint8x8_t bv_lo0 = vand_u8(bv_packed0, LowMaskU8x8); + const uint8x8_t bv_hi0 = vshr_n_u8(bv_packed0, 4); + + int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(bv_lo0, bv_hi0)); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp0); + + // quantized dot product + const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } +} + +template +MLAS_FORCEINLINE void +SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + const float* BiasPtr, + float* SumPtr, + size_t BlockCountK +) +{ + constexpr size_t BlkLen = 32; + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + const std::byte* QuantAPtr = QuantARowPtr; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); + const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + const int8x16_t bzp1 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) >> 4) : 8 + ); + + // load A + const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); + const int8x16_t av_lo1 = vld1q_s8(Q8BlkData(QuantABlk1)); + const int8x16_t av_hi1 = vld1q_s8(Q8BlkData(QuantABlk1) + 16); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); + + int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + int8x16_t bv_lo1 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); + int8x16_t bv_hi1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); + + // subtract B zero point + bv_lo0 = vsubq_s8(bv_lo0, bzp0); + bv_hi0 = vsubq_s8(bv_hi0, bzp0); + bv_lo1 = vsubq_s8(bv_lo1, bzp1); + bv_hi1 = vsubq_s8(bv_hi1, bzp1); + + // quantized dot product + int32x4_t dot0{}, dot1{}; + dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); + dot1 = vdotq_s32(vdotq_s32(dot1, av_lo1, bv_lo1), av_hi1, bv_hi1); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + + // load A + const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + + int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + + // subtract B zero point + bv_lo0 = vsubq_s8(bv_lo0, bzp0); + bv_hi0 = vsubq_s8(bv_hi0, bzp0); + + // quantized dot product + int32x4_t dot0{}; + dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } +} + +template +MLAS_FORCEINLINE void +SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( + size_t BlkLen, + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + const float* BiasPtr, + float* SumPtr, + size_t BlockCountK +) +{ + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + // process blocks in 32-element sub-blocks + const size_t SubBlksPerBlk = BlkLen / 32; + + const std::byte* QuantAPtr = QuantARowPtr; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { + const std::byte* QuantABlk0 = QuantAPtr; + + // compute combined scale + const float32x4_t scale = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); + + // load B zero point + const int8x16_t bzp = [&]() -> int8x16_t { + if constexpr (HasZeroPoint) { + return vdupq_n_s8( + ((k_blk_idx & 1) == 0) ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) + : std::to_integer((*QuantBZeroPointPtr) >> 4) + ); + } else { + return vdupq_n_s8(8); + } + }(); + + const int8_t* QuantADataPtr = Q8BlkData(QuantAPtr); + + for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; sub_blk_idx += 2) { + // load A + const int8x16_t av0 = vld1q_s8(QuantADataPtr + 0); + const int8x16_t av1 = vld1q_s8(QuantADataPtr + 16); + const int8x16_t av2 = vld1q_s8(QuantADataPtr + 32); + const int8x16_t av3 = vld1q_s8(QuantADataPtr + 48); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); + + int8x16_t bv0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + int8x16_t bv2 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); + int8x16_t bv3 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp); + bv1 = vsubq_s8(bv1, bzp); + bv2 = vsubq_s8(bv2, bzp); + bv3 = vsubq_s8(bv3, bzp); + + // quantized dot product + int32x4_t dot0{}, dot1{}; + dot0 = vdotq_s32(vdotq_s32(dot0, av0, bv0), av1, bv1); + dot1 = vdotq_s32(vdotq_s32(dot1, av2, bv2), av3, bv3); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale); + + // increment block data pointers to next sub-block + QuantADataPtr += 16 * 4; + QuantBDataPtr += 16 * 2; + } + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen); + QuantBScalePtr += 1; + + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; + } + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } +} + +template +MLAS_FORCEINLINE void +AdvanceColPtrs( + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const std::byte*& QuantBDataColPtr, + const float*& QuantBScaleColPtr, + const std::byte*& QuantBZeroPointColPtr, + const float*& BiasPtr, + float*& SumPtr +) +{ + QuantBDataColPtr += NumCols * StrideQuantBData; + QuantBScaleColPtr += NumCols * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NumCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NumCols : 0; + SumPtr += NumCols; +} + +template +MLAS_FORCEINLINE void +AdvanceRowPtrs( + size_t StrideQuantA, + size_t ldc, + const std::byte*& QuantARowPtr, + float*& SumRowPtr +) +{ + QuantARowPtr += NumRows * StrideQuantA; + SumRowPtr += NumRows * ldc; +} + +template +void +SQ4BitGemmKernel_CompInt8_BlkLen16( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 16; + + const size_t StrideQuantA = BlockCountK * Q8BlkSize(BlkLen); + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const std::byte* QuantARowPtr = QuantA; + + float* SumRowPtr = C; + + size_t m_remaining = CountM; + while (m_remaining > 1) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 1) { + // Compute 2x2 tiles of output + SQ4BitGemm_CompInt8_Compute2x2_BlkLen16( + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK, + StrideQuantA, + StrideQuantBData, + StrideQuantBScale, + StrideQuantBZeroPoint, + ldc + ); + + // Move to next 2 columns + AdvanceColPtrs<2, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 2; + } + + if (n_remaining > 0) { + // Compute last 2x1 tile of output + SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK + ); + + SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( + QuantARowPtr + StrideQuantA, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr + ldc, + BlockCountK + ); + } + + // Move to next 2 rows + AdvanceRowPtrs<2>( + StrideQuantA, ldc, + QuantARowPtr, SumRowPtr + ); + + m_remaining -= 2; + } + + if (m_remaining > 0) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 0) { + // Compute 1x1 tiles of output + SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK + ); + + // Move to next column + AdvanceColPtrs<1, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 1; + } + } +} + +template +void +SQ4BitGemmKernel_CompInt8_BlkLen32( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 32; + + const size_t StrideQuantA = BlockCountK * Q8BlkSize(BlkLen); + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const std::byte* QuantARowPtr = QuantA; + + float* SumRowPtr = C; + + size_t m_remaining = CountM; + while (m_remaining > 1) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 1) { + // Compute 2x2 tiles of output + SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16( + BlkLen, + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK, + StrideQuantA, + StrideQuantBData, + StrideQuantBScale, + StrideQuantBZeroPoint, + ldc + ); + + // Move to next 2 columns + AdvanceColPtrs<2, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 2; + } + + if (n_remaining > 0) { + // Compute last 2x1 tile of output + SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK + ); + + SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( + QuantARowPtr + StrideQuantA, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr + ldc, + BlockCountK + ); + } + + // Move to next 2 rows + AdvanceRowPtrs<2>( + StrideQuantA, ldc, + QuantARowPtr, SumRowPtr + ); + + m_remaining -= 2; + } + + if (m_remaining > 0) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 0) { + // Compute 1x1 tiles of output + SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK + ); + + // Move to next column + AdvanceColPtrs<1, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 1; + } + } +} + +template +void +SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + + const size_t StrideQuantA = BlockCountK * Q8BlkSize(BlkLen); + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const std::byte* QuantARowPtr = QuantA; + + float* SumRowPtr = C; + + size_t m_remaining = CountM; + while (m_remaining > 1) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 1) { + // Compute 2x2 tiles of output + SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16( + BlkLen, + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK, + StrideQuantA, + StrideQuantBData, + StrideQuantBScale, + StrideQuantBZeroPoint, + ldc + ); + + // Move to next 2 columns + AdvanceColPtrs<2, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 2; + } + + if (n_remaining > 0) { + // Compute last 2x1 tile of output + SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( + BlkLen, + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK + ); + + SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( + BlkLen, + QuantARowPtr + StrideQuantA, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr + ldc, + BlockCountK + ); + } + + // Move to next 2 rows + AdvanceRowPtrs<2>( + StrideQuantA, ldc, + QuantARowPtr, SumRowPtr + ); + + m_remaining -= 2; + } + + if (m_remaining > 0) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 0) { + // Compute 1x1 tiles of output + SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( + BlkLen, + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK + ); + + // Move to next column + AdvanceColPtrs<1, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 1; + } + } +} + +template +void +SQ4BitGemmKernel_CompInt8_DispatchOnBlkLen( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + if (BlkLen == 16) { + SQ4BitGemmKernel_CompInt8_BlkLen16( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + BlockCountK, + ldc, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmKernel_CompInt8_BlkLen32( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + BlockCountK, + ldc, + Bias + ); + } else { + SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + BlockCountK, + ldc, + Bias + ); + } +} + +} // namespace + +size_t +SQ4BitGemmKernel_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + constexpr bool HasZeroPoint = true; + SQ4BitGemmKernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + BlockCountK, + ldc, + Bias + ); + } else { + constexpr bool HasZeroPoint = false; + SQ4BitGemmKernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + BlockCountK, + ldc, + Bias + ); + } + + return CountM; +} + +} // namespace sqnbitgemm_neon diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 71a6123b86..f391027de4 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -419,9 +419,10 @@ static size_t SQNBitGemmRegisterAllShortExecuteTests() { return count; } -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - if (is_short_execute) { - return SQNBitGemmRegisterAllShortExecuteTests() > 0; - } - return false; -}); +static UNUSED_VARIABLE bool added_to_main = AddTestRegister( + [](bool is_short_execute) -> size_t { + if (is_short_execute) { + return SQNBitGemmRegisterAllShortExecuteTests(); + } + return 0; + });