diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index d05de64e68..b12e2358d7 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -167,6 +167,61 @@ Q4BitGemmPerGemmWorkspaceAlignment( } } +size_t +Q2BitGemmPackQuantBDataSize( + size_t /*N*/, + size_t /*K*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ +) +{ + return 0; +} + +void +SQ2BitGemmPackQuantBData( + size_t /*N*/, + size_t /*K*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + const std::byte* /*QuantBDataBegin*/, + std::byte* /*PackedQuantBDataBegin*/, + MLAS_THREADPOOL* /*ThreadPool*/ +) +{ +} + +size_t +Q2BitGemmPerGemmWorkspaceSize( + size_t /*M*/, + size_t /*N*/, + size_t /*K*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ +) +{ + return 0; +} + +size_t +SQ2BitGemmKernel_CompInt8_avx2( + size_t /*BlkLen*/, + const std::byte* /*QuantA*/, + const std::byte* /*QuantBData*/, + const float* /*QuantBScale*/, + const std::byte* /*QuantBZeroPoint*/, + float* /*C*/, + size_t /*CountM*/, + size_t /*CountN*/, + size_t /*CountK*/, + size_t /*BlockCountK*/, + size_t /*ldc*/, + const float* /*Bias*/ +) +{ + return 0; +} + } // namespace } // namespace sqnbitgemm_neon @@ -197,5 +252,12 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16; #endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + + d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + + d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index 1d7a1ce73e..d6d104967e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -90,6 +90,7 @@ SQ2BitGemmKernel_CompInt8_avx2( const float* /*Bias*/ ) { + // reference SQ4BitGemmKernel_CompInt8_avx2 return 0; } @@ -101,4 +102,5 @@ QuantizeARow_CompInt8( std::byte* /*QuantA*/ ) { + // shall be similar to QuantizeARow_CompInt8_avx2 without blksum related code. } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index fe9720fd7e..56c54cf9be 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -1375,5 +1375,14 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; + // change funcions if implementation are different from avx2 + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + + d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + + d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index b4e25d4e40..d07ba72d1e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -32,6 +32,7 @@ Abstract: // #include "sqnbitgemm_kernel_avx_common_fp32.h" +#include "sqnbitgemm_bitnet_kernel_avx2.h" MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompFp32_avx512( @@ -368,5 +369,14 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; + // change funcions if implementation are different from avx2 + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + + d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + + d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index a4468bb906..83fba19c17 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -27,6 +27,7 @@ Abstract: #include "sqnbitgemm_kernel_avx512_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" #include "sqnbitgemm_kernel_avx512_int8_blklen128.h" +#include "sqnbitgemm_bitnet_kernel_avx2.h" MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompFp32( @@ -353,5 +354,13 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; + // change funcions if implementation are different from avx2 + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + + d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + + d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; return d; }(); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 26f02466be..d849118aae 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -147,8 +147,6 @@ class MlasSQNBitGemmTest : public MlasTestBase { b_zp = 8; } else if constexpr (BlkBitWidth == 2) { b_zp = 2; - } else { - static_assert(false, "only implemented for 2- and 4-bit quantized B"); } int pack_size = 8 / BlkBitWidth;