onnxruntime/onnxruntime/core/mlas/lib/sqnbitgemm.cpp
Edward Chen 150c4cb8fe
[MLAS AArch64] SQNBitGemm CompInt8 kernel (#18953)
Implement ARM NEON SQNBitGemm kernel that first block quantizes A to int8 and then does int8 multiplication.
2024-01-12 17:58:08 -08:00

820 lines
23 KiB
C++

/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
sqnbitgemm.cpp
Abstract:
This module implements the float/quantized n-bit integer matrix
multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch,
as well as some SQNBitGemm-related query functions.
--*/
#include "sqnbitgemm.h"
#include <cassert>
#ifdef MLAS_JBLAS
#include "jblas_gemm.h"
#endif
namespace
{
enum SQNBitGemmVariant {
SQNBitGemmVariantInvalid = -1,
// Valid variants
SQNBitGemmVariant_BitWidth4_CompFp32 = 0,
SQNBitGemmVariant_BitWidth4_CompInt8,
// End of valid variants
// Keep this element last and ensure that its value is the number of valid SQNBitGemmVariant values.
// Its value is used as an array size.
SQNBitGemmVariantCount,
};
SQNBitGemmVariant
GetSQNBitGemmVariant(
size_t M,
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
MLAS_UNREFERENCED_PARAMETER(N);
MLAS_UNREFERENCED_PARAMETER(K);
if (BlkBitWidth == 4 &&
(BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) {
if (ComputeType == CompFp32 ||
ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32
return SQNBitGemmVariant_BitWidth4_CompFp32;
} else if (ComputeType == CompInt8 && M == 1) {
return SQNBitGemmVariant_BitWidth4_CompInt8;
}
}
return SQNBitGemmVariantInvalid;
}
} // namespace
bool MLASCALL
MlasIsSQNBitGemmAvailable(
size_t M,
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch;
if (Dispatch == nullptr) {
return false;
}
const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType);
switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompFp32: {
return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr &&
Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr;
}
case SQNBitGemmVariant_BitWidth4_CompInt8: {
return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr &&
Dispatch->QuantizeARow_CompInt8 != nullptr;
}
default: {
return false;
}
}
}
namespace
{
size_t
SQNBitGemmWorkspaceAlignment(SQNBitGemmVariant Variant)
{
switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompInt8: {
return Q8BlkAlignment();
}
default: {
return 1;
}
}
}
size_t
SQNBitGemmPerGemmWorkspaceSize(
SQNBitGemmVariant Variant,
size_t M,
size_t N,
size_t K,
size_t BlkLen
)
{
MLAS_UNREFERENCED_PARAMETER(N);
switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompInt8: {
// workspace buffer is used for block quantization of A to int8
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen);
return PerGemmWorkspaceSize;
}
default: {
return 0;
}
}
}
size_t
SQNBitGemmPerGemmWorkspaceStride(
SQNBitGemmVariant Variant,
size_t M,
size_t N,
size_t K,
size_t BlkLen
)
{
const auto Size = SQNBitGemmPerGemmWorkspaceSize(Variant, M, N, K, BlkLen);
const auto Alignment = SQNBitGemmWorkspaceAlignment(Variant);
return MlasDivRoundup(Size, Alignment) * Alignment;
}
} // namespace
size_t MLASCALL
MlasSQNBitGemmBatchWorkspaceSize(
size_t M,
size_t N,
size_t K,
size_t BatchN,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType);
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen);
if (PerGemmWorkspaceStride == 0) {
return 0;
}
const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant);
const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride;
return WorkspaceSize + Alignment - 1;
}
namespace
{
void
SQ4BitGemmPackQuantBData(
size_t N,
size_t K,
size_t BlkLen,
const std::byte* QuantBDataBegin,
std::byte* PackedQuantBDataBegin,
MLAS_THREADPOOL* ThreadPool
)
{
constexpr size_t BlkBitWidth = 4;
assert(BlkLen % 16 == 0);
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const size_t Iterations = N * BlockCountK; // one iteration per block
MlasTrySimpleParallel(
ThreadPool, Iterations,
[&](ptrdiff_t tid) {
const size_t n = tid / BlockCountK;
const size_t k_blk = tid % BlockCountK;
const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize;
const std::byte* QuantBData = QuantBDataBegin + data_offset;
std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset;
//
// Pack 16 4-bit values (8 bytes) at a time like this:
//
// src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF |
// =>
// dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF |
//
for (size_t kk = 0; kk < BlkLen; kk += 16) {
for (size_t byte_pair_idx = 0; byte_pair_idx < 4; ++byte_pair_idx) {
const std::byte src0 = QuantBData[byte_pair_idx];
const std::byte src1 = QuantBData[byte_pair_idx + 4];
std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx];
std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1];
dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4);
dst1 = (src0 >> 4) | ((src1 >> 4) << 4);
}
QuantBData += 8;
PackedQuantBData += 8;
}
}
);
}
} // namespace
size_t MLASCALL
MlasSQNBitGemmPackQuantBDataSize(
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen
)
{
// Ensure that a general implementation is available on this platform.
// For now, all implementations share the same packed format.
{
// Currently, there are implementations specific to M = 1, so pick a more general M > 1.
constexpr size_t M = 2;
// A CompUndef implementation should be available if any is available.
constexpr MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType = CompUndef;
const bool HasGeneralImplementation =
MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType);
if (!HasGeneralImplementation) {
return 0;
}
}
if (BlkBitWidth == 4) {
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
return PackedQuantBDataSize;
}
return 0;
}
void MLASCALL
MlasSQNBitGemmPackQuantBData(
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
const void* QuantBData,
void* PackedQuantBData,
MLAS_THREADPOOL* ThreadPool
)
{
if (BlkBitWidth == 4) {
SQ4BitGemmPackQuantBData(
N,
K,
BlkLen,
static_cast<const std::byte*>(QuantBData),
static_cast<std::byte*>(PackedQuantBData),
ThreadPool
);
}
}
namespace
{
MLAS_FORCEINLINE void
AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc)
{
for (size_t m = 0; m < CountM; m++) {
const float* bias = Bias;
float* sum = C;
for (size_t n = 0; n < CountN; n += 4) {
if (CountN - n < 4) {
for (size_t nn = n; nn < CountN; nn++) {
*sum += *bias;
sum++;
bias++;
}
break;
}
MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum);
acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias));
MlasStoreFloat32x4(sum, acc_x);
bias += 4;
sum += 4;
}
C += ldc;
}
}
typedef void(SQNBitGemmFn)(
size_t BlkLen,
size_t K,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
void* PerGemmWorkspace,
size_t RangeStartM,
size_t RangeCountM,
size_t RangeStartN,
size_t RangeCountN
);
void
SQ4BitGemm_CompFp32(
const size_t BlkLen,
const size_t K,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams,
void* const PerGemmWorkspace,
const size_t RangeStartM,
const size_t RangeCountM,
const size_t RangeStartN,
const size_t RangeCountN
)
{
constexpr size_t BlkBitWidth = 4;
MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace);
const size_t lda = DataParams->lda;
const size_t ldc = DataParams->ldc;
const size_t k_blks = MlasDivRoundup(K, BlkLen);
const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(k_blks);
const float* A = DataParams->A + RangeStartM * lda;
const std::byte* QuantBData = static_cast<const std::byte*>(DataParams->QuantBData) + RangeStartN * ldb;
const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks;
const std::byte* QuantBZeroPoint =
(DataParams->QuantBZeroPoint == nullptr)
? nullptr
: static_cast<const std::byte*>(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes;
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;
const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;
if (RangeCountM == 1) {
size_t CountN;
for (size_t n = 0; n < RangeCountN; n += CountN) {
CountN = std::min(RangeCountN - n, size_t{128});
const float* a_row = A;
const std::byte* b_col = QuantBData + n * ldb;
const float* b_col_scale = QuantBScale + n * k_blks;
const std::byte* b_col_zp =
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32(
BlkLen,
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
);
if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM, RangeStartN + n,
RangeCountM, CountN, ldc
);
}
}
return;
}
constexpr size_t StrideN = 32;
size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float);
MlasThreadedBufAlloc(bufsize);
auto* dequant_b = reinterpret_cast<float*>(ThreadedBufHolder.get());
//
// Step through each slice of matrix B along the N dimension.
//
size_t CountN;
for (size_t n = 0; n < RangeCountN; n += CountN) {
CountN = std::min(RangeCountN - n, StrideN);
//
// Step through each slice of matrix A along the M dimension.
//
const float* a_row = A;
const std::byte* b_col = QuantBData + n * ldb;
const float* b_col_scale = QuantBScale + n * k_blks;
const std::byte* b_col_zp =
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
GetMlasPlatform().SQNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32(
BlkLen,
dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks
);
size_t RowsRemaining = RangeCountM;
while (RowsRemaining > 0) {
#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)
auto RowsHandled = GetMlasPlatform().GemmFloatKernel(
a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true
);
#else
auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f);
#endif
if (bias) {
AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc);
}
if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN,
RowsHandled, CountN, ldc
);
}
c_blk += ldc * RowsHandled;
a_row += lda * RowsHandled;
RowsRemaining -= RowsHandled;
}
}
}
void
SQ4BitGemm_CompInt8(
const size_t BlkLen,
const size_t K,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams,
void* const PerGemmWorkspace,
const size_t RangeStartM,
const size_t RangeCountM,
const size_t RangeStartN,
const size_t RangeCountN
)
{
constexpr size_t BlkBitWidth = 4;
const size_t k_blks = MlasDivRoundup(K, BlkLen);
const size_t lda = k_blks * Q8BlkSize(BlkLen);
const size_t ldc = DataParams->ldc;
const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(k_blks);
const std::byte* QuantA = static_cast<const std::byte*>(PerGemmWorkspace) + RangeStartM * lda;
const std::byte* QuantBData = static_cast<const std::byte*>(DataParams->QuantBData) + RangeStartN * ldb;
const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks;
const std::byte* QuantBZeroPoint =
(DataParams->QuantBZeroPoint == nullptr)
? nullptr
: static_cast<const std::byte*>(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes;
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;
const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;
if (RangeCountM == 1) {
size_t CountN;
for (size_t n = 0; n < RangeCountN; n += CountN) {
CountN = std::min(RangeCountN - n, size_t{128});
const std::byte* a_row = QuantA;
const std::byte* b_col = QuantBData + n * ldb;
const float* b_col_scale = QuantBScale + n * k_blks;
const std::byte* b_col_zp =
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8(
BlkLen,
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
);
if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM, RangeStartN + n,
RangeCountM, CountN, ldc
);
}
}
return;
}
assert(false && "not implemented for M > 1");
}
typedef void(InitializeWorkspaceFn)(
size_t M,
size_t N,
size_t K,
size_t BatchN,
size_t BlkLen,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
void* Workspace,
size_t PerGemmWorkspaceStride,
MLAS_THREADPOOL* ThreadPool
);
void
InitializeWorkspace_CompInt8(
size_t M,
size_t N,
size_t K,
size_t BatchN,
size_t BlkLen,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
void* Workspace,
size_t PerGemmWorkspaceStride,
MLAS_THREADPOOL* ThreadPool
)
{
MLAS_UNREFERENCED_PARAMETER(N);
const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8;
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen);
MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) {
const auto& data = DataParams[gemm_idx];
const float* ARowPtr = data.A;
std::byte* QuantARowPtr = static_cast<std::byte*>(Workspace) + gemm_idx * PerGemmWorkspaceStride;
for (size_t m = 0; m < M; ++m) {
QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr);
ARowPtr += data.lda;
QuantARowPtr += QuantAStride;
}
});
}
struct Operations {
InitializeWorkspaceFn* InitializeWorkspace = nullptr;
SQNBitGemmFn* SQNBitGemm = nullptr;
};
constexpr auto OperationMap = []() {
std::array<Operations, SQNBitGemmVariantCount> ops;
ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32;
ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8;
ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8;
return ops;
}();
} // namespace
void MLASCALL
MlasSQNBitGemmBatch(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const size_t BlkBitWidth,
const size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
void* Workspace,
MLAS_THREADPOOL* ThreadPool
)
{
const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType);
assert(Variant != SQNBitGemmVariantInvalid);
//
// Ensure `Workspace` has correct alignment.
//
if (Workspace != nullptr) {
const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant);
const uintptr_t WorkspaceAddress = reinterpret_cast<uintptr_t>(Workspace);
Workspace = reinterpret_cast<void*>(
(WorkspaceAddress + Alignment - 1) & (~(Alignment - 1))
);
}
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen);
if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace;
InitializeWorkspaceOperation != nullptr) {
InitializeWorkspaceOperation(
M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool
);
}
const auto ComputeOperation = OperationMap[Variant].SQNBitGemm;
if (ThreadPool == nullptr) {
for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) {
const auto* Data = &DataParams[gemm_i];
void* PerGemmWorkspace =
reinterpret_cast<std::byte*>(Workspace) + gemm_i * PerGemmWorkspaceStride;
ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N);
}
return;
}
//
// Compute the number of target threads given the complexity of the SGEMM
// operation. Small requests should run using the single threaded path.
//
const double Complexity = double(M) * double(N) * double(K) * double(BatchN);
ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1;
ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8;
if (TargetThreadCount >= MaximumThreadCount) {
TargetThreadCount = MaximumThreadCount;
}
ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN;
if (ThreadsPerGemm < 1) {
ThreadsPerGemm = 1;
}
constexpr size_t StrideM = 128;
size_t nc = N;
if (ThreadsPerGemm > 1) {
// more than one thread per GEMM
const size_t BlockedM = MlasDivRoundup(M, StrideM);
const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm);
if (max_nc < nc) {
nc = std::min(
nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) *
MLAS_QGEMM_STRIDEN_THREAD_ALIGN
);
}
}
const size_t StrideN = nc;
const size_t ThreadCountM = MlasDivRoundup(M, StrideM);
const size_t ThreadCountN = MlasDivRoundup(N, StrideN);
ThreadsPerGemm = ThreadCountM * ThreadCountN;
MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) {
const auto gemm_i = tid / ThreadsPerGemm;
const auto blk_i = tid % ThreadsPerGemm;
const auto* Data = &DataParams[gemm_i];
void* PerGemmWorkspace = reinterpret_cast<void*>(
reinterpret_cast<std::byte*>(Workspace) + gemm_i * PerGemmWorkspaceStride
);
const ptrdiff_t ThreadIdN = blk_i / ThreadCountM;
const ptrdiff_t ThreadIdM = blk_i % ThreadCountM;
const size_t RangeStartM = ThreadIdM * StrideM;
const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM);
const size_t RangeStartN = ThreadIdN * StrideN;
const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN);
ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN);
});
}
size_t MLASCALL
MlasNBitsGemmPackBSize(
size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType
)
{
#ifdef MLAS_JBLAS
if (nbits == 4) {
auto jsize = JblasQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType);
if (jsize) {
return jsize;
}
}
#endif
(void)(N);
(void)(K);
(void)(BlkSize);
(void)(nbits);
(void)(isAsym);
(void)(CompType);
return 0;
}
void MLASCALL
MlasNBitsGemmPackB(
void* PackedBuf,
const uint8_t* QData,
const float* Scale,
const uint8_t* Zp,
size_t N,
size_t K,
size_t ldb,
size_t BlkSize,
int nbits,
bool isAsym,
bool lastCall,
MLAS_SQNBIT_COMPUTE_TYPE CompType,
MLAS_THREADPOOL* ThreadPool
)
{
#ifdef MLAS_JBLAS
if (nbits == 4) {
if (JblasQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) {
return;
}
}
#endif
(void)(PackedBuf);
(void)(QData);
(void)(Scale);
(void)(Zp);
(void)(N);
(void)(K);
(void)(ldb);
(void)(BlkSize);
(void)(nbits);
(void)(isAsym);
(void)(lastCall);
(void)(CompType);
(void)(ThreadPool);
}
void MLASCALL
MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool)
{
#ifdef MLAS_JBLAS
if (JblasQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) {
return;
}
#endif
(void)(FpData);
(void)(PackedBuf);
(void)(N);
(void)(K);
(void)(ldb);
(void)(ThreadPool);
}
size_t MLASCALL
MlasSQNBitsGemmBatchPackedBWorkspaceSize(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams
)
{
#ifdef MLAS_JBLAS
return JblasSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams);
#endif
(void)(M);
(void)(N);
(void)(K);
(void)(BatchN);
(void)(DataParams);
return 0;
}
void MLASCALL
MlasSQNBitsGemmBatchPackedB(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
void* WorkSpace,
MLAS_THREADPOOL* ThreadPool
)
{
GetMlasPlatform();
#ifdef MLAS_JBLAS
if (JblasSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast<int8_t*>(WorkSpace), ThreadPool)) {
// PackedWeight is created by jblas
return;
}
#endif
(void)(M);
(void)(N);
(void)(K);
(void)(BatchN);
(void)(DataParams);
(void)(WorkSpace);
(void)(ThreadPool);
}