diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 7744e2f634..1d578d09b1 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -89,6 +89,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp ${MLAS_SRC_DIR}/fp16_neon_common.cpp + ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp ) set(mlas_platform_preprocess_srcs @@ -384,6 +385,7 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/fp16_neon_common.cpp + ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -394,6 +396,7 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h index 30b66cdb2e..f4c49905eb 100644 --- a/onnxruntime/core/mlas/lib/fp16_common.h +++ b/onnxruntime/core/mlas/lib/fp16_common.h @@ -64,6 +64,15 @@ MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); } +template +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasLoadLaneFloat16x4(const _mlas_fp16_* Buffer, MLAS_FLOAT16X4 vec) { + return vreinterpret_f16_u16( + vld1_lane_u16(Buffer, vreinterpret_u16_f16(vec), lane) + ); +} + MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasLoadPartialFloat16x4(const _mlas_fp16_* Buffer, size_t len) @@ -95,6 +104,14 @@ MlasStoreFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector) vst1_u16(Buffer, vreinterpret_u16_f16(Vector)); } +template +MLAS_FORCEINLINE +void +MlasStoreLaneFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector) +{ + vst1_lane_u16(Buffer, vreinterpret_u16_f16(Vector), lane); +} + MLAS_FORCEINLINE void MlasStorePartialFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector, size_t len) diff --git a/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp new file mode 100644 index 0000000000..d7d99899d5 --- /dev/null +++ b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp @@ -0,0 +1,898 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + hqnbitgemm_kernel_neon_fp16.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON specific to + MLAS_QNBIT_GEMM_COMPUTE_TYPE HQNBIT_CompFp16. + +--*/ + +#include + +#include +#include +#include + +#include "fp16_common.h" +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_neon.h" + +namespace sqnbitgemm_neon +{ +MLAS_FORCEINLINE void +Transpose8x8(uint8x8_t& v0, uint8x8_t& v1, uint8x8_t& v2, uint8x8_t& v3, + uint8x8_t& v4, uint8x8_t& v5, uint8x8_t& v6, uint8x8_t& v7) +{ + // v0: | B00 B10 | B20 B30 | B40 B50 | B60 B70 | B80 B90 | Ba0 Bb0 | Bc0 Bd0 | Be0 Bf0 | + // v1: | B01 B11 | B21 B31 | B41 B51 | B61 B71 | B81 B91 | Ba1 Bb1 | Bc1 Bd1 | Be1 Bf1 | + // v2: | B02 B12 | B22 B32 | B42 B52 | B62 B72 | B82 B92 | Ba2 Bb2 | Bc2 Bd2 | Be2 Bf2 | + // v3: | B03 B13 | B23 B33 | B43 B53 | B63 B73 | B83 B93 | Ba3 Bb3 | Bc3 Bd3 | Be3 Bf3 | + // v4: | B04 B14 | B24 B34 | B44 B54 | B64 B74 | B84 B94 | Ba4 Bb4 | Bc4 Bd4 | Be4 Bf4 | + // v5: | B05 B15 | B25 B35 | B45 B55 | B65 B75 | B85 B95 | Ba5 Bb5 | Bc5 Bd5 | Be5 Bf5 | + // v6: | B06 B16 | B26 B36 | B46 B56 | B66 B76 | B86 B96 | Ba6 Bb6 | Bc6 Bd6 | Be6 Bf6 | + // v7: | B07 B17 | B27 B37 | B47 B57 | B67 B77 | B87 B97 | Ba7 Bb7 | Bc7 Bd7 | Be7 Bf7 | + + uint8x8x2_t a0 = vtrn_u8(v0, v1); + uint8x8x2_t a1 = vtrn_u8(v2, v3); + uint8x8x2_t a2 = vtrn_u8(v4, v5); + uint8x8x2_t a3 = vtrn_u8(v6, v7); + + // a0[0]: | B00 B10 | B01 B11 | B40 B50 | B41 B51 | B80 B90 | B81 B91 | Bc0 Bd0 | Bc1 Bd1 | + // a0[1]: | B20 B30 | B21 B31 | B60 B70 | B61 B71 | Ba0 Bb0 | Ba1 Bb1 | Be0 Bf0 | Be1 Bf1 | + // a1[0]: | B02 B12 | B03 B13 | B42 B52 | B43 B53 | B82 B92 | B83 B93 | Bc2 Bd2 | Bc3 Bd3 | + // a1[1]: | B22 B32 | B23 B33 | B62 B72 | B63 B73 | Ba2 Bb2 | Ba3 Bb3 | Be2 Bf2 | Be3 Bf3 | + // a2[0]: | B04 B14 | B05 B15 | B44 B54 | B45 B55 | B84 B94 | B85 B95 | Bc4 Bd4 | Bc5 Bd5 | + // a2[1]: | B24 B34 | B25 B35 | B64 B74 | B65 B75 | Ba4 Bb4 | Ba5 Bb5 | Be4 Bf4 | Be5 Bf5 | + // a3[0]: | B06 B16 | B07 B17 | B46 B56 | B47 B57 | B86 B96 | B87 B97 | Bc6 Bd6 | Bc7 Bd7 | + // a3[1]: | B26 B36 | B27 B37 | B66 B76 | B67 B77 | Ba6 Bb6 | Ba7 Bb7 | Be6 Bf6 | Be7 Bf7 | + + uint16x4x2_t b0 = vtrn_u16(vreinterpret_u16_u8(a0.val[0]), vreinterpret_u16_u8(a1.val[0])); + uint16x4x2_t b1 = vtrn_u16(vreinterpret_u16_u8(a0.val[1]), vreinterpret_u16_u8(a1.val[1])); + uint16x4x2_t b2 = vtrn_u16(vreinterpret_u16_u8(a2.val[0]), vreinterpret_u16_u8(a3.val[0])); + uint16x4x2_t b3 = vtrn_u16(vreinterpret_u16_u8(a2.val[1]), vreinterpret_u16_u8(a3.val[1])); + + // b0[0]: | B00 B10 | B01 B11 | B02 B12 | B03 B13 | B80 B90 | B81 B91 | B82 B92 | B83 B93 | + // b0[1]: | B40 B50 | B41 B51 | B42 B52 | B43 B53 | Bc0 Bd0 | Bc1 Bd1 | Bc2 Bd2 | Bc3 Bd3 | + // b1[0]: | B20 B30 | B21 B31 | B22 B32 | B23 B33 | Ba0 Bb0 | Ba1 Bb1 | Ba2 Bb2 | Ba3 Bb3 | + // b1[1]: | B60 B70 | B61 B71 | B62 B72 | B63 B73 | Be0 Bf0 | Be1 Bf1 | Be2 Bf2 | Be3 Bf3 | + // b2[0]: | B04 B14 | B05 B15 | B06 B16 | B07 B17 | B84 B94 | B85 B95 | B86 B96 | B87 B97 | + // b2[1]: | B44 B54 | B45 B55 | B46 B56 | B47 B57 | Bc4 Bd4 | Bc5 Bd5 | Bc6 Bd6 | Bc7 Bd7 | + // b3[0]: | B24 B34 | B25 B35 | B26 B36 | B27 B37 | Ba4 Bb4 | Ba5 Bb5 | Ba6 Bb6 | Ba7 Bb7 | + // b3[1]: | B64 B74 | B65 B75 | B66 B76 | B67 B77 | Be4 Bf4 | Be5 Bf5 | Be6 Bf6 | Be7 Bf7 | + + uint32x2x2_t c0 = vtrn_u32(vreinterpret_u32_u16(b0.val[0]), vreinterpret_u32_u16(b2.val[0])); + uint32x2x2_t c1 = vtrn_u32(vreinterpret_u32_u16(b0.val[1]), vreinterpret_u32_u16(b2.val[1])); + uint32x2x2_t c2 = vtrn_u32(vreinterpret_u32_u16(b1.val[0]), vreinterpret_u32_u16(b3.val[0])); + uint32x2x2_t c3 = vtrn_u32(vreinterpret_u32_u16(b1.val[1]), vreinterpret_u32_u16(b3.val[1])); + + // c0[0]: | B00 B10 | B01 B11 | B02 B12 | B03 B13 | B04 B14 | B05 B15 | B06 B16 | B07 B17 | + // c0[1]: | B80 B90 | B81 B91 | B92 B92 | B83 B93 | B84 B94 | B85 B95 | B86 B96 | B87 B97 | + // c1[0]: | B40 B50 | B41 B51 | B42 B52 | B43 B53 | B44 B54 | B45 B55 | B46 B56 | B47 B57 | + // c1[1]: | Bc0 Bd0 | Bc1 Bd1 | Bc2 Bd2 | Bc3 Bd3 | Bc4 Bd4 | Bc5 Bd5 | Bc6 Bd6 | Bc7 Bd7 | + // c2[0]: | B20 B30 | B21 B31 | B22 B32 | B23 B33 | B24 B34 | B25 B35 | B26 B36 | B27 B37 | + // c2[1]: | Ba0 Bb0 | Ba1 Bb1 | Ba2 Bb2 | Ba3 Bb3 | Ba4 Bb4 | Ba5 Bb5 | Ba6 Bb6 | Ba7 Bb7 | + // c3[0]: | B60 B70 | B61 B71 | B62 B72 | B63 B73 | B64 B74 | B65 B75 | B66 B76 | B67 B77 | + // c3[1]: | Be0 Bf0 | Be1 Bf1 | Be2 Bf2 | Be3 Bf3 | Be4 Bf4 | Be5 Bf5 | Be6 Bf6 | Be7 Bf7 | + + v0 = vreinterpret_u8_u32(c0.val[0]); + v1 = vreinterpret_u8_u32(c2.val[0]); + v2 = vreinterpret_u8_u32(c1.val[0]); + v3 = vreinterpret_u8_u32(c3.val[0]); + v4 = vreinterpret_u8_u32(c0.val[1]); + v5 = vreinterpret_u8_u32(c2.val[1]); + v6 = vreinterpret_u8_u32(c1.val[1]); + v7 = vreinterpret_u8_u32(c3.val[1]); +} + +MLAS_FORCEINLINE void +Transpose4x8(float16x8_t& v0, float16x8_t& v1, float16x8_t& v2, float16x8_t& v3) +{ + // |v00|v01|v02|v03|v04|v05|v06|v07| + // |v10|v11|v12|v13|v14|v15|v16|v17| + // |v20|v21|v22|v23|v24|v25|v26|v27| + // |v30|v31|v32|v33|v34|v35|v36|v37| + // => + // |v00|v10|v20|v30|v04|v14|v24|v34| + // |v01|v11|v21|v31|v05|v15|v25|v35| + // |v02|v12|v22|v32|v06|v16|v26|v36| + // |v03|v13|v23|v33|v07|v17|v27|v37| + float16x8x2_t t01 = vtrnq_f16(v0, v1); + float16x8x2_t t23 = vtrnq_f16(v2, v3); + + v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); + v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); +} + +MLAS_FORCEINLINE void +Transpose4x4(float16x4_t& v0, float16x4_t& v1, float16x4_t& v2, float16x4_t& v3) +{ + float16x4x2_t t01 = vtrn_f16(v0, v1); + float16x4x2_t t23 = vtrn_f16(v2, v3); + + v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); + v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); +} + +void +HQ4BitGemmPackQuantBData_CompFp16( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + MLAS_UNREFERENCED_PARAMETER(ComputeType); + constexpr size_t nbits = 4; + constexpr size_t k_blk_dim = 16; + constexpr size_t n_blk_dim = 8; + assert(BlkLen > 0 && BlkLen % k_blk_dim == 0); + + const size_t k_blk_num = MlasDivRoundup(K, k_blk_dim); + const size_t n_blk_num = MlasDivRoundup(N, n_blk_dim); + constexpr size_t k_blk_bytes = MlasQNBitBlkDataSizeInBytes(nbits, k_blk_dim); + const size_t iterations = k_blk_num * n_blk_num; // one iteration per block + const size_t ld = MlasDivRoundup(K, BlkLen) * MlasQNBitBlkDataSizeInBytes(nbits, BlkLen); + + // + // For blocks 16_K * 8_N, transpose bytes in 8x8 blocks like this: + // src B_k_n: + // | B00 B10 | B20 B30 | B40 B50 | B60 B70 | B80 B90 | Ba0 Bb0 | Bc0 Bd0 | Be0 Bf0 | + // | B01 B11 | B21 B31 | B41 B51 | B61 B71 | B81 B91 | Ba1 Bb1 | Bc1 Bd1 | Be1 Bf1 | + // | B02 B12 | B22 B32 | B42 B52 | B62 B72 | B82 B92 | Ba2 Bb2 | Bc2 Bd2 | Be2 Bf2 | + // | B03 B13 | B23 B33 | B43 B53 | B63 B73 | B83 B93 | Ba3 Bb3 | Bc3 Bd3 | Be3 Bf3 | + // | B04 B14 | B24 B34 | B44 B54 | B64 B74 | B84 B94 | Ba4 Bb4 | Bc4 Bd4 | Be4 Bf4 | + // | B05 B15 | B25 B35 | B45 B55 | B65 B75 | B85 B95 | Ba5 Bb5 | Bc5 Bd5 | Be5 Bf5 | + // | B06 B16 | B26 B36 | B46 B56 | B66 B76 | B86 B96 | Ba6 Bb6 | Bc6 Bd6 | Be6 Bf6 | + // | B07 B17 | B27 B37 | B47 B57 | B67 B77 | B87 B97 | Ba7 Bb7 | Bc7 Bd7 | Be7 Bf7 | + // => dst: + // | B00 B10 | B01 B11 | B02 B12 | B03 B13 | B04 B14 | B05 B15 | B06 B16 | B07 B17 | + // | B20 B30 | B21 B31 | B22 B32 | B23 B33 | B24 B34 | B25 B35 | B26 B36 | B27 B37 | + // | B40 B50 | B41 B51 | B42 B52 | B43 B53 | B44 B54 | B45 B55 | B46 B56 | B47 B57 | + // | B60 B70 | B61 B71 | B62 B72 | B63 B73 | B64 B74 | B65 B75 | B66 B76 | B67 B77 | + // | B80 B90 | B81 B91 | B92 B92 | B83 B93 | B84 B94 | B85 B95 | B86 B96 | B87 B97 | + // | Ba0 Bb0 | Ba1 Bb1 | Ba2 Bb2 | Ba3 Bb3 | Ba4 Bb4 | Ba5 Bb5 | Ba6 Bb6 | Ba7 Bb7 | + // | Bc0 Bd0 | Bc1 Bd1 | Bc2 Bd2 | Bc3 Bd3 | Bc4 Bd4 | Bc5 Bd5 | Bc6 Bd6 | Bc7 Bd7 | + // | Be0 Bf0 | Be1 Bf1 | Be2 Bf2 | Be3 Bf3 | Be4 Bf4 | Be5 Bf5 | Be6 Bf6 | Be7 Bf7 | + // + + // + // For blocks < 8_N: + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + MlasTrySimpleParallel( + ThreadPool, iterations, + [&](ptrdiff_t tid) { + const size_t n_blk = tid / k_blk_num; + const size_t k_blk = tid % k_blk_num; + size_t n = n_blk * n_blk_dim; + const size_t src_offset = n * ld + k_blk * k_blk_bytes; + + if (n + n_blk_dim <= N) { + const size_t dst_offset = n * ld + k_blk * k_blk_bytes * n_blk_dim; + const uint8_t* src = reinterpret_cast(QuantBDataBegin) + src_offset; + uint8_t* dst = reinterpret_cast(PackedQuantBDataBegin) + dst_offset; + + uint8x8_t v0 = vld1_u8(src); + uint8x8_t v1 = vld1_u8(src + ld); + uint8x8_t v2 = vld1_u8(src + 2*ld); + uint8x8_t v3 = vld1_u8(src + 3*ld); + uint8x8_t v4 = vld1_u8(src + 4*ld); + uint8x8_t v5 = vld1_u8(src + 5*ld); + uint8x8_t v6 = vld1_u8(src + 6*ld); + uint8x8_t v7 = vld1_u8(src + 7*ld); + + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + + vst1_u8(dst, v0); + vst1_u8(dst + 8, v1); + vst1_u8(dst + 16, v2); + vst1_u8(dst + 24, v3); + vst1_u8(dst + 32, v4); + vst1_u8(dst + 40, v5); + vst1_u8(dst + 48, v6); + vst1_u8(dst + 56, v7); + } else { + const uint8_t* src = reinterpret_cast(QuantBDataBegin) + src_offset; + uint8_t* dst = reinterpret_cast(PackedQuantBDataBegin) + src_offset; + + for (; n < N; ++n, src += ld, dst += ld) { + uint8x8_t v0 = vld1_u8(src); + uint8x8_t v_even = vand_u8(v0, vdup_n_u8(0x0F)); + uint8x8_t v_odd = vshr_n_u8(v0, 4); + uint8x8x2_t v1 = vzip_u8(v_even, v_odd); + uint8x8_t v2 = vorr_u8(v1.val[0], vshl_n_u8(v1.val[1], 4)); + vst1_u8(dst, v2); + } + } + } + ); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && K == 16), void> +HQ4BitBlkDequantBKernel( + const std::uint8_t* src_ptr, + const float16x8_t& scale, + const float16x8_t& neg_scaled_zp, + _mlas_fp16_* dst_ptr +) { + const uint8x8_t low_mask = vdup_n_u8(0x0F); + + uint8x8_t b01 = vld1_u8(src_ptr); + uint8x8_t b23 = vld1_u8(src_ptr + 8); + uint8x8_t b45 = vld1_u8(src_ptr + 16); + uint8x8_t b67 = vld1_u8(src_ptr + 24); + uint8x8_t b89 = vld1_u8(src_ptr + 32); + uint8x8_t bab = vld1_u8(src_ptr + 40); + uint8x8_t bcd = vld1_u8(src_ptr + 48); + uint8x8_t bef = vld1_u8(src_ptr + 56); + + float16x8_t b0 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b01, low_mask), 0)); + float16x8_t b1 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b01, 4), 0)); + float16x8_t b2 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b23, low_mask), 0)); + float16x8_t b3 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b23, 4), 0)); + float16x8_t b4 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b45, low_mask), 0)); + float16x8_t b5 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b45, 4), 0)); + float16x8_t b6 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b67, low_mask), 0)); + float16x8_t b7 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b67, 4), 0)); + float16x8_t b8 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b89, low_mask), 0)); + float16x8_t b9 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b89, 4), 0)); + float16x8_t ba = vcvtq_f16_u16(vshll_n_u8(vand_u8(bab, low_mask), 0)); + float16x8_t bb = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(bab, 4), 0)); + float16x8_t bc = vcvtq_f16_u16(vshll_n_u8(vand_u8(bcd, low_mask), 0)); + float16x8_t bd = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(bcd, 4), 0)); + float16x8_t be = vcvtq_f16_u16(vshll_n_u8(vand_u8(bef, low_mask), 0)); + float16x8_t bf = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(bef, 4), 0)); + + float16x8_t c0 = vfmaq_f16(neg_scaled_zp, b0, scale); + float16x8_t c1 = vfmaq_f16(neg_scaled_zp, b1, scale); + float16x8_t c2 = vfmaq_f16(neg_scaled_zp, b2, scale); + float16x8_t c3 = vfmaq_f16(neg_scaled_zp, b3, scale); + float16x8_t c4 = vfmaq_f16(neg_scaled_zp, b4, scale); + float16x8_t c5 = vfmaq_f16(neg_scaled_zp, b5, scale); + float16x8_t c6 = vfmaq_f16(neg_scaled_zp, b6, scale); + float16x8_t c7 = vfmaq_f16(neg_scaled_zp, b7, scale); + float16x8_t c8 = vfmaq_f16(neg_scaled_zp, b8, scale); + float16x8_t c9 = vfmaq_f16(neg_scaled_zp, b9, scale); + float16x8_t ca = vfmaq_f16(neg_scaled_zp, ba, scale); + float16x8_t cb = vfmaq_f16(neg_scaled_zp, bb, scale); + float16x8_t cc = vfmaq_f16(neg_scaled_zp, bc, scale); + float16x8_t cd = vfmaq_f16(neg_scaled_zp, bd, scale); + float16x8_t ce = vfmaq_f16(neg_scaled_zp, be, scale); + float16x8_t cf = vfmaq_f16(neg_scaled_zp, bf, scale); + + MlasStoreFloat16x8(dst_ptr, c0); + MlasStoreFloat16x8(dst_ptr + 8, c1); + MlasStoreFloat16x8(dst_ptr + 16, c2); + MlasStoreFloat16x8(dst_ptr + 24, c3); + MlasStoreFloat16x8(dst_ptr + 32, c4); + MlasStoreFloat16x8(dst_ptr + 40, c5); + MlasStoreFloat16x8(dst_ptr + 48, c6); + MlasStoreFloat16x8(dst_ptr + 56, c7); + MlasStoreFloat16x8(dst_ptr + 64, c8); + MlasStoreFloat16x8(dst_ptr + 72, c9); + MlasStoreFloat16x8(dst_ptr + 80, ca); + MlasStoreFloat16x8(dst_ptr + 88, cb); + MlasStoreFloat16x8(dst_ptr + 96, cc); + MlasStoreFloat16x8(dst_ptr + 104, cd); + MlasStoreFloat16x8(dst_ptr + 112, ce); + MlasStoreFloat16x8(dst_ptr + 120, cf); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 1 && K == 16), void> +HQ4BitBlkDequantBKernel( + const std::uint8_t* src_ptr, + const float16x8_t& scale, + const float16x8_t& neg_scaled_zp, + _mlas_fp16_* dst_ptr +) { + const uint8x8_t low_mask = vdup_n_u8(0x0F); + + uint8x8_t v0 = vld1_u8(src_ptr); + + float16x8_t f_low = vcvtq_f16_u16(vshll_n_u8(vand_u8(v0, low_mask), 0)); + float16x8_t f_high = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(v0, 4), 0)); + + float16x8_t c0 = vfmaq_f16(neg_scaled_zp, f_low, scale); + float16x8_t c1 = vfmaq_f16(neg_scaled_zp, f_high, scale); + + MlasStoreFloat16x8(dst_ptr, c0); + MlasStoreFloat16x8(dst_ptr + 8, c1); +} + +void +HQ4BitBlkDequantBForHgemm_CompFp16( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t K, + size_t BlockCountK +) { + MLAS_UNREFERENCED_PARAMETER(K); + constexpr size_t nbits = 4; + constexpr size_t kk_blk_dim = 16; + constexpr size_t n_blk_dim = 8; + assert(BlkLen > 0 && BlkLen % kk_blk_dim == 0); + + const size_t kk_blk_num = BlockCountK * BlkLen / kk_blk_dim; + constexpr size_t kk_blk_bytes = MlasQNBitBlkDataSizeInBytes(nbits, kk_blk_dim); + const size_t kk_n_src_bytes = kk_blk_bytes * n_blk_dim; + const size_t kk_n_dst_size = kk_blk_dim * n_blk_dim; + const size_t ld_blk_src = kk_blk_num * kk_n_src_bytes; + const size_t ld_blk_dst = BlkLen * BlockCountK * n_blk_dim; + const size_t ld_blk_scale = BlockCountK * n_blk_dim; + const size_t ld_zp = (BlockCountK + 1) / 2; + const size_t ld_blk_zp = ld_zp * n_blk_dim; + const float16x8_t zp_mid_point_vec = MlasBroadcastFloat16x8(MLAS_FP16(8.0f).val); + const bool has_zp = QuantBZeroPoint != nullptr; + + size_t n = 0; + for (; n + n_blk_dim <= CountN; n += n_blk_dim) { + const auto* scales_ptr = reinterpret_cast(QuantBScale); + const std::uint8_t* zero_points_ptr = reinterpret_cast(QuantBZeroPoint); + const std::uint8_t* src_ptr = reinterpret_cast(QuantBData); + auto* dst_ptr = reinterpret_cast<_mlas_fp16_*>(FpData); + + for (size_t k_blk_i = 0; k_blk_i < BlockCountK; ++k_blk_i) { + // prepare scales and zero_points for the block + _mlas_fp16_ scales[n_blk_dim]; + uint16_t zero_points[n_blk_dim]; + float16x8_t scale_vec; + float16x8_t neg_scaled_zp_vec; + + UnrolledLoop([&](int nn){ + scales[nn] = scales_ptr[nn * BlockCountK]; + }); + scale_vec = MlasLoadFloat16x8(scales); + + if (has_zp) { + UnrolledLoop([&](int nn){ + uint8_t zp = zero_points_ptr[nn * ld_zp]; + zp = (k_blk_i & 1) ? (zp >> 4) : (zp & 0x0F); + zero_points[nn] = static_cast(zp); + }); + uint16x8_t zp_u16_vec = vld1q_u16(zero_points); + neg_scaled_zp_vec = vcvtq_f16_u16(zp_u16_vec); + } else { + neg_scaled_zp_vec = zp_mid_point_vec; + } + neg_scaled_zp_vec = vnegq_f16(vmulq_f16(scale_vec, neg_scaled_zp_vec)); + + for (size_t kk = 0; kk < BlkLen; kk += kk_blk_dim) { + HQ4BitBlkDequantBKernel<8, 16>(src_ptr, scale_vec, neg_scaled_zp_vec, dst_ptr); + + src_ptr += kk_n_src_bytes; + dst_ptr += kk_n_dst_size; + } + + ++scales_ptr; + if (has_zp) { + zero_points_ptr += k_blk_i & 1; + } + } + + QuantBData += ld_blk_src; + FpData += ld_blk_dst; + QuantBScale += ld_blk_scale; + QuantBZeroPoint = has_zp ? QuantBZeroPoint + ld_blk_zp : nullptr; + } + + // remaining N + for (; n < CountN; ++n) { + const auto* scales_ptr = reinterpret_cast(QuantBScale); + const std::uint8_t* zero_points_ptr = reinterpret_cast(QuantBZeroPoint); + for (size_t k_blk_i = 0; k_blk_i < BlockCountK; ++k_blk_i) { + const auto scale = scales_ptr[0]; + float16x8_t scale_vec = MlasBroadcastFloat16x8(scale); + float16x8_t neg_scaled_zp_vec; + + if (has_zp) { + uint8_t zero_point = static_cast(zero_points_ptr[0]); + zero_point = (k_blk_i & 1) ? (zero_point >> 4) : (zero_point & 0x0F); + uint16x8_t zp_u16_vec = vdupq_n_u16(static_cast(zero_point)); + neg_scaled_zp_vec = vcvtq_f16_u16(zp_u16_vec); + } else { + neg_scaled_zp_vec = zp_mid_point_vec; + } + neg_scaled_zp_vec = vnegq_f16(vmulq_f16(scale_vec, neg_scaled_zp_vec)); + + for (size_t kk = 0; kk < BlkLen; kk += kk_blk_dim) { + HQ4BitBlkDequantBKernel<1, 16>( + reinterpret_cast(QuantBData), scale_vec, neg_scaled_zp_vec, + reinterpret_cast<_mlas_fp16_*>(FpData) + ); + + QuantBData += kk_blk_bytes; + FpData += kk_blk_dim; + } + + ++scales_ptr; + if (has_zp) { + zero_points_ptr += k_blk_i & 1; + } + } + + QuantBScale += BlockCountK; + if (has_zp) { + QuantBZeroPoint += ld_zp; + } + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8), float16x8_t> +PrepareAccumulator(const _mlas_fp16_* Bias) +{ + if (Bias) { + return MlasLoadFloat16x8(Bias); + } else { + return MlasZeroFloat16x8(); + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 4), float16x4_t> +PrepareAccumulator(const _mlas_fp16_* Bias) +{ + if (Bias) { + return MlasLoadFloat16x4(Bias); + } else { + return MlasZeroFloat16x4(); + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N == 2 || N == 1)), float16x4_t> +PrepareAccumulator(const _mlas_fp16_* Bias) +{ + float16x4_t v = MlasZeroFloat16x4(); + + if (Bias) { + v = MlasLoadLaneFloat16x4<0>(Bias, v); + if constexpr (N == 2) { + v = MlasLoadLaneFloat16x4<1>(Bias + 1, v); + } + return v; + } else { + return v; + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && M == 1 && K == 8), float16x8_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x8_t accumulator +) { + MLAS_UNREFERENCED_PARAMETER(ldb); + float16x8_t a0 = MlasLoadFloat16x8(A); + float16x8_t b0 = MlasLoadFloat16x8(B); + float16x8_t b1 = MlasLoadFloat16x8(B + 8); + float16x8_t b2 = MlasLoadFloat16x8(B + 16); + float16x8_t b3 = MlasLoadFloat16x8(B + 24); + float16x8_t b4 = MlasLoadFloat16x8(B + 32); + float16x8_t b5 = MlasLoadFloat16x8(B + 40); + float16x8_t b6 = MlasLoadFloat16x8(B + 48); + float16x8_t b7 = MlasLoadFloat16x8(B + 56); + + // This version uses less instructions, but introduces dependency path between instructions. + // Must pair it with loop unrolling to alleviate dependency path penalty. + float16x8_t c0 = vfmaq_laneq_f16(accumulator, b0, a0, 0); + c0 = vfmaq_laneq_f16(c0, b1, a0, 1); + c0 = vfmaq_laneq_f16(c0, b2, a0, 2); + c0 = vfmaq_laneq_f16(c0, b3, a0, 3); + c0 = vfmaq_laneq_f16(c0, b4, a0, 4); + c0 = vfmaq_laneq_f16(c0, b5, a0, 5); + c0 = vfmaq_laneq_f16(c0, b6, a0, 6); + c0 = vfmaq_laneq_f16(c0, b7, a0, 7); + + return c0; +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && M == 1 && K == 4), float16x8_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x8_t accumulator +) { + MLAS_UNREFERENCED_PARAMETER(ldb); + float16x4_t a0 = MlasLoadFloat16x4(A); + float16x8_t b0 = MlasLoadFloat16x8(B); + float16x8_t b1 = MlasLoadFloat16x8(B + 8); + float16x8_t b2 = MlasLoadFloat16x8(B + 16); + float16x8_t b3 = MlasLoadFloat16x8(B + 24); + + float16x8_t c0 = vfmaq_lane_f16(accumulator, b0, a0, 0); + c0 = vfmaq_lane_f16(c0, b1, a0, 1); + c0 = vfmaq_lane_f16(c0, b2, a0, 2); + c0 = vfmaq_lane_f16(c0, b3, a0, 3); + + return c0; +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && M == 1 && (K == 2 || K == 1)), float16x8_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x8_t accumulator +) { + MLAS_UNREFERENCED_PARAMETER(ldb); + float16x4_t a0 = MlasZeroFloat16x4(); + a0 = MlasLoadLaneFloat16x4<0>(A, a0); + if constexpr (K == 2) a0 = MlasLoadLaneFloat16x4<1>(A + 1, a0); + float16x8_t b0 = MlasLoadFloat16x8(B), b1; + if constexpr (K == 2) b1 = MlasLoadFloat16x8(B + 8); + + float16x8_t c0 = vfmaq_lane_f16(accumulator, b0, a0, 0), c01; + if constexpr (K == 2) c01 = vfmaq_lane_f16(c0, b1, a0, 1); + + if constexpr (K == 1) + return c0; + else + return c01; +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N > 0 && N <= 4) && M == 1 && K == 8), float16x4_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x4_t accumulator +) { + float16x8_t a0 = MlasLoadFloat16x8(A); + + float16x8_t b0, b1, b2, b3; + b0 = MlasLoadFloat16x8(B); + if constexpr (N > 1) b1 = MlasLoadFloat16x8(B + ldb); + if constexpr (N > 2) b2 = MlasLoadFloat16x8(B + ldb * 2); + if constexpr (N > 3) b3 = MlasLoadFloat16x8(B + ldb * 3); + + float16x8_t c00, c01, c02, c03; + c00 = vmulq_f16(b0, a0); + if constexpr (N > 1) + c01 = vmulq_f16(b1, a0); + else + c01 = MlasZeroFloat16x8(); + if constexpr (N > 2) + c02 = vmulq_f16(b2, a0); + else + c02 = MlasZeroFloat16x8(); + if constexpr (N > 3) + c03 = vmulq_f16(b3, a0); + else + c03 = MlasZeroFloat16x8(); + + Transpose4x8(c00, c01, c02, c03); + + float16x8_t c_low_high = vaddq_f16(vaddq_f16(c00, c01), vaddq_f16(c02, c03)); + float16x4_t c_low = vget_low_f16(c_low_high); + float16x4_t c_high = vget_high_f16(c_low_high); + float16x4_t c = vadd_f16(c_low, c_high); + + return vadd_f16(c, accumulator); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N > 0 && N <= 4) && M == 1 && (K == 4)), float16x4_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x4_t accumulator +) { + float16x4_t a0 = MlasLoadFloat16x4(A); + float16x4_t b0, b1, b2, b3; + b0 = MlasLoadFloat16x4(B); + if constexpr (N > 1) b1 = MlasLoadFloat16x4(B + ldb); + if constexpr (N > 2) b2 = MlasLoadFloat16x4(B + ldb * 2); + if constexpr (N > 3) b3 = MlasLoadFloat16x4(B + ldb * 3); + + float16x4_t c00, c01, c02, c03; + c00 = vmul_f16(b0, a0); + if constexpr (N > 1) + c01 = vmul_f16(b1, a0); + else + c01 = MlasZeroFloat16x4(); + if constexpr (N > 2) + c02 = vmul_f16(b2, a0); + else + c02 = MlasZeroFloat16x4(); + if constexpr (N > 3) + c03 = vmul_f16(b3, a0); + else + c03 = MlasZeroFloat16x4(); + + Transpose4x4(c00, c01, c02, c03); + + float16x4_t c = vadd_f16(vadd_f16(c00, c01), vadd_f16(c02, c03)); + return vadd_f16(c, accumulator); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N > 0 && N <= 4) && M == 1 && (K > 0 && K < 4)), float16x4_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x4_t accumulator +) { + float16x4_t a0 = MlasZeroFloat16x4(); + float16x4_t b0 = MlasZeroFloat16x4(), b1, b2, b3; + if constexpr (N > 1) b1 = MlasZeroFloat16x4(); + if constexpr (N > 2) b2 = MlasZeroFloat16x4(); + if constexpr (N > 3) b3 = MlasZeroFloat16x4(); + + a0 = MlasLoadLaneFloat16x4<0>(A, a0); + b0 = MlasLoadLaneFloat16x4<0>(B, b0); + if constexpr (N > 1) b1 = MlasLoadLaneFloat16x4<0>(B + ldb, b1); + if constexpr (N > 2) b2 = MlasLoadLaneFloat16x4<0>(B + ldb * 2, b2); + if constexpr (N > 3) b3 = MlasLoadLaneFloat16x4<0>(B + ldb * 3, b3); + + if constexpr (K >= 2) { + a0 = MlasLoadLaneFloat16x4<1>(A + 1, a0); + b0 = MlasLoadLaneFloat16x4<1>(B + 1, b0); + if constexpr (N > 1) b1 = MlasLoadLaneFloat16x4<1>(B + 1 + ldb, b1); + if constexpr (N > 2) b2 = MlasLoadLaneFloat16x4<1>(B + 1 + ldb * 2, b2); + if constexpr (N > 3) b3 = MlasLoadLaneFloat16x4<1>(B + 1 + ldb * 3, b3); + } + + if constexpr (K >= 3) { + a0 = MlasLoadLaneFloat16x4<2>(A + 2, a0); + b0 = MlasLoadLaneFloat16x4<2>(B + 2, b0); + if constexpr (N > 1) b1 = MlasLoadLaneFloat16x4<2>(B + 2 + ldb, b1); + if constexpr (N > 2) b2 = MlasLoadLaneFloat16x4<2>(B + 2 + ldb * 2, b2); + if constexpr (N > 3) b3 = MlasLoadLaneFloat16x4<2>(B + 2 + ldb * 3, b3); + } + + float16x4_t c00, c01, c02, c03; + c00 = vmul_f16(b0, a0); + if constexpr (N > 1) + c01 = vmul_f16(b1, a0); + else + c01 = MlasZeroFloat16x4(); + if constexpr (N > 2) + c02 = vmul_f16(b2, a0); + else + c02 = MlasZeroFloat16x4(); + if constexpr (N > 3) + c03 = vmul_f16(b3, a0); + else + c03 = MlasZeroFloat16x4(); + + Transpose4x4(c00, c01, c02, c03); + + float16x4_t c = vadd_f16(vadd_f16(c00, c01), vadd_f16(c02, c03)); + return vadd_f16(c, accumulator); +} + +template +typename std::enable_if_t<((CountN >= 1 && CountN <= 16 && ((CountN - 1) & CountN) == 0) && (CountM == 1 || CountM == 2)), void> +HQ4BitGemmKernel_CompFp16_Kernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const _mlas_fp16_* Bias, + _mlas_fp16_* C, + size_t K, + size_t lda, + size_t ldb, + size_t ldc +) { + using RegisterType = typename std::conditional_t<(CountN < 8), float16x4_t, float16x8_t>; + + RegisterType accu00, accu01, accu10, accu11; + constexpr size_t b_step = CountN >= 8 ? 8 : 1; + constexpr size_t N = CountN == 16 ? 8 : CountN; + + if constexpr (CountM == 2) { + accu00 = accu10 = PrepareAccumulator(Bias); + } else { + accu00 = PrepareAccumulator(Bias); + } + if constexpr (CountN == 16) { + if constexpr (CountM == 2) { + accu01 = accu11 = PrepareAccumulator(Bias ? Bias + 8 : nullptr); + } else { + accu01 = PrepareAccumulator(Bias ? Bias + 8 : nullptr); + } + } + + size_t k = 0; + for (; k + 8 <= K; k += 8, A += 8, B += b_step * 8) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + } + + if (K & 4) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + k += 4, A += 4, B += b_step * 4; + } + + if (K & 2) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + k += 2, A += 2, B += b_step * 2; + } + + if (k < K) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + } + + if constexpr (CountN >= 8) { + MlasStoreFloat16x8(C, accu00); + if constexpr (CountN == 16) { + MlasStoreFloat16x8(C + 8, accu01); + } + } else if constexpr (CountN == 4) { + MlasStoreFloat16x4(C, accu00); + } else { + MlasStoreLaneFloat16x4<0>(C, accu00); + if constexpr (CountN == 2) { + MlasStoreLaneFloat16x4<1>(C + 1, accu00); + } + } + + if constexpr (CountM == 2) { + if constexpr (CountN >= 8) { + MlasStoreFloat16x8(C + ldc, accu10); + if constexpr (CountN == 16) { + MlasStoreFloat16x8(C + ldc + 8, accu11); + } + } else if constexpr (CountN == 4) { + MlasStoreFloat16x4(C + ldc, accu10); + } else { + MlasStoreLaneFloat16x4<0>(C + ldc, accu10); + if constexpr (CountN == 2) { + MlasStoreLaneFloat16x4<1>(C + ldc + 1, accu10); + } + } + } +} + +void +HQ4BitGemmKernel_CompFp16( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc +) { + assert(CountM <= 2); + + // 2M_16N is the balance between loop unrolling and register spill. + // More unroll will trigger register spill. + // Less unroll will increase micro kernel dependency path penalty. + // TODO: dequant 16N as continuous segments. Current version dequants 8N. + const auto* a = reinterpret_cast(A); + const auto* b = reinterpret_cast(B); + const auto* bias = reinterpret_cast(Bias); + auto* c = reinterpret_cast<_mlas_fp16_*>(C); + + for (; CountN >= 16; CountN -= 16) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<16, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<16, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 16 * ldb, c += 16; + if (bias) bias += 16; + } + + if (CountN & 8) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<8, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<8, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 8 * ldb, c += 8; + if (bias) bias += 8; + } + + if (CountN & 4) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<4, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<4, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 4 * ldb, c += 4; + if (bias) bias += 4; + } + + if (CountN & 2) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<2, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<2, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 2 * ldb, c += 2; + if (bias) bias += 2; + } + + if (CountN & 1) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<1, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<1, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + } +} +} // namespace sqnbitgemm_neon