mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-24 02:47:54 +00:00
Implement ARM NEON SQNBitGemm kernel that first block quantizes A to int8 and then does int8 multiplication.
820 lines
23 KiB
C++
820 lines
23 KiB
C++
/*++
|
|
|
|
Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
Licensed under the MIT License.
|
|
|
|
Module Name:
|
|
|
|
sqnbitgemm.cpp
|
|
|
|
Abstract:
|
|
|
|
This module implements the float/quantized n-bit integer matrix
|
|
multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch,
|
|
as well as some SQNBitGemm-related query functions.
|
|
--*/
|
|
|
|
#include "sqnbitgemm.h"
|
|
|
|
#include <cassert>
|
|
|
|
#ifdef MLAS_JBLAS
|
|
#include "jblas_gemm.h"
|
|
#endif
|
|
|
|
namespace
|
|
{
|
|
|
|
enum SQNBitGemmVariant {
|
|
SQNBitGemmVariantInvalid = -1,
|
|
|
|
// Valid variants
|
|
|
|
SQNBitGemmVariant_BitWidth4_CompFp32 = 0,
|
|
SQNBitGemmVariant_BitWidth4_CompInt8,
|
|
|
|
// End of valid variants
|
|
|
|
// Keep this element last and ensure that its value is the number of valid SQNBitGemmVariant values.
|
|
// Its value is used as an array size.
|
|
SQNBitGemmVariantCount,
|
|
};
|
|
|
|
SQNBitGemmVariant
|
|
GetSQNBitGemmVariant(
|
|
size_t M,
|
|
size_t N,
|
|
size_t K,
|
|
size_t BlkBitWidth,
|
|
size_t BlkLen,
|
|
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
|
)
|
|
{
|
|
MLAS_UNREFERENCED_PARAMETER(N);
|
|
MLAS_UNREFERENCED_PARAMETER(K);
|
|
|
|
if (BlkBitWidth == 4 &&
|
|
(BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) {
|
|
if (ComputeType == CompFp32 ||
|
|
ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32
|
|
return SQNBitGemmVariant_BitWidth4_CompFp32;
|
|
} else if (ComputeType == CompInt8 && M == 1) {
|
|
return SQNBitGemmVariant_BitWidth4_CompInt8;
|
|
}
|
|
}
|
|
|
|
return SQNBitGemmVariantInvalid;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool MLASCALL
|
|
MlasIsSQNBitGemmAvailable(
|
|
size_t M,
|
|
size_t N,
|
|
size_t K,
|
|
size_t BlkBitWidth,
|
|
size_t BlkLen,
|
|
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
|
)
|
|
{
|
|
const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch;
|
|
if (Dispatch == nullptr) {
|
|
return false;
|
|
}
|
|
|
|
const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType);
|
|
|
|
switch (Variant) {
|
|
case SQNBitGemmVariant_BitWidth4_CompFp32: {
|
|
return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr &&
|
|
Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr;
|
|
}
|
|
case SQNBitGemmVariant_BitWidth4_CompInt8: {
|
|
return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr &&
|
|
Dispatch->QuantizeARow_CompInt8 != nullptr;
|
|
}
|
|
default: {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace
|
|
{
|
|
|
|
size_t
|
|
SQNBitGemmWorkspaceAlignment(SQNBitGemmVariant Variant)
|
|
{
|
|
switch (Variant) {
|
|
case SQNBitGemmVariant_BitWidth4_CompInt8: {
|
|
return Q8BlkAlignment();
|
|
}
|
|
default: {
|
|
return 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
size_t
|
|
SQNBitGemmPerGemmWorkspaceSize(
|
|
SQNBitGemmVariant Variant,
|
|
size_t M,
|
|
size_t N,
|
|
size_t K,
|
|
size_t BlkLen
|
|
)
|
|
{
|
|
MLAS_UNREFERENCED_PARAMETER(N);
|
|
|
|
switch (Variant) {
|
|
case SQNBitGemmVariant_BitWidth4_CompInt8: {
|
|
// workspace buffer is used for block quantization of A to int8
|
|
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
|
|
const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen);
|
|
return PerGemmWorkspaceSize;
|
|
}
|
|
default: {
|
|
return 0;
|
|
}
|
|
}
|
|
}
|
|
|
|
size_t
|
|
SQNBitGemmPerGemmWorkspaceStride(
|
|
SQNBitGemmVariant Variant,
|
|
size_t M,
|
|
size_t N,
|
|
size_t K,
|
|
size_t BlkLen
|
|
)
|
|
{
|
|
const auto Size = SQNBitGemmPerGemmWorkspaceSize(Variant, M, N, K, BlkLen);
|
|
const auto Alignment = SQNBitGemmWorkspaceAlignment(Variant);
|
|
return MlasDivRoundup(Size, Alignment) * Alignment;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
size_t MLASCALL
|
|
MlasSQNBitGemmBatchWorkspaceSize(
|
|
size_t M,
|
|
size_t N,
|
|
size_t K,
|
|
size_t BatchN,
|
|
size_t BlkBitWidth,
|
|
size_t BlkLen,
|
|
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
|
)
|
|
{
|
|
const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType);
|
|
|
|
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen);
|
|
if (PerGemmWorkspaceStride == 0) {
|
|
return 0;
|
|
}
|
|
|
|
const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant);
|
|
|
|
const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride;
|
|
|
|
return WorkspaceSize + Alignment - 1;
|
|
}
|
|
|
|
namespace
|
|
{
|
|
|
|
void
|
|
SQ4BitGemmPackQuantBData(
|
|
size_t N,
|
|
size_t K,
|
|
size_t BlkLen,
|
|
const std::byte* QuantBDataBegin,
|
|
std::byte* PackedQuantBDataBegin,
|
|
MLAS_THREADPOOL* ThreadPool
|
|
)
|
|
{
|
|
constexpr size_t BlkBitWidth = 4;
|
|
|
|
assert(BlkLen % 16 == 0);
|
|
|
|
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
|
|
const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
|
const size_t Iterations = N * BlockCountK; // one iteration per block
|
|
|
|
MlasTrySimpleParallel(
|
|
ThreadPool, Iterations,
|
|
[&](ptrdiff_t tid) {
|
|
const size_t n = tid / BlockCountK;
|
|
const size_t k_blk = tid % BlockCountK;
|
|
|
|
const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize;
|
|
const std::byte* QuantBData = QuantBDataBegin + data_offset;
|
|
std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset;
|
|
|
|
//
|
|
// Pack 16 4-bit values (8 bytes) at a time like this:
|
|
//
|
|
// src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF |
|
|
// =>
|
|
// dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF |
|
|
//
|
|
for (size_t kk = 0; kk < BlkLen; kk += 16) {
|
|
for (size_t byte_pair_idx = 0; byte_pair_idx < 4; ++byte_pair_idx) {
|
|
const std::byte src0 = QuantBData[byte_pair_idx];
|
|
const std::byte src1 = QuantBData[byte_pair_idx + 4];
|
|
|
|
std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx];
|
|
std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1];
|
|
|
|
dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4);
|
|
dst1 = (src0 >> 4) | ((src1 >> 4) << 4);
|
|
}
|
|
|
|
QuantBData += 8;
|
|
PackedQuantBData += 8;
|
|
}
|
|
}
|
|
);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
size_t MLASCALL
|
|
MlasSQNBitGemmPackQuantBDataSize(
|
|
size_t N,
|
|
size_t K,
|
|
size_t BlkBitWidth,
|
|
size_t BlkLen
|
|
)
|
|
{
|
|
// Ensure that a general implementation is available on this platform.
|
|
// For now, all implementations share the same packed format.
|
|
{
|
|
// Currently, there are implementations specific to M = 1, so pick a more general M > 1.
|
|
constexpr size_t M = 2;
|
|
// A CompUndef implementation should be available if any is available.
|
|
constexpr MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType = CompUndef;
|
|
const bool HasGeneralImplementation =
|
|
MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType);
|
|
if (!HasGeneralImplementation) {
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
if (BlkBitWidth == 4) {
|
|
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
|
|
const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
|
return PackedQuantBDataSize;
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
void MLASCALL
|
|
MlasSQNBitGemmPackQuantBData(
|
|
size_t N,
|
|
size_t K,
|
|
size_t BlkBitWidth,
|
|
size_t BlkLen,
|
|
const void* QuantBData,
|
|
void* PackedQuantBData,
|
|
MLAS_THREADPOOL* ThreadPool
|
|
)
|
|
{
|
|
if (BlkBitWidth == 4) {
|
|
SQ4BitGemmPackQuantBData(
|
|
N,
|
|
K,
|
|
BlkLen,
|
|
static_cast<const std::byte*>(QuantBData),
|
|
static_cast<std::byte*>(PackedQuantBData),
|
|
ThreadPool
|
|
);
|
|
}
|
|
}
|
|
|
|
namespace
|
|
{
|
|
|
|
MLAS_FORCEINLINE void
|
|
AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc)
|
|
{
|
|
for (size_t m = 0; m < CountM; m++) {
|
|
const float* bias = Bias;
|
|
float* sum = C;
|
|
for (size_t n = 0; n < CountN; n += 4) {
|
|
if (CountN - n < 4) {
|
|
for (size_t nn = n; nn < CountN; nn++) {
|
|
*sum += *bias;
|
|
sum++;
|
|
bias++;
|
|
}
|
|
break;
|
|
}
|
|
|
|
MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum);
|
|
acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias));
|
|
MlasStoreFloat32x4(sum, acc_x);
|
|
bias += 4;
|
|
sum += 4;
|
|
}
|
|
C += ldc;
|
|
}
|
|
}
|
|
|
|
typedef void(SQNBitGemmFn)(
|
|
size_t BlkLen,
|
|
size_t K,
|
|
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
|
|
void* PerGemmWorkspace,
|
|
size_t RangeStartM,
|
|
size_t RangeCountM,
|
|
size_t RangeStartN,
|
|
size_t RangeCountN
|
|
);
|
|
|
|
void
|
|
SQ4BitGemm_CompFp32(
|
|
const size_t BlkLen,
|
|
const size_t K,
|
|
const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams,
|
|
void* const PerGemmWorkspace,
|
|
const size_t RangeStartM,
|
|
const size_t RangeCountM,
|
|
const size_t RangeStartN,
|
|
const size_t RangeCountN
|
|
)
|
|
{
|
|
constexpr size_t BlkBitWidth = 4;
|
|
|
|
MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace);
|
|
|
|
const size_t lda = DataParams->lda;
|
|
const size_t ldc = DataParams->ldc;
|
|
|
|
const size_t k_blks = MlasDivRoundup(K, BlkLen);
|
|
const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
|
const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(k_blks);
|
|
|
|
const float* A = DataParams->A + RangeStartM * lda;
|
|
|
|
const std::byte* QuantBData = static_cast<const std::byte*>(DataParams->QuantBData) + RangeStartN * ldb;
|
|
const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks;
|
|
const std::byte* QuantBZeroPoint =
|
|
(DataParams->QuantBZeroPoint == nullptr)
|
|
? nullptr
|
|
: static_cast<const std::byte*>(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes;
|
|
|
|
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;
|
|
|
|
const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;
|
|
|
|
if (RangeCountM == 1) {
|
|
size_t CountN;
|
|
for (size_t n = 0; n < RangeCountN; n += CountN) {
|
|
CountN = std::min(RangeCountN - n, size_t{128});
|
|
|
|
const float* a_row = A;
|
|
const std::byte* b_col = QuantBData + n * ldb;
|
|
const float* b_col_scale = QuantBScale + n * k_blks;
|
|
const std::byte* b_col_zp =
|
|
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
|
|
float* c_blk = C + n;
|
|
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
|
|
|
|
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32(
|
|
BlkLen,
|
|
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
|
|
);
|
|
|
|
if (DataParams->PostProcessor != nullptr) {
|
|
DataParams->PostProcessor->Process(
|
|
DataParams->C, RangeStartM, RangeStartN + n,
|
|
RangeCountM, CountN, ldc
|
|
);
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
|
|
constexpr size_t StrideN = 32;
|
|
size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float);
|
|
MlasThreadedBufAlloc(bufsize);
|
|
auto* dequant_b = reinterpret_cast<float*>(ThreadedBufHolder.get());
|
|
|
|
//
|
|
// Step through each slice of matrix B along the N dimension.
|
|
//
|
|
size_t CountN;
|
|
for (size_t n = 0; n < RangeCountN; n += CountN) {
|
|
CountN = std::min(RangeCountN - n, StrideN);
|
|
|
|
//
|
|
// Step through each slice of matrix A along the M dimension.
|
|
//
|
|
const float* a_row = A;
|
|
const std::byte* b_col = QuantBData + n * ldb;
|
|
const float* b_col_scale = QuantBScale + n * k_blks;
|
|
const std::byte* b_col_zp =
|
|
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
|
|
float* c_blk = C + n;
|
|
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
|
|
|
|
GetMlasPlatform().SQNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32(
|
|
BlkLen,
|
|
dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks
|
|
);
|
|
|
|
size_t RowsRemaining = RangeCountM;
|
|
while (RowsRemaining > 0) {
|
|
#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)
|
|
auto RowsHandled = GetMlasPlatform().GemmFloatKernel(
|
|
a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true
|
|
);
|
|
#else
|
|
auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f);
|
|
#endif
|
|
|
|
if (bias) {
|
|
AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc);
|
|
}
|
|
if (DataParams->PostProcessor != nullptr) {
|
|
DataParams->PostProcessor->Process(
|
|
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN,
|
|
RowsHandled, CountN, ldc
|
|
);
|
|
}
|
|
|
|
c_blk += ldc * RowsHandled;
|
|
a_row += lda * RowsHandled;
|
|
RowsRemaining -= RowsHandled;
|
|
}
|
|
}
|
|
}
|
|
|
|
void
|
|
SQ4BitGemm_CompInt8(
|
|
const size_t BlkLen,
|
|
const size_t K,
|
|
const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams,
|
|
void* const PerGemmWorkspace,
|
|
const size_t RangeStartM,
|
|
const size_t RangeCountM,
|
|
const size_t RangeStartN,
|
|
const size_t RangeCountN
|
|
)
|
|
{
|
|
constexpr size_t BlkBitWidth = 4;
|
|
|
|
const size_t k_blks = MlasDivRoundup(K, BlkLen);
|
|
|
|
const size_t lda = k_blks * Q8BlkSize(BlkLen);
|
|
const size_t ldc = DataParams->ldc;
|
|
const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
|
const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(k_blks);
|
|
|
|
const std::byte* QuantA = static_cast<const std::byte*>(PerGemmWorkspace) + RangeStartM * lda;
|
|
|
|
const std::byte* QuantBData = static_cast<const std::byte*>(DataParams->QuantBData) + RangeStartN * ldb;
|
|
const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks;
|
|
const std::byte* QuantBZeroPoint =
|
|
(DataParams->QuantBZeroPoint == nullptr)
|
|
? nullptr
|
|
: static_cast<const std::byte*>(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes;
|
|
|
|
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;
|
|
|
|
const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;
|
|
|
|
if (RangeCountM == 1) {
|
|
size_t CountN;
|
|
for (size_t n = 0; n < RangeCountN; n += CountN) {
|
|
CountN = std::min(RangeCountN - n, size_t{128});
|
|
|
|
const std::byte* a_row = QuantA;
|
|
const std::byte* b_col = QuantBData + n * ldb;
|
|
const float* b_col_scale = QuantBScale + n * k_blks;
|
|
const std::byte* b_col_zp =
|
|
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
|
|
float* c_blk = C + n;
|
|
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
|
|
|
|
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8(
|
|
BlkLen,
|
|
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
|
|
);
|
|
|
|
if (DataParams->PostProcessor != nullptr) {
|
|
DataParams->PostProcessor->Process(
|
|
DataParams->C, RangeStartM, RangeStartN + n,
|
|
RangeCountM, CountN, ldc
|
|
);
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
|
|
assert(false && "not implemented for M > 1");
|
|
}
|
|
|
|
typedef void(InitializeWorkspaceFn)(
|
|
size_t M,
|
|
size_t N,
|
|
size_t K,
|
|
size_t BatchN,
|
|
size_t BlkLen,
|
|
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
|
|
void* Workspace,
|
|
size_t PerGemmWorkspaceStride,
|
|
MLAS_THREADPOOL* ThreadPool
|
|
);
|
|
|
|
void
|
|
InitializeWorkspace_CompInt8(
|
|
size_t M,
|
|
size_t N,
|
|
size_t K,
|
|
size_t BatchN,
|
|
size_t BlkLen,
|
|
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
|
|
void* Workspace,
|
|
size_t PerGemmWorkspaceStride,
|
|
MLAS_THREADPOOL* ThreadPool
|
|
)
|
|
{
|
|
MLAS_UNREFERENCED_PARAMETER(N);
|
|
|
|
const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8;
|
|
|
|
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
|
|
const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen);
|
|
|
|
MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) {
|
|
const auto& data = DataParams[gemm_idx];
|
|
|
|
const float* ARowPtr = data.A;
|
|
std::byte* QuantARowPtr = static_cast<std::byte*>(Workspace) + gemm_idx * PerGemmWorkspaceStride;
|
|
|
|
for (size_t m = 0; m < M; ++m) {
|
|
QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr);
|
|
|
|
ARowPtr += data.lda;
|
|
QuantARowPtr += QuantAStride;
|
|
}
|
|
});
|
|
}
|
|
|
|
struct Operations {
|
|
InitializeWorkspaceFn* InitializeWorkspace = nullptr;
|
|
SQNBitGemmFn* SQNBitGemm = nullptr;
|
|
};
|
|
|
|
constexpr auto OperationMap = []() {
|
|
std::array<Operations, SQNBitGemmVariantCount> ops;
|
|
|
|
ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32;
|
|
|
|
ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8;
|
|
ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8;
|
|
|
|
return ops;
|
|
}();
|
|
|
|
} // namespace
|
|
|
|
void MLASCALL
|
|
MlasSQNBitGemmBatch(
|
|
const size_t M,
|
|
const size_t N,
|
|
const size_t K,
|
|
const size_t BatchN,
|
|
const size_t BlkBitWidth,
|
|
const size_t BlkLen,
|
|
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
|
|
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
|
|
void* Workspace,
|
|
MLAS_THREADPOOL* ThreadPool
|
|
)
|
|
{
|
|
const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType);
|
|
assert(Variant != SQNBitGemmVariantInvalid);
|
|
|
|
//
|
|
// Ensure `Workspace` has correct alignment.
|
|
//
|
|
if (Workspace != nullptr) {
|
|
const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant);
|
|
const uintptr_t WorkspaceAddress = reinterpret_cast<uintptr_t>(Workspace);
|
|
Workspace = reinterpret_cast<void*>(
|
|
(WorkspaceAddress + Alignment - 1) & (~(Alignment - 1))
|
|
);
|
|
}
|
|
|
|
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen);
|
|
|
|
if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace;
|
|
InitializeWorkspaceOperation != nullptr) {
|
|
InitializeWorkspaceOperation(
|
|
M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool
|
|
);
|
|
}
|
|
|
|
const auto ComputeOperation = OperationMap[Variant].SQNBitGemm;
|
|
|
|
if (ThreadPool == nullptr) {
|
|
for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) {
|
|
const auto* Data = &DataParams[gemm_i];
|
|
void* PerGemmWorkspace =
|
|
reinterpret_cast<std::byte*>(Workspace) + gemm_i * PerGemmWorkspaceStride;
|
|
ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N);
|
|
}
|
|
return;
|
|
}
|
|
|
|
//
|
|
// Compute the number of target threads given the complexity of the SGEMM
|
|
// operation. Small requests should run using the single threaded path.
|
|
//
|
|
|
|
const double Complexity = double(M) * double(N) * double(K) * double(BatchN);
|
|
|
|
ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1;
|
|
|
|
ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8;
|
|
|
|
if (TargetThreadCount >= MaximumThreadCount) {
|
|
TargetThreadCount = MaximumThreadCount;
|
|
}
|
|
|
|
ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN;
|
|
if (ThreadsPerGemm < 1) {
|
|
ThreadsPerGemm = 1;
|
|
}
|
|
|
|
constexpr size_t StrideM = 128;
|
|
|
|
size_t nc = N;
|
|
if (ThreadsPerGemm > 1) {
|
|
// more than one thread per GEMM
|
|
|
|
const size_t BlockedM = MlasDivRoundup(M, StrideM);
|
|
const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm);
|
|
if (max_nc < nc) {
|
|
nc = std::min(
|
|
nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) *
|
|
MLAS_QGEMM_STRIDEN_THREAD_ALIGN
|
|
);
|
|
}
|
|
}
|
|
const size_t StrideN = nc;
|
|
|
|
const size_t ThreadCountM = MlasDivRoundup(M, StrideM);
|
|
const size_t ThreadCountN = MlasDivRoundup(N, StrideN);
|
|
ThreadsPerGemm = ThreadCountM * ThreadCountN;
|
|
|
|
MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) {
|
|
const auto gemm_i = tid / ThreadsPerGemm;
|
|
const auto blk_i = tid % ThreadsPerGemm;
|
|
const auto* Data = &DataParams[gemm_i];
|
|
void* PerGemmWorkspace = reinterpret_cast<void*>(
|
|
reinterpret_cast<std::byte*>(Workspace) + gemm_i * PerGemmWorkspaceStride
|
|
);
|
|
|
|
const ptrdiff_t ThreadIdN = blk_i / ThreadCountM;
|
|
const ptrdiff_t ThreadIdM = blk_i % ThreadCountM;
|
|
|
|
const size_t RangeStartM = ThreadIdM * StrideM;
|
|
const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM);
|
|
|
|
const size_t RangeStartN = ThreadIdN * StrideN;
|
|
const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN);
|
|
|
|
ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN);
|
|
});
|
|
}
|
|
|
|
size_t MLASCALL
|
|
MlasNBitsGemmPackBSize(
|
|
size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType
|
|
)
|
|
{
|
|
#ifdef MLAS_JBLAS
|
|
if (nbits == 4) {
|
|
auto jsize = JblasQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType);
|
|
if (jsize) {
|
|
return jsize;
|
|
}
|
|
}
|
|
#endif
|
|
(void)(N);
|
|
(void)(K);
|
|
(void)(BlkSize);
|
|
(void)(nbits);
|
|
(void)(isAsym);
|
|
(void)(CompType);
|
|
return 0;
|
|
}
|
|
|
|
void MLASCALL
|
|
MlasNBitsGemmPackB(
|
|
void* PackedBuf,
|
|
const uint8_t* QData,
|
|
const float* Scale,
|
|
const uint8_t* Zp,
|
|
size_t N,
|
|
size_t K,
|
|
size_t ldb,
|
|
size_t BlkSize,
|
|
int nbits,
|
|
bool isAsym,
|
|
bool lastCall,
|
|
MLAS_SQNBIT_COMPUTE_TYPE CompType,
|
|
MLAS_THREADPOOL* ThreadPool
|
|
)
|
|
{
|
|
#ifdef MLAS_JBLAS
|
|
if (nbits == 4) {
|
|
if (JblasQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) {
|
|
return;
|
|
}
|
|
}
|
|
#endif
|
|
(void)(PackedBuf);
|
|
(void)(QData);
|
|
(void)(Scale);
|
|
(void)(Zp);
|
|
(void)(N);
|
|
(void)(K);
|
|
(void)(ldb);
|
|
(void)(BlkSize);
|
|
(void)(nbits);
|
|
(void)(isAsym);
|
|
(void)(lastCall);
|
|
(void)(CompType);
|
|
(void)(ThreadPool);
|
|
}
|
|
|
|
void MLASCALL
|
|
MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool)
|
|
{
|
|
#ifdef MLAS_JBLAS
|
|
if (JblasQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) {
|
|
return;
|
|
}
|
|
#endif
|
|
(void)(FpData);
|
|
(void)(PackedBuf);
|
|
(void)(N);
|
|
(void)(K);
|
|
(void)(ldb);
|
|
(void)(ThreadPool);
|
|
}
|
|
|
|
size_t MLASCALL
|
|
MlasSQNBitsGemmBatchPackedBWorkspaceSize(
|
|
const size_t M,
|
|
const size_t N,
|
|
const size_t K,
|
|
const size_t BatchN,
|
|
const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams
|
|
)
|
|
{
|
|
#ifdef MLAS_JBLAS
|
|
return JblasSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams);
|
|
#endif
|
|
(void)(M);
|
|
(void)(N);
|
|
(void)(K);
|
|
(void)(BatchN);
|
|
(void)(DataParams);
|
|
return 0;
|
|
}
|
|
|
|
void MLASCALL
|
|
MlasSQNBitsGemmBatchPackedB(
|
|
const size_t M,
|
|
const size_t N,
|
|
const size_t K,
|
|
const size_t BatchN,
|
|
const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
|
|
void* WorkSpace,
|
|
MLAS_THREADPOOL* ThreadPool
|
|
)
|
|
{
|
|
GetMlasPlatform();
|
|
#ifdef MLAS_JBLAS
|
|
if (JblasSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast<int8_t*>(WorkSpace), ThreadPool)) {
|
|
// PackedWeight is created by jblas
|
|
return;
|
|
}
|
|
#endif
|
|
(void)(M);
|
|
(void)(N);
|
|
(void)(K);
|
|
(void)(BatchN);
|
|
(void)(DataParams);
|
|
(void)(WorkSpace);
|
|
(void)(ThreadPool);
|
|
}
|