add apis to neon and other avxs

Signed-off-by: Liqun Fu <liqun.fu@microsoft.com>
This commit is contained in:
Liqun Fu 2025-02-03 12:24:40 -08:00
parent f6f22e30d5
commit 3e1a951448
6 changed files with 92 additions and 2 deletions

View file

@ -167,6 +167,61 @@ Q4BitGemmPerGemmWorkspaceAlignment(
}
}
size_t
Q2BitGemmPackQuantBDataSize(
size_t /*N*/,
size_t /*K*/,
size_t /*BlkLen*/,
MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/
)
{
return 0;
}
void
SQ2BitGemmPackQuantBData(
size_t /*N*/,
size_t /*K*/,
size_t /*BlkLen*/,
MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/,
const std::byte* /*QuantBDataBegin*/,
std::byte* /*PackedQuantBDataBegin*/,
MLAS_THREADPOOL* /*ThreadPool*/
)
{
}
size_t
Q2BitGemmPerGemmWorkspaceSize(
size_t /*M*/,
size_t /*N*/,
size_t /*K*/,
size_t /*BlkLen*/,
MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/
)
{
return 0;
}
size_t
SQ2BitGemmKernel_CompInt8_avx2(
size_t /*BlkLen*/,
const std::byte* /*QuantA*/,
const std::byte* /*QuantBData*/,
const float* /*QuantBScale*/,
const std::byte* /*QuantBZeroPoint*/,
float* /*C*/,
size_t /*CountM*/,
size_t /*CountN*/,
size_t /*CountK*/,
size_t /*BlockCountK*/,
size_t /*ldc*/,
const float* /*Bias*/
)
{
return 0;
}
} // namespace
} // namespace sqnbitgemm_neon
@ -197,5 +252,12 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() {
d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16;
#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64
d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize;
d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData;
d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize;
d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2;
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8;
return d;
}();

View file

@ -90,6 +90,7 @@ SQ2BitGemmKernel_CompInt8_avx2(
const float* /*Bias*/
)
{
// reference SQ4BitGemmKernel_CompInt8_avx2
return 0;
}
@ -101,4 +102,5 @@ QuantizeARow_CompInt8(
std::byte* /*QuantA*/
)
{
// shall be similar to QuantizeARow_CompInt8_avx2 without blksum related code.
}

View file

@ -1375,5 +1375,14 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() {
d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni;
d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2;
// change funcions if implementation are different from avx2
d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize;
d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData;
d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize;
d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2;
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8;
return d;
}();

View file

@ -32,6 +32,7 @@ Abstract:
//
#include "sqnbitgemm_kernel_avx_common_fp32.h"
#include "sqnbitgemm_bitnet_kernel_avx2.h"
MLAS_FORCEINLINE void
SQ4BitGemmM1Kernel_CompFp32_avx512(
@ -368,5 +369,14 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() {
d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512;
d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512;
// change funcions if implementation are different from avx2
d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize;
d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData;
d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize;
d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2;
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8;
return d;
}();

View file

@ -27,6 +27,7 @@ Abstract:
#include "sqnbitgemm_kernel_avx512_int8_blklen32.h"
#include "sqnbitgemm_kernel_avx512_int8_blklen64.h"
#include "sqnbitgemm_kernel_avx512_int8_blklen128.h"
#include "sqnbitgemm_bitnet_kernel_avx2.h"
MLAS_FORCEINLINE void
SQ4BitGemmM1Kernel_CompFp32(
@ -353,5 +354,13 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() {
d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni;
d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512;
// change funcions if implementation are different from avx2
d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize;
d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData;
d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize;
d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2;
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8;
return d;
}();

View file

@ -147,8 +147,6 @@ class MlasSQNBitGemmTest : public MlasTestBase {
b_zp = 8;
} else if constexpr (BlkBitWidth == 2) {
b_zp = 2;
} else {
static_assert(false, "only implemented for 2- and 4-bit quantized B");
}
int pack_size = 8 / BlkBitWidth;