mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
[MLAS] AArch64 SQNBitGemm CompInt8 initial multi-row implementation (#21193)
Update AArch64 SQNBitGemm CompInt8 kernels to process matrix in tiles. E.g., computing the output in 2x2 tiles allows us to compute four elements of the output with one read of two rows of A and two columns of B. Also moved some code around as it was getting big for a single file.
This commit is contained in:
parent
8749fa381e
commit
20cd3394fc
12 changed files with 2248 additions and 1384 deletions
|
|
@ -82,7 +82,10 @@ function(setup_mlas_source_for_windows)
|
|||
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
|
||||
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
|
||||
)
|
||||
|
||||
set(mlas_platform_preprocess_srcs
|
||||
|
|
@ -350,9 +353,12 @@ else()
|
|||
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
|
||||
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
|
||||
)
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
|
||||
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
|
||||
if (NOT APPLE)
|
||||
set(mlas_platform_srcs
|
||||
|
|
|
|||
|
|
@ -16,10 +16,11 @@ Abstract:
|
|||
--*/
|
||||
|
||||
#include "sqnbitgemm.h"
|
||||
#include "sqnbitgemm_q8_block.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "sqnbitgemm_q8_block.h"
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
|
|
@ -80,7 +81,7 @@ MlasIsSQNBitGemmAvailable(
|
|||
Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr;
|
||||
}
|
||||
case SQNBitGemmVariant_BitWidth4_CompInt8: {
|
||||
return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr &&
|
||||
return Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr &&
|
||||
Dispatch->QuantizeARow_CompInt8 != nullptr;
|
||||
}
|
||||
default: {
|
||||
|
|
@ -372,15 +373,17 @@ SQ4BitGemm_CompFp32(
|
|||
if (bias) {
|
||||
AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc);
|
||||
}
|
||||
|
||||
if (DataParams->PostProcessor != nullptr) {
|
||||
DataParams->PostProcessor->Process(
|
||||
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN,
|
||||
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n,
|
||||
RowsHandled, CountN, ldc
|
||||
);
|
||||
}
|
||||
|
||||
c_blk += ldc * RowsHandled;
|
||||
a_row += lda * RowsHandled;
|
||||
|
||||
RowsRemaining -= RowsHandled;
|
||||
}
|
||||
}
|
||||
|
|
@ -431,36 +434,6 @@ SQ4BitGemm_CompInt8(
|
|||
|
||||
const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;
|
||||
|
||||
if (RangeCountM == 1) {
|
||||
size_t CountN;
|
||||
for (size_t n = 0; n < RangeCountN; n += CountN) {
|
||||
CountN = std::min(RangeCountN - n, size_t{128});
|
||||
|
||||
const std::byte* a_row = QuantA;
|
||||
const std::byte* b_col = QuantBData + n * ldb;
|
||||
const float* b_col_scale = QuantBScale + n * k_blks;
|
||||
const std::byte* b_col_zp =
|
||||
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
|
||||
float* c_blk = C + n;
|
||||
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
|
||||
|
||||
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8(
|
||||
BlkLen,
|
||||
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
|
||||
);
|
||||
|
||||
if (DataParams->PostProcessor != nullptr) {
|
||||
DataParams->PostProcessor->Process(
|
||||
DataParams->C, RangeStartM, RangeStartN + n,
|
||||
RangeCountM, CountN, ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// This is a naive M > 1 implementation that repeatedly calls the M=1 kernel.
|
||||
// TODO Replace it with an optimized implementation.
|
||||
size_t CountN;
|
||||
for (size_t n = 0; n < RangeCountN; n += CountN) {
|
||||
CountN = std::min(RangeCountN - n, size_t{128});
|
||||
|
|
@ -473,21 +446,24 @@ SQ4BitGemm_CompInt8(
|
|||
float* c_blk = C + n;
|
||||
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
|
||||
|
||||
for (size_t m = 0; m < RangeCountM; ++m) {
|
||||
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8(
|
||||
size_t RowsRemaining = RangeCountM;
|
||||
while (RowsRemaining > 0) {
|
||||
const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8(
|
||||
BlkLen,
|
||||
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
|
||||
a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias
|
||||
);
|
||||
|
||||
if (DataParams->PostProcessor != nullptr) {
|
||||
DataParams->PostProcessor->Process(
|
||||
DataParams->C, RangeStartM, RangeStartN + n,
|
||||
RangeCountM, CountN, ldc
|
||||
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n,
|
||||
RowsHandled, CountN, ldc
|
||||
);
|
||||
}
|
||||
|
||||
c_blk += ldc;
|
||||
a_row += lda;
|
||||
c_blk += RowsHandled * ldc;
|
||||
a_row += RowsHandled * lda;
|
||||
|
||||
RowsRemaining -= RowsHandled;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -184,7 +184,6 @@ struct MLAS_SQNBIT_GEMM_DISPATCH {
|
|||
/**
|
||||
* @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B.
|
||||
* A and B are block quantized and B is column major.
|
||||
* This kernel handles the special case where M, the number of rows of A and C, is 1.
|
||||
*
|
||||
* @param BlkLen Number of values in a block.
|
||||
* @param QuantA Supplies the quantized A matrix.
|
||||
|
|
@ -193,25 +192,31 @@ struct MLAS_SQNBIT_GEMM_DISPATCH {
|
|||
* @param QuantBScale Supplies the quantized B matrix block scale values.
|
||||
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
|
||||
* @param[out] C Supplies the output C matrix.
|
||||
* @param CountN Number of columns of B and C.
|
||||
* @param CountM Number of rows of A and C to process, an upper bound.
|
||||
* @param CountN Number of columns of B and C to process.
|
||||
* @param CountK Number of columns of A and rows of B.
|
||||
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
|
||||
* @param BlockCountK Number of blocks in one row of A and one column of B.
|
||||
* @param ldc Number of elements between adjacent rows of C.
|
||||
* @param Bias Bias vector of length N.
|
||||
*
|
||||
* @return The number of rows of A and C that were processed, at most CountM.
|
||||
*/
|
||||
typedef void(SQ4BitGemmM1Kernel_CompInt8_Fn)(
|
||||
typedef size_t(SQ4BitGemmKernel_CompInt8_Fn)(
|
||||
size_t BlkLen,
|
||||
const std::byte* QuantA,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
float* C,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockStrideQuantB,
|
||||
size_t BlockCountK,
|
||||
size_t ldc,
|
||||
const float* Bias
|
||||
);
|
||||
|
||||
SQ4BitGemmM1Kernel_CompInt8_Fn* SQ4BitGemmM1Kernel_CompInt8 = nullptr;
|
||||
SQ4BitGemmKernel_CompInt8_Fn* SQ4BitGemmKernel_CompInt8 = nullptr;
|
||||
|
||||
/**
|
||||
* @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers.
|
||||
|
|
|
|||
|
|
@ -434,6 +434,44 @@ SQ4BitGemmM1Kernel_CompInt8_avx2(
|
|||
}
|
||||
}
|
||||
|
||||
size_t
|
||||
SQ4BitGemmKernel_CompInt8_avx2(
|
||||
size_t BlkLen,
|
||||
const std::byte* QuantA,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
float* C,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockCountK,
|
||||
size_t ldc,
|
||||
const float* Bias
|
||||
)
|
||||
{
|
||||
MLAS_UNREFERENCED_PARAMETER(ldc);
|
||||
|
||||
if (CountM == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
SQ4BitGemmM1Kernel_CompInt8_avx2(
|
||||
BlkLen,
|
||||
QuantA,
|
||||
QuantBData,
|
||||
QuantBScale,
|
||||
QuantBZeroPoint,
|
||||
C,
|
||||
CountN,
|
||||
CountK,
|
||||
BlockCountK,
|
||||
Bias
|
||||
);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <size_t NCols, bool HasZeroPoint>
|
||||
MLAS_FORCEINLINE void
|
||||
ComputeDotProducts_BlkLen16_CompFp32_avx2(
|
||||
|
|
@ -1109,7 +1147,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() {
|
|||
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2;
|
||||
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;
|
||||
|
||||
d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2;
|
||||
d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2;
|
||||
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx2;
|
||||
|
||||
return d;
|
||||
|
|
|
|||
|
|
@ -239,7 +239,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() {
|
|||
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512;
|
||||
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;
|
||||
|
||||
d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2;
|
||||
d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2;
|
||||
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx512;
|
||||
|
||||
return d;
|
||||
|
|
|
|||
|
|
@ -237,6 +237,44 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni(
|
|||
}
|
||||
}
|
||||
|
||||
size_t
|
||||
SQ4BitGemmKernel_CompInt8_avx512vnni(
|
||||
size_t BlkLen,
|
||||
const std::byte* QuantA,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
float* C,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockCountK,
|
||||
size_t ldc,
|
||||
const float* Bias
|
||||
)
|
||||
{
|
||||
MLAS_UNREFERENCED_PARAMETER(ldc);
|
||||
|
||||
if (CountM == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
SQ4BitGemmM1Kernel_CompInt8_avx512vnni(
|
||||
BlkLen,
|
||||
QuantA,
|
||||
QuantBData,
|
||||
QuantBScale,
|
||||
QuantBZeroPoint,
|
||||
C,
|
||||
CountN,
|
||||
CountK,
|
||||
BlockCountK,
|
||||
Bias
|
||||
);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
void MLASCALL
|
||||
MlasQ80BlkQuantRow_avx512(
|
||||
size_t BlkLen,
|
||||
|
|
@ -260,7 +298,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() {
|
|||
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32;
|
||||
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;
|
||||
|
||||
d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx512vnni;
|
||||
d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx512vnni;
|
||||
d.QuantizeARow_CompInt8 = MlasQ80BlkQuantRow_avx512;
|
||||
|
||||
return d;
|
||||
|
|
|
|||
|
|
@ -158,17 +158,19 @@ Q4BitBlkDequantBForSgemm_CompFp32_avx2(
|
|||
const size_t BlockStrideQuantB
|
||||
);
|
||||
|
||||
void
|
||||
SQ4BitGemmM1Kernel_CompInt8_avx2(
|
||||
size_t
|
||||
SQ4BitGemmKernel_CompInt8_avx2(
|
||||
size_t BlkLen,
|
||||
const std::byte* QuantA,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
float* C,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockStrideQuantB,
|
||||
size_t BlockCountK,
|
||||
size_t ldc,
|
||||
const float* Bias
|
||||
);
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
144
onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h
Normal file
144
onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
sqnbitgemm_kernel_neon.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module includes function declarations and common helper functions for
|
||||
SQNBitGemm ARM NEON kernels.
|
||||
|
||||
--*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <arm_neon.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
|
||||
#include "mlasi.h"
|
||||
|
||||
namespace sqnbitgemm_neon
|
||||
{
|
||||
|
||||
//
|
||||
// Function declarations for SQNBitGemm ARM NEON kernel entry points.
|
||||
// Refer to the prototypes in sqnbitgemm.h for documentation.
|
||||
// These are declared here so they can be used to initialize the
|
||||
// MLAS_SQNBIT_GEMM_DISPATCH structure and also be implemented in separate
|
||||
// files.
|
||||
//
|
||||
|
||||
// CompFp32 declarations
|
||||
|
||||
void
|
||||
SQ4BitGemmM1Kernel_CompFp32(
|
||||
size_t BlkLen,
|
||||
const float* A,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
float* C,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockCountK,
|
||||
const float* Bias
|
||||
);
|
||||
|
||||
void
|
||||
Q4BitBlkDequantBForSgemm_CompFp32(
|
||||
size_t BlkLen,
|
||||
float* FpData,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockCountK
|
||||
);
|
||||
|
||||
// CompInt8 declarations
|
||||
|
||||
void
|
||||
QuantizeARow_CompInt8(
|
||||
size_t BlkLen,
|
||||
const float* A,
|
||||
size_t CountK,
|
||||
std::byte* QuantA
|
||||
);
|
||||
|
||||
size_t
|
||||
SQ4BitGemmKernel_CompInt8(
|
||||
size_t BlkLen,
|
||||
const std::byte* QuantA,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
float* C,
|
||||
size_t CountM,
|
||||
size_t CountN,
|
||||
size_t /*CountK*/,
|
||||
size_t BlockCountK,
|
||||
size_t ldc,
|
||||
const float* Bias
|
||||
);
|
||||
|
||||
//
|
||||
// General helpers.
|
||||
//
|
||||
|
||||
template <typename IterationFn, size_t... Indices>
|
||||
MLAS_FORCEINLINE void
|
||||
UnrolledLoopIterations(IterationFn&& f, std::index_sequence<Indices...> /* indices */)
|
||||
{
|
||||
(f(Indices), ...);
|
||||
}
|
||||
|
||||
template <size_t N, typename IterationFn>
|
||||
MLAS_FORCEINLINE void
|
||||
UnrolledLoop(IterationFn&& f)
|
||||
{
|
||||
UnrolledLoopIterations(std::forward<IterationFn>(f), std::make_index_sequence<N>());
|
||||
}
|
||||
|
||||
template <size_t Capacity>
|
||||
MLAS_FORCEINLINE void
|
||||
LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4])
|
||||
{
|
||||
static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4.");
|
||||
|
||||
assert(count <= Capacity);
|
||||
|
||||
size_t vi = 0; // vector index
|
||||
|
||||
// handle 4 values at a time
|
||||
while (count > 3) {
|
||||
dst[vi] = vld1q_f32(src);
|
||||
|
||||
vi += 1;
|
||||
src += 4;
|
||||
count -= 4;
|
||||
}
|
||||
|
||||
// handle remaining values
|
||||
if (count > 0) {
|
||||
dst[vi] = vsetq_lane_f32(src[0], dst[vi], 0);
|
||||
|
||||
if (count > 1) {
|
||||
dst[vi] = vsetq_lane_f32(src[1], dst[vi], 1);
|
||||
|
||||
if (count > 2) {
|
||||
dst[vi] = vsetq_lane_f32(src[2], dst[vi], 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sqnbitgemm_neon
|
||||
646
onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp
Normal file
646
onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp
Normal file
|
|
@ -0,0 +1,646 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
sqnbitgemm_kernel_neon_fp32.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
This module implements the float/quantized n-bit integer matrix
|
||||
multiplication kernels for ARM NEON specific to
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompFp32.
|
||||
|
||||
--*/
|
||||
|
||||
#include <arm_neon.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "sqnbitgemm.h"
|
||||
#include "sqnbitgemm_kernel_neon.h"
|
||||
|
||||
namespace sqnbitgemm_neon
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
//
|
||||
// CompFp32 kernel implementation.
|
||||
//
|
||||
|
||||
MLAS_FORCEINLINE void
|
||||
Transpose4x4(float32x4_t& a0, float32x4_t& a1, float32x4_t& a2, float32x4_t& a3)
|
||||
{
|
||||
// aN: aN_0 aN_1 aN_2 aN_3
|
||||
|
||||
float32x4_t b0 = vzip1q_f32(a0, a1); // a0_0 a1_0 a0_1 a1_1
|
||||
float32x4_t b1 = vzip2q_f32(a0, a1); // a0_2 a1_2 a0_3 a1_3
|
||||
float32x4_t b2 = vzip1q_f32(a2, a3); // a2_0 a3_0 a2_1 a3_1
|
||||
float32x4_t b3 = vzip2q_f32(a2, a3); // a2_2 a3_2 a2_3 a3_3
|
||||
|
||||
// a0_0 a1_0 a2_0 a3_0
|
||||
a0 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2)));
|
||||
// a0_1 a1_1 a2_1 a3_1
|
||||
a1 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2)));
|
||||
// a0_2 a1_2 a3_2 a3_2
|
||||
a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3)));
|
||||
// a0_3 a1_3 a2_3 a3_3
|
||||
a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3)));
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE float32x4_t
|
||||
FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3)
|
||||
{
|
||||
Transpose4x4(a0, a1, a2, a3);
|
||||
return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3));
|
||||
}
|
||||
|
||||
namespace fp32_conversion
|
||||
{
|
||||
|
||||
// Manual conversion to float takes place in two steps:
|
||||
// 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f].
|
||||
// This target float range is convenient because the 4-bit source values can be placed directly into the
|
||||
// target float bits.
|
||||
// 2. Subtract the conversion offset of 16 from the float result.
|
||||
|
||||
// The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values.
|
||||
constexpr uint16_t float_high_half_template = 0b0'10000011'0000000;
|
||||
// sign|exponent|partial mantissa
|
||||
// +|131: 2^4|~~~~ <- 4 bits go here
|
||||
|
||||
const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template);
|
||||
|
||||
constexpr float offset = 16.0f;
|
||||
|
||||
} // namespace fp32_conversion
|
||||
|
||||
template <size_t NCols, bool HasZeroPoint>
|
||||
MLAS_FORCEINLINE void
|
||||
ComputeDotProducts_BlkBitWidth4_CompFp32(
|
||||
size_t BlkLen,
|
||||
const float* ARowPtr,
|
||||
const std::byte* QuantBDataColPtr,
|
||||
const float* QuantBScaleColPtr,
|
||||
const std::byte* QuantBZeroPointColPtr,
|
||||
float* SumPtr,
|
||||
size_t CountK,
|
||||
size_t StrideQuantBData,
|
||||
size_t StrideQuantBScale,
|
||||
size_t StrideQuantBZeroPoint,
|
||||
const float* BiasPtr
|
||||
)
|
||||
{
|
||||
constexpr size_t BlkBitWidth = 4;
|
||||
constexpr size_t SubBlkLen = 16;
|
||||
|
||||
static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4");
|
||||
|
||||
assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0);
|
||||
|
||||
const uint8x8_t LowMask = vdup_n_u8(0x0F);
|
||||
|
||||
float32x4_t acc[NCols]{};
|
||||
|
||||
const std::byte* QuantBData = QuantBDataColPtr;
|
||||
const float* QuantBScale = QuantBScaleColPtr;
|
||||
[[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
|
||||
// only used if HasZeroPoint is true
|
||||
|
||||
for (size_t k = 0; k < CountK; k += BlkLen) {
|
||||
const size_t k_blk_len = std::min(CountK - k, BlkLen);
|
||||
|
||||
float scale[NCols];
|
||||
UnrolledLoop<NCols>(
|
||||
[&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; }
|
||||
);
|
||||
|
||||
[[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset.
|
||||
// only used if HasZeroPoint is true
|
||||
if constexpr (HasZeroPoint) {
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
const std::byte zp_packed =
|
||||
QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2];
|
||||
const std::byte zp = ((QuantBZeroPointIdx & 1) == 1)
|
||||
? (zp_packed >> 4)
|
||||
: (zp_packed & std::byte{0x0F});
|
||||
offset[i] = fp32_conversion::offset + std::to_integer<uint8_t>(zp);
|
||||
});
|
||||
}
|
||||
|
||||
for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) {
|
||||
// load A row vector elements
|
||||
|
||||
// load `SubBlkLen` elements from A, padded with 0's if there aren't enough
|
||||
const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen);
|
||||
float32x4_t av[4]{};
|
||||
LoadFloatData<SubBlkLen>(ARowPtr + k + k_idx_in_blk, k_subblk_len, av);
|
||||
|
||||
// load B column vectors
|
||||
uint8x8_t bv_packed[NCols];
|
||||
const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8;
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
bv_packed[i] = vld1_u8(
|
||||
reinterpret_cast<const uint8_t*>(QuantBData) + i * StrideQuantBData + b_data_block_offset
|
||||
);
|
||||
});
|
||||
|
||||
uint8x8_t bv_u8[NCols][2];
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
bv_u8[i][0] = vand_u8(bv_packed[i], LowMask);
|
||||
bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4);
|
||||
});
|
||||
|
||||
// shift left 3 and widen to 16 bits
|
||||
uint16x8_t bv_u16[NCols][2];
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
constexpr int shift = 3;
|
||||
bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift);
|
||||
bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift);
|
||||
});
|
||||
|
||||
// combine 4 bits with float high half template
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v);
|
||||
bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v);
|
||||
});
|
||||
|
||||
// `SubBlkLen` floats of B
|
||||
float32x4_t bv[NCols][4];
|
||||
|
||||
// shift left 16, widen to 32 bits, and reinterpret as float
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
constexpr int shift = 16;
|
||||
bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift));
|
||||
bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift));
|
||||
|
||||
bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift));
|
||||
bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift));
|
||||
});
|
||||
|
||||
// subtract float conversion offset and zero point
|
||||
if constexpr (HasZeroPoint) {
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
const float32x4_t offset_v = vdupq_n_f32(offset[i]);
|
||||
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
|
||||
});
|
||||
} else {
|
||||
const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f);
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
|
||||
});
|
||||
}
|
||||
|
||||
// multiply by scale
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
const float32x4_t scale_v = vdupq_n_f32(scale[i]);
|
||||
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); });
|
||||
});
|
||||
|
||||
// c[m,n] += a[m,k] * b[k,n]
|
||||
UnrolledLoop<4>([&](size_t j) {
|
||||
UnrolledLoop<NCols>([&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); });
|
||||
});
|
||||
}
|
||||
|
||||
// increment pointers to next block
|
||||
QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
||||
QuantBScale += 1;
|
||||
if constexpr (HasZeroPoint) {
|
||||
QuantBZeroPointIdx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (NCols == 4) {
|
||||
float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]);
|
||||
|
||||
if (BiasPtr != nullptr) {
|
||||
sum = vaddq_f32(sum, vld1q_f32(BiasPtr));
|
||||
}
|
||||
|
||||
vst1q_f32(SumPtr, sum);
|
||||
} else {
|
||||
for (size_t i = 0; i < NCols; ++i) {
|
||||
SumPtr[i] = vaddvq_f32(acc[i]);
|
||||
if (BiasPtr != nullptr) {
|
||||
SumPtr[i] += BiasPtr[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasZeroPoint>
|
||||
void
|
||||
SQ4BitGemmM1Kernel_CompFp32_Impl(
|
||||
size_t BlkLen,
|
||||
const float* A,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
float* C,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockCountK,
|
||||
const float* Bias
|
||||
)
|
||||
{
|
||||
constexpr size_t BlkBitWidth = 4;
|
||||
constexpr size_t NCols = 4;
|
||||
|
||||
const float* ARowPtr = A;
|
||||
float* CRowPtr = C;
|
||||
|
||||
const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
||||
const size_t StrideQuantBScale = BlockCountK;
|
||||
const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(BlockCountK);
|
||||
|
||||
const float* BiasPtr = Bias;
|
||||
|
||||
const std::byte* QuantBDataColPtr = QuantBData;
|
||||
const float* QuantBScaleColPtr = QuantBScale;
|
||||
const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint;
|
||||
|
||||
float* SumPtr = CRowPtr;
|
||||
|
||||
int64_t nblk = static_cast<int64_t>(CountN) - NCols;
|
||||
|
||||
while (nblk >= 0) {
|
||||
ComputeDotProducts_BlkBitWidth4_CompFp32<NCols, HasZeroPoint>(
|
||||
BlkLen,
|
||||
ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
|
||||
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
|
||||
BiasPtr
|
||||
);
|
||||
|
||||
// move to next `NCols` columns
|
||||
|
||||
QuantBDataColPtr += NCols * StrideQuantBData;
|
||||
QuantBScaleColPtr += NCols * StrideQuantBScale;
|
||||
if constexpr (HasZeroPoint) {
|
||||
QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint;
|
||||
}
|
||||
|
||||
BiasPtr += BiasPtr != nullptr ? NCols : 0;
|
||||
SumPtr += NCols;
|
||||
|
||||
nblk -= NCols;
|
||||
}
|
||||
|
||||
// left over columns less than `NCols`?
|
||||
nblk += NCols;
|
||||
for (int64_t n = 0; n < nblk; ++n) {
|
||||
ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>(
|
||||
BlkLen,
|
||||
ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
|
||||
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
|
||||
BiasPtr
|
||||
);
|
||||
|
||||
// move to next column
|
||||
|
||||
QuantBDataColPtr += StrideQuantBData;
|
||||
QuantBScaleColPtr += StrideQuantBScale;
|
||||
if constexpr (HasZeroPoint) {
|
||||
QuantBZeroPointColPtr += StrideQuantBZeroPoint;
|
||||
}
|
||||
|
||||
BiasPtr += BiasPtr != nullptr ? 1 : 0;
|
||||
SumPtr += 1;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void
|
||||
SQ4BitGemmM1Kernel_CompFp32(
|
||||
size_t BlkLen,
|
||||
const float* A,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
float* C,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockCountK,
|
||||
const float* Bias
|
||||
)
|
||||
{
|
||||
if (QuantBZeroPoint != nullptr) {
|
||||
constexpr bool HasZeroPoint = true;
|
||||
SQ4BitGemmM1Kernel_CompFp32_Impl<HasZeroPoint>(
|
||||
BlkLen,
|
||||
A,
|
||||
QuantBData,
|
||||
QuantBScale,
|
||||
QuantBZeroPoint,
|
||||
C,
|
||||
CountN,
|
||||
CountK,
|
||||
BlockCountK,
|
||||
Bias
|
||||
);
|
||||
} else {
|
||||
constexpr bool HasZeroPoint = false;
|
||||
SQ4BitGemmM1Kernel_CompFp32_Impl<HasZeroPoint>(
|
||||
BlkLen,
|
||||
A,
|
||||
QuantBData,
|
||||
QuantBScale,
|
||||
QuantBZeroPoint,
|
||||
C,
|
||||
CountN,
|
||||
CountK,
|
||||
BlockCountK,
|
||||
Bias
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
// Block dequantize a 16 x NCols section of B from column major source to row major destination.
|
||||
template <size_t NCols, bool HasZeroPoint>
|
||||
MLAS_FORCEINLINE void
|
||||
Q4BitBlkDequantB_16xNCols(
|
||||
const std::byte* QuantBDataPtr,
|
||||
size_t StrideQuantBData,
|
||||
const float* QuantBColScalePtr, // pointer to NCols scales of adjacent columns
|
||||
[[maybe_unused]] const float* QuantBColOffsetPtr, // pointer to NCols offsets of adjacent columns
|
||||
// only used if HasZeroPoint is true
|
||||
float* DstColPtr
|
||||
)
|
||||
{
|
||||
const uint8x8_t LowMask = vdup_n_u8(0x0F);
|
||||
|
||||
// load B column vectors
|
||||
uint8x8_t bv_packed[NCols];
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
bv_packed[i] = vld1_u8(
|
||||
reinterpret_cast<const uint8_t*>(QuantBDataPtr) + i * StrideQuantBData
|
||||
);
|
||||
});
|
||||
|
||||
uint8x8_t bv_u8[NCols][2];
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
bv_u8[i][0] = vand_u8(bv_packed[i], LowMask);
|
||||
bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4);
|
||||
});
|
||||
|
||||
// shift left 3 and widen to 16 bits
|
||||
uint16x8_t bv_u16[NCols][2];
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
constexpr int shift = 3;
|
||||
bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift);
|
||||
bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift);
|
||||
});
|
||||
|
||||
// combine 4 bits with float high half template
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v);
|
||||
bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v);
|
||||
});
|
||||
|
||||
// `SubBlkLen` floats of B
|
||||
float32x4_t bv[NCols][4];
|
||||
|
||||
// shift left 16, widen to 32 bits, and reinterpret as float
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
constexpr int shift = 16;
|
||||
bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift));
|
||||
bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift));
|
||||
|
||||
bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift));
|
||||
bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift));
|
||||
});
|
||||
|
||||
// subtract float conversion offset and zero point
|
||||
if constexpr (HasZeroPoint) {
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
const float32x4_t offset_v = vdupq_n_f32(QuantBColOffsetPtr[i]);
|
||||
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
|
||||
});
|
||||
} else {
|
||||
const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f);
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
|
||||
});
|
||||
}
|
||||
|
||||
// multiply by scale
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
const float32x4_t scale_v = vdupq_n_f32(QuantBColScalePtr[i]);
|
||||
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); });
|
||||
});
|
||||
|
||||
// write, transposed, 16 x NCols values
|
||||
if constexpr (NCols == 4) {
|
||||
UnrolledLoop<4>([&](size_t j) {
|
||||
Transpose4x4(bv[0][j], bv[1][j], bv[2][j], bv[3][j]);
|
||||
|
||||
vst1q_f32(&DstColPtr[(j * 4 + 0) * 16], bv[0][j]);
|
||||
vst1q_f32(&DstColPtr[(j * 4 + 1) * 16], bv[1][j]);
|
||||
vst1q_f32(&DstColPtr[(j * 4 + 2) * 16], bv[2][j]);
|
||||
vst1q_f32(&DstColPtr[(j * 4 + 3) * 16], bv[3][j]);
|
||||
});
|
||||
} else {
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
UnrolledLoop<4>([&](size_t j) {
|
||||
DstColPtr[(j * 4 + 0) * 16 + i] = vgetq_lane_f32(bv[i][j], 0);
|
||||
DstColPtr[(j * 4 + 1) * 16 + i] = vgetq_lane_f32(bv[i][j], 1);
|
||||
DstColPtr[(j * 4 + 2) * 16 + i] = vgetq_lane_f32(bv[i][j], 2);
|
||||
DstColPtr[(j * 4 + 3) * 16 + i] = vgetq_lane_f32(bv[i][j], 3);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasZeroPoint>
|
||||
void
|
||||
Q4BitBlkDequantBForSgemm_CompFp32_Impl(
|
||||
size_t BlkLen,
|
||||
float* FpData,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockCountK
|
||||
)
|
||||
{
|
||||
constexpr size_t BlkBitWidth = 4;
|
||||
|
||||
float* Dst = FpData;
|
||||
|
||||
const std::byte* QuantBDataCol = QuantBData;
|
||||
const float* QuantBScaleCol = QuantBScale;
|
||||
[[maybe_unused]] const std::byte* QuantBZeroPointCol = QuantBZeroPoint; // only used if HasZeroPoint is true
|
||||
|
||||
const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
||||
[[maybe_unused]] const size_t StrideQuantBZeroPoint = // only used if HasZeroPoint is true
|
||||
MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(BlockCountK);
|
||||
|
||||
//
|
||||
// Proceed down 16 column-wide regions of B. Dequantize and write output 16 x 16 elements at a time.
|
||||
//
|
||||
|
||||
// scales of blocks from 16 adjacent columns
|
||||
float scale[16];
|
||||
// float conversion offsets (including zero point) of blocks from 16 adjacent columns
|
||||
[[maybe_unused]] float offset[16]; // only used if HasZeroPoint is true
|
||||
|
||||
size_t n_cols_remaining = CountN;
|
||||
while (n_cols_remaining > 15) {
|
||||
for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) {
|
||||
for (size_t nn = 0; nn < 16; ++nn) {
|
||||
scale[nn] = QuantBScaleCol[nn * BlockCountK + k_blk_idx];
|
||||
|
||||
if constexpr (HasZeroPoint) {
|
||||
const std::byte zp_packed =
|
||||
QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2];
|
||||
const std::byte zp = ((k_blk_idx & 1) == 1)
|
||||
? (zp_packed >> 4)
|
||||
: (zp_packed & std::byte{0x0F});
|
||||
offset[nn] = fp32_conversion::offset + std::to_integer<uint8_t>(zp);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t kklen = std::min(CountK - k, BlkLen);
|
||||
|
||||
for (size_t kk = 0; kk < kklen; kk += 16) {
|
||||
constexpr size_t NCols = 4;
|
||||
|
||||
const float* ScalePtr = &scale[0];
|
||||
const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr;
|
||||
|
||||
float* DstColPtr = Dst;
|
||||
|
||||
for (size_t nn = 0; nn < 16; nn += NCols) {
|
||||
const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8;
|
||||
|
||||
Q4BitBlkDequantB_16xNCols<NCols, HasZeroPoint>(
|
||||
QuantBDataPtr,
|
||||
StrideQuantBData,
|
||||
ScalePtr,
|
||||
OffsetPtr,
|
||||
DstColPtr
|
||||
);
|
||||
|
||||
ScalePtr += NCols;
|
||||
if constexpr (HasZeroPoint) {
|
||||
OffsetPtr += NCols;
|
||||
}
|
||||
DstColPtr += NCols;
|
||||
}
|
||||
|
||||
Dst += 16 * std::min(kklen - kk, size_t{16});
|
||||
}
|
||||
}
|
||||
|
||||
n_cols_remaining -= 16;
|
||||
|
||||
QuantBDataCol += 16 * StrideQuantBData;
|
||||
QuantBScaleCol += 16 * BlockCountK;
|
||||
if constexpr (HasZeroPoint) {
|
||||
QuantBZeroPointCol += 16 * StrideQuantBZeroPoint;
|
||||
}
|
||||
}
|
||||
|
||||
if (n_cols_remaining > 0) {
|
||||
for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) {
|
||||
for (size_t nn = 0; nn < n_cols_remaining; ++nn) {
|
||||
scale[nn] = QuantBScaleCol[nn * BlockCountK + k_blk_idx];
|
||||
|
||||
if constexpr (HasZeroPoint) {
|
||||
const std::byte zp_packed =
|
||||
QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2];
|
||||
const std::byte zp = ((k_blk_idx & 1) == 1)
|
||||
? (zp_packed >> 4)
|
||||
: (zp_packed & std::byte{0x0F});
|
||||
offset[nn] = fp32_conversion::offset + std::to_integer<uint8_t>(zp);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t kklen = std::min(CountK - k, BlkLen);
|
||||
|
||||
for (size_t kk = 0; kk < kklen; kk += 16) {
|
||||
// zero out the 16x16 block in Dst first to ensure zero padding
|
||||
const float32x4_t zero_v = vdupq_n_f32(0.0f);
|
||||
UnrolledLoop<16 * 4>([&](size_t i) {
|
||||
vst1q_f32(Dst + 4 * i, zero_v);
|
||||
});
|
||||
|
||||
const float* ScalePtr = &scale[0];
|
||||
const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr;
|
||||
|
||||
float* DstColPtr = Dst;
|
||||
|
||||
for (size_t nn = 0; nn < n_cols_remaining; ++nn) {
|
||||
const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8;
|
||||
|
||||
Q4BitBlkDequantB_16xNCols<1, HasZeroPoint>(
|
||||
QuantBDataPtr,
|
||||
StrideQuantBData,
|
||||
ScalePtr,
|
||||
OffsetPtr,
|
||||
DstColPtr
|
||||
);
|
||||
|
||||
ScalePtr += 1;
|
||||
if constexpr (HasZeroPoint) {
|
||||
OffsetPtr += 1;
|
||||
}
|
||||
DstColPtr += 1;
|
||||
}
|
||||
|
||||
Dst += 16 * std::min(kklen - kk, size_t{16});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void
|
||||
Q4BitBlkDequantBForSgemm_CompFp32(
|
||||
size_t BlkLen,
|
||||
float* FpData,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockCountK
|
||||
)
|
||||
{
|
||||
if (QuantBZeroPoint != nullptr) {
|
||||
Q4BitBlkDequantBForSgemm_CompFp32_Impl<true>(
|
||||
BlkLen,
|
||||
FpData,
|
||||
QuantBData,
|
||||
QuantBScale,
|
||||
QuantBZeroPoint,
|
||||
CountN,
|
||||
CountK,
|
||||
BlockCountK
|
||||
);
|
||||
} else {
|
||||
Q4BitBlkDequantBForSgemm_CompFp32_Impl<false>(
|
||||
BlkLen,
|
||||
FpData,
|
||||
QuantBData,
|
||||
QuantBScale,
|
||||
QuantBZeroPoint,
|
||||
CountN,
|
||||
CountK,
|
||||
BlockCountK
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sqnbitgemm_neon
|
||||
1315
onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp
Normal file
1315
onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -419,9 +419,10 @@ static size_t SQNBitGemmRegisterAllShortExecuteTests() {
|
|||
return count;
|
||||
}
|
||||
|
||||
static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) {
|
||||
if (is_short_execute) {
|
||||
return SQNBitGemmRegisterAllShortExecuteTests() > 0;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
static UNUSED_VARIABLE bool added_to_main = AddTestRegister(
|
||||
[](bool is_short_execute) -> size_t {
|
||||
if (is_short_execute) {
|
||||
return SQNBitGemmRegisterAllShortExecuteTests();
|
||||
}
|
||||
return 0;
|
||||
});
|
||||
|
|
|
|||
Loading…
Reference in a new issue