mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
add apis to neon and other avxs
Signed-off-by: Liqun Fu <liqun.fu@microsoft.com>
This commit is contained in:
parent
f6f22e30d5
commit
3e1a951448
6 changed files with 92 additions and 2 deletions
|
|
@ -167,6 +167,61 @@ Q4BitGemmPerGemmWorkspaceAlignment(
|
|||
}
|
||||
}
|
||||
|
||||
size_t
|
||||
Q2BitGemmPackQuantBDataSize(
|
||||
size_t /*N*/,
|
||||
size_t /*K*/,
|
||||
size_t /*BlkLen*/,
|
||||
MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/
|
||||
)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
void
|
||||
SQ2BitGemmPackQuantBData(
|
||||
size_t /*N*/,
|
||||
size_t /*K*/,
|
||||
size_t /*BlkLen*/,
|
||||
MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/,
|
||||
const std::byte* /*QuantBDataBegin*/,
|
||||
std::byte* /*PackedQuantBDataBegin*/,
|
||||
MLAS_THREADPOOL* /*ThreadPool*/
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
size_t
|
||||
Q2BitGemmPerGemmWorkspaceSize(
|
||||
size_t /*M*/,
|
||||
size_t /*N*/,
|
||||
size_t /*K*/,
|
||||
size_t /*BlkLen*/,
|
||||
MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/
|
||||
)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t
|
||||
SQ2BitGemmKernel_CompInt8_avx2(
|
||||
size_t /*BlkLen*/,
|
||||
const std::byte* /*QuantA*/,
|
||||
const std::byte* /*QuantBData*/,
|
||||
const float* /*QuantBScale*/,
|
||||
const std::byte* /*QuantBZeroPoint*/,
|
||||
float* /*C*/,
|
||||
size_t /*CountM*/,
|
||||
size_t /*CountN*/,
|
||||
size_t /*CountK*/,
|
||||
size_t /*BlockCountK*/,
|
||||
size_t /*ldc*/,
|
||||
const float* /*Bias*/
|
||||
)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace sqnbitgemm_neon
|
||||
|
|
@ -197,5 +252,12 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() {
|
|||
d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16;
|
||||
#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64
|
||||
|
||||
d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize;
|
||||
d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData;
|
||||
|
||||
d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize;
|
||||
|
||||
d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2;
|
||||
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8;
|
||||
return d;
|
||||
}();
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ SQ2BitGemmKernel_CompInt8_avx2(
|
|||
const float* /*Bias*/
|
||||
)
|
||||
{
|
||||
// reference SQ4BitGemmKernel_CompInt8_avx2
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -101,4 +102,5 @@ QuantizeARow_CompInt8(
|
|||
std::byte* /*QuantA*/
|
||||
)
|
||||
{
|
||||
// shall be similar to QuantizeARow_CompInt8_avx2 without blksum related code.
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1375,5 +1375,14 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() {
|
|||
d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni;
|
||||
d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2;
|
||||
|
||||
// change funcions if implementation are different from avx2
|
||||
d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize;
|
||||
d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData;
|
||||
|
||||
d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize;
|
||||
|
||||
d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2;
|
||||
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8;
|
||||
|
||||
return d;
|
||||
}();
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ Abstract:
|
|||
//
|
||||
|
||||
#include "sqnbitgemm_kernel_avx_common_fp32.h"
|
||||
#include "sqnbitgemm_bitnet_kernel_avx2.h"
|
||||
|
||||
MLAS_FORCEINLINE void
|
||||
SQ4BitGemmM1Kernel_CompFp32_avx512(
|
||||
|
|
@ -368,5 +369,14 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() {
|
|||
d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512;
|
||||
d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512;
|
||||
|
||||
// change funcions if implementation are different from avx2
|
||||
d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize;
|
||||
d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData;
|
||||
|
||||
d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize;
|
||||
|
||||
d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2;
|
||||
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8;
|
||||
|
||||
return d;
|
||||
}();
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ Abstract:
|
|||
#include "sqnbitgemm_kernel_avx512_int8_blklen32.h"
|
||||
#include "sqnbitgemm_kernel_avx512_int8_blklen64.h"
|
||||
#include "sqnbitgemm_kernel_avx512_int8_blklen128.h"
|
||||
#include "sqnbitgemm_bitnet_kernel_avx2.h"
|
||||
|
||||
MLAS_FORCEINLINE void
|
||||
SQ4BitGemmM1Kernel_CompFp32(
|
||||
|
|
@ -353,5 +354,13 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() {
|
|||
d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni;
|
||||
d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512;
|
||||
|
||||
// change funcions if implementation are different from avx2
|
||||
d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize;
|
||||
d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData;
|
||||
|
||||
d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize;
|
||||
|
||||
d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2;
|
||||
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8;
|
||||
return d;
|
||||
}();
|
||||
|
|
|
|||
|
|
@ -147,8 +147,6 @@ class MlasSQNBitGemmTest : public MlasTestBase {
|
|||
b_zp = 8;
|
||||
} else if constexpr (BlkBitWidth == 2) {
|
||||
b_zp = 2;
|
||||
} else {
|
||||
static_assert(false, "only implemented for 2- and 4-bit quantized B");
|
||||
}
|
||||
|
||||
int pack_size = 8 / BlkBitWidth;
|
||||
|
|
|
|||
Loading…
Reference in a new issue