[MLAS AArch64] SQNBitGemm optimization (#19272)

1. Add support for packing 4-bit values 32 at a time for CompInt8. 32 4-bit values can fit into a single 128-bit NEON register. For CompInt8, this enables a more efficient path for block sizes greater than or equal to 32. CompFp32 seems to do better with handling 16 elements at a time, so this 32-value packing is not used there.
Pack differently based on compute type. Adjust APIs to handle this.

2. Introduce template argument for whether to handle zero-point. This results in less code for the no zero-point (symmetric) case. However, there is a binary size increase due to the additional template instantiations.
This commit is contained in:
Edward Chen 2024-01-30 14:29:12 -08:00 committed by GitHub
parent 04afe77305
commit c379a89bcb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 561 additions and 235 deletions

View file

@ -9,6 +9,7 @@
#include "core/mlas/inc/mlas_q4.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/common.h"
#ifdef ORT_NEURAL_SPEED
#include "contrib_ops/cpu/quantization/neural_speed_gemm.h"
#endif
@ -16,6 +17,39 @@
namespace onnxruntime {
namespace contrib {
namespace {
int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level_attr) {
const auto accuracy_level = std::clamp(accuracy_level_attr,
static_cast<int64_t>(CompMostAccurate),
static_cast<int64_t>(CompLeastAccurate));
#if defined(ORT_NEURAL_SPEED)
ORT_UNUSED_PARAMETER(nbits);
ORT_UNUSED_PARAMETER(block_size);
// Neural Speed APIs already expect a minimum accuracy level so just use the given value.
return accuracy_level;
#else // defined(ORT_NEURAL_SPEED)
// Find a supported accuracy level that is not less accurate than the one given.
// CompMostAccurate is always supported with the fallback implementation.
// Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed.
int64_t effective_accuracy_level = accuracy_level;
for (; effective_accuracy_level > CompMostAccurate; --effective_accuracy_level) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(effective_accuracy_level);
if (MlasIsSQNBitGemmAvailable(nbits, block_size, compute_type)) {
break;
}
}
return effective_accuracy_level;
#endif // defined(ORT_NEURAL_SPEED)
}
} // namespace
class MatMulNBits final : public OpKernel {
public:
MatMulNBits(const OpKernelInfo& info)
@ -24,7 +58,7 @@ class MatMulNBits final : public OpKernel {
N_{narrow<size_t>(info.GetAttr<int64_t>("N"))},
block_size_{narrow<size_t>(info.GetAttr<int64_t>("block_size"))},
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
accuracy_level_{info.GetAttr<int64_t>("accuracy_level")} {
accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr<int64_t>("accuracy_level"))} {
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
#ifdef ORT_NEURAL_SPEED
@ -58,17 +92,22 @@ class MatMulNBits final : public OpKernel {
const bool column_wise_quant_{true};
IAllocatorUniquePtr<void> packed_b_;
size_t packed_b_size_{0};
#ifdef ORT_NEURAL_SPEED
#if defined(ORT_NEURAL_SPEED)
bool is_asym_{false};
bool all_constant_{false};
#endif
#endif // defined(ORT_NEURAL_SPEED)
};
Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
#ifdef ORT_NEURAL_SPEED
#if defined(ORT_NEURAL_SPEED)
if (!all_constant_) {
return Status::OK();
}
@ -116,11 +155,17 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
#else // defined(ORT_NEURAL_SPEED)
if (input_idx == 1) {
packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_);
if (packed_b_size_ == 0) return Status::OK();
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) {
return Status::OK();
}
packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type);
if (packed_b_size_ == 0) {
return Status::OK();
}
auto qptr = tensor.DataRaw();
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, qptr, packed_b_.get());
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get());
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
@ -136,7 +181,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;
#ifdef ORT_NEURAL_SPEED
#if defined(ORT_NEURAL_SPEED)
// Pack three tensors into one buffer
if (input_idx == 1) {
used_shared_buffers = true;
@ -159,6 +206,7 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prep
}
#endif // defined(ORT_NEURAL_SPEED)
return Status::OK();
}
@ -167,8 +215,10 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const Tensor* a = ctx->Input<Tensor>(0);
const auto* a_data = a->Data<float>();
#ifdef ORT_NEURAL_SPEED
if (packed_b_.get()) {
#if defined(ORT_NEURAL_SPEED)
if (packed_b_) {
TensorShape b_shape({static_cast<int64_t>(N_), static_cast<int64_t>(K_)});
MatMulComputeHelper helper;
@ -234,37 +284,43 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(),
[](size_t offset) { return offset == 0; });
if (has_single_b_matrix && packed_b_) {
for (int64_t accuracy_level = accuracy_level_;
accuracy_level >= static_cast<int64_t>(CompMostAccurate);
--accuracy_level) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level);
if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) {
IAllocatorUniquePtr<std::byte> workspace{};
if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count,
nbits_, block_size_, compute_type);
workspace_size > 0) {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size);
}
if (has_single_b_matrix) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
InlinedVector<MLAS_SQNBIT_GEMM_DATA_PARAMS> data(batch_count);
for (size_t i = 0; i < batch_count; ++i) {
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].QuantBData = packed_b_.get();
data[i].QuantBScale = scales_data;
data[i].QuantBZeroPoint = zero_points_data;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
}
MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(),
thread_pool);
return Status::OK();
if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) {
IAllocatorUniquePtr<std::byte> workspace{};
if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count,
nbits_, block_size_, compute_type);
workspace_size > 0) {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size);
}
const void* b_data = [&]() -> const void* {
if (packed_b_) {
return packed_b_.get();
}
const Tensor* b = ctx->Input<Tensor>(1);
return b->DataRaw();
}();
InlinedVector<MLAS_SQNBIT_GEMM_DATA_PARAMS> data(batch_count);
for (size_t i = 0; i < batch_count; ++i) {
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].QuantBData = b_data;
data[i].QuantBScale = scales_data;
data[i].QuantBZeroPoint = zero_points_data;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
}
MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(),
thread_pool);
return Status::OK();
}
}

View file

@ -37,9 +37,7 @@ typedef enum {
CompMostAccurate = CompUndef,
CompLeastAccurate = CompInt8,
} MLAS_SQNBIT_COMPUTE_TYPE;
using MLAS_SQNBIT_GEMM_COMPUTE_TYPE = MLAS_SQNBIT_COMPUTE_TYPE; // TODO consolidate these
} MLAS_SQNBIT_GEMM_COMPUTE_TYPE;
/**
* @brief Data parameters for float/n-bit quantized int GEMM routine.
@ -102,18 +100,12 @@ MlasSQNBitGemmBatch(
/**
* @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
bool MLASCALL
MlasIsSQNBitGemmAvailable(
size_t M,
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
@ -153,13 +145,15 @@ MlasSQNBitGemmBatchWorkspaceSize(
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
size_t MLASCALL
MlasSQNBitGemmPackQuantBDataSize(
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
);
/**
@ -169,6 +163,7 @@ MlasSQNBitGemmPackQuantBDataSize(
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
* @param[in] QuantBData quantized B data
* @param[out] PackedQuantBData packed quantized B data
* @param[in] ThreadPool optional thread pool to use
@ -179,6 +174,7 @@ MlasSQNBitGemmPackQuantBData(
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const void* QuantBData,
void* PackedQuantBData,
MLAS_THREADPOOL* ThreadPool = nullptr

View file

@ -39,23 +39,17 @@ enum SQNBitGemmVariant {
SQNBitGemmVariant
GetSQNBitGemmVariant(
size_t M,
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
MLAS_UNREFERENCED_PARAMETER(N);
MLAS_UNREFERENCED_PARAMETER(K);
if (BlkBitWidth == 4 &&
(BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) {
if (ComputeType == CompFp32 ||
ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32
return SQNBitGemmVariant_BitWidth4_CompFp32;
} else if (ComputeType == CompInt8 && M == 1) {
} else if (ComputeType == CompInt8) {
return SQNBitGemmVariant_BitWidth4_CompInt8;
}
}
@ -67,9 +61,6 @@ GetSQNBitGemmVariant(
bool MLASCALL
MlasIsSQNBitGemmAvailable(
size_t M,
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
@ -80,7 +71,7 @@ MlasIsSQNBitGemmAvailable(
return false;
}
const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType);
const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType);
switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompFp32: {
@ -164,7 +155,7 @@ MlasSQNBitGemmBatchWorkspaceSize(
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType);
const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType);
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen);
if (PerGemmWorkspaceStride == 0) {
@ -178,91 +169,24 @@ MlasSQNBitGemmBatchWorkspaceSize(
return WorkspaceSize + Alignment - 1;
}
namespace
{
void
SQ4BitGemmPackQuantBData(
size_t N,
size_t K,
size_t BlkLen,
const std::byte* QuantBDataBegin,
std::byte* PackedQuantBDataBegin,
MLAS_THREADPOOL* ThreadPool
)
{
constexpr size_t BlkBitWidth = 4;
assert(BlkLen % 16 == 0);
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const size_t Iterations = N * BlockCountK; // one iteration per block
MlasTrySimpleParallel(
ThreadPool, Iterations,
[&](ptrdiff_t tid) {
const size_t n = tid / BlockCountK;
const size_t k_blk = tid % BlockCountK;
const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize;
const std::byte* QuantBData = QuantBDataBegin + data_offset;
std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset;
//
// Pack 16 4-bit values (8 bytes) at a time like this:
//
// src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF |
// =>
// dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF |
//
for (size_t kk = 0; kk < BlkLen; kk += 16) {
for (size_t byte_pair_idx = 0; byte_pair_idx < 4; ++byte_pair_idx) {
const std::byte src0 = QuantBData[byte_pair_idx];
const std::byte src1 = QuantBData[byte_pair_idx + 4];
std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx];
std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1];
dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4);
dst1 = (src0 >> 4) | ((src1 >> 4) << 4);
}
QuantBData += 8;
PackedQuantBData += 8;
}
}
);
}
} // namespace
size_t MLASCALL
MlasSQNBitGemmPackQuantBDataSize(
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
// Ensure that a general implementation is available on this platform.
// For now, all implementations share the same packed format.
{
// Currently, there are implementations specific to M = 1, so pick a more general M > 1.
constexpr size_t M = 2;
// A CompUndef implementation should be available if any is available.
constexpr MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType = CompUndef;
const bool HasGeneralImplementation =
MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType);
if (!HasGeneralImplementation) {
return 0;
}
const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch;
if (Dispatch == nullptr) {
return 0;
}
if (BlkBitWidth == 4) {
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
return PackedQuantBDataSize;
if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBDataSize != nullptr) {
return Dispatch->SQ4BitGemmPackQuantBDataSize(
N, K, BlkLen, ComputeType
);
}
return 0;
@ -274,20 +198,28 @@ MlasSQNBitGemmPackQuantBData(
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const void* QuantBData,
void* PackedQuantBData,
MLAS_THREADPOOL* ThreadPool
)
{
if (BlkBitWidth == 4) {
SQ4BitGemmPackQuantBData(
const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch;
if (Dispatch == nullptr) {
return;
}
if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBData != nullptr) {
Dispatch->SQ4BitGemmPackQuantBData(
N,
K,
BlkLen,
ComputeType,
static_cast<const std::byte*>(QuantBData),
static_cast<std::byte*>(PackedQuantBData),
ThreadPool
);
return;
}
}
@ -512,7 +444,37 @@ SQ4BitGemm_CompInt8(
return;
}
assert(false && "not implemented for M > 1");
// This is a naive M > 1 implementation that repeatedly calls the M=1 kernel.
// TODO Replace it with an optimized implementation.
size_t CountN;
for (size_t n = 0; n < RangeCountN; n += CountN) {
CountN = std::min(RangeCountN - n, size_t{128});
const std::byte* a_row = QuantA;
const std::byte* b_col = QuantBData + n * ldb;
const float* b_col_scale = QuantBScale + n * k_blks;
const std::byte* b_col_zp =
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
for (size_t m = 0; m < RangeCountM; ++m) {
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8(
BlkLen,
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
);
if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM, RangeStartN + n,
RangeCountM, CountN, ldc
);
}
c_blk += ldc;
a_row += lda;
}
}
}
typedef void(InitializeWorkspaceFn)(
@ -594,7 +556,7 @@ MlasSQNBitGemmBatch(
MLAS_THREADPOOL* ThreadPool
)
{
const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType);
const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType);
assert(Variant != SQNBitGemmVariantInvalid);
//

View file

@ -99,6 +99,33 @@ Q8BlkAlignment()
//
struct MLAS_SQNBIT_GEMM_DISPATCH {
//
// Quantized B data packing function prototypes.
//
/** Gets size of packed quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBDataSize(). */
typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)(
size_t N,
size_t K,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
);
SQ4BitGemmPackQuantBDataSize_Fn* SQ4BitGemmPackQuantBDataSize = nullptr;
/** Packs quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBData(). */
typedef void(SQ4BitGemmPackQuantBData_Fn)(
size_t N,
size_t K,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const std::byte* QuantBDataBegin,
std::byte* PackedQuantBDataBegin,
MLAS_THREADPOOL* ThreadPool
);
SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr;
//
// CompFp32 kernel function prototypes.
//

View file

@ -15,14 +15,115 @@ Abstract:
--*/
#include "sqnbitgemm.h"
#include <arm_neon.h>
#include <algorithm>
#include <cassert>
#include <utility>
#include "sqnbitgemm.h"
//
// Quantized B data packing function implementation.
//
namespace
{
size_t
SQ4BitGemmPackQuantBDataSize(
size_t N,
size_t K,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType
constexpr size_t BlkBitWidth = 4;
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
return PackedQuantBDataSize;
}
void
SQ4BitGemmPackQuantBData(
size_t N,
size_t K,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const std::byte* QuantBDataBegin,
std::byte* PackedQuantBDataBegin,
MLAS_THREADPOOL* ThreadPool
)
{
constexpr size_t BlkBitWidth = 4;
assert(BlkLen >= 16 && BlkLen % 16 == 0);
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const size_t Iterations = N * BlockCountK; // one iteration per block
const size_t SubBlkLen = (ComputeType == CompInt8)
? ((BlkLen == 16) ? 16 : 32)
: 16;
const size_t SubBlkDataSize = SubBlkLen / 2;
const size_t SubBlkBytePairCount = SubBlkLen / 4;
//
// For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this:
//
// src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF |
// =>
// dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF |
//
//
// For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this:
//
// src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 |
// =>
// dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 |
//
MlasTrySimpleParallel(
ThreadPool, Iterations,
[&](ptrdiff_t tid) {
const size_t n = tid / BlockCountK;
const size_t k_blk = tid % BlockCountK;
const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize;
const std::byte* QuantBData = QuantBDataBegin + data_offset;
std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset;
for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) {
for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) {
const std::byte src0 = QuantBData[byte_pair_idx];
const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2];
std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx];
std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1];
dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4);
dst1 = (src0 >> 4) | ((src1 >> 4) << 4);
}
QuantBData += SubBlkDataSize;
PackedQuantBData += SubBlkDataSize;
}
}
);
}
} // namespace
//
// General helpers.
//
namespace
{
@ -95,7 +196,16 @@ LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4])
}
}
template <size_t NCols>
} // namespace
//
// CompFp32 kernel implementation.
//
namespace
{
template <size_t NCols, bool HasZeroPoint>
MLAS_FORCEINLINE void
ComputeDotProducts_BlkBitWidth4_CompFp32(
size_t BlkLen,
@ -112,11 +222,11 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
)
{
constexpr size_t BlkBitWidth = 4;
constexpr size_t SubBlkLen = 16;
static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4");
constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration
assert(BlkLen % SubBlkLen == 0);
assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0);
const uint8x8_t LowMask = vdup_n_u8(0x0F);
@ -137,7 +247,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
const std::byte* QuantBData = QuantBDataColPtr;
const float* QuantBScale = QuantBScaleColPtr;
size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
[[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
// only used if HasZeroPoint == true
for (size_t k = 0; k < CountK; k += BlkLen) {
const size_t k_blk_len = std::min(CountK - k, BlkLen);
@ -147,8 +258,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
[&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; }
);
float offset[NCols]; // Includes zero point and float conversion offset of 16.
if (QuantBZeroPointColPtr != nullptr) {
[[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset of 16.
// only used if HasZeroPoint == true
if constexpr (HasZeroPoint) {
UnrolledLoop<NCols>([&](size_t i) {
const std::byte zp_packed =
QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2];
@ -157,11 +269,6 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
: (zp_packed & std::byte{0x0F});
offset[i] = 16.0f + std::to_integer<uint8_t>(zp);
});
} else {
UnrolledLoop<NCols>([&](size_t i) {
constexpr float zp = 8.0f;
offset[i] = 16.0f + zp;
});
}
for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) {
@ -187,8 +294,6 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4);
});
// dequantize B
// shift left 3 and widen to 16 bits
uint16x8_t bv_u16[NCols][2];
UnrolledLoop<NCols>([&](size_t i) {
@ -217,10 +322,17 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
});
// subtract float conversion offset (16) and zero point
UnrolledLoop<NCols>([&](size_t i) {
const float32x4_t offset_v = vdupq_n_f32(offset[i]);
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
});
if constexpr (HasZeroPoint) {
UnrolledLoop<NCols>([&](size_t i) {
const float32x4_t offset_v = vdupq_n_f32(offset[i]);
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
});
} else {
const float32x4_t offset_v = vdupq_n_f32(16.0f + 8.0f);
UnrolledLoop<NCols>([&](size_t i) {
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
});
}
// multiply by scale
UnrolledLoop<NCols>([&](size_t i) {
@ -237,7 +349,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
// increment pointers to next block
QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
QuantBScale += 1;
QuantBZeroPointIdx += 1;
if constexpr (HasZeroPoint) {
QuantBZeroPointIdx += 1;
}
}
if constexpr (NCols == 4) {
@ -258,8 +372,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
}
}
MLAS_FORCEINLINE void
SQ4BitGemmM1Kernel_CompFp32(
template <bool HasZeroPoint>
void
SQ4BitGemmM1Kernel_CompFp32_Impl(
size_t BlkLen,
const float* A,
const std::byte* QuantBData,
@ -295,7 +410,7 @@ SQ4BitGemmM1Kernel_CompFp32(
int64_t nblk = static_cast<int64_t>(CountN) - NCols;
while (nblk >= 0) {
ComputeDotProducts_BlkBitWidth4_CompFp32<NCols>(
ComputeDotProducts_BlkBitWidth4_CompFp32<NCols, HasZeroPoint>(
BlkLen,
ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
@ -306,7 +421,7 @@ SQ4BitGemmM1Kernel_CompFp32(
QuantBDataColPtr += NCols * StrideQuantBData;
QuantBScaleColPtr += NCols * StrideQuantBScale;
if (QuantBZeroPointColPtr != nullptr) {
if constexpr (HasZeroPoint) {
QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint;
}
@ -319,7 +434,7 @@ SQ4BitGemmM1Kernel_CompFp32(
// left over columns less than `NCols`?
nblk += NCols;
for (int64_t n = 0; n < nblk; ++n) {
ComputeDotProducts_BlkBitWidth4_CompFp32<1>(
ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>(
BlkLen,
ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
@ -330,7 +445,7 @@ SQ4BitGemmM1Kernel_CompFp32(
QuantBDataColPtr += StrideQuantBData;
QuantBScaleColPtr += StrideQuantBScale;
if (QuantBZeroPointColPtr != nullptr) {
if constexpr (HasZeroPoint) {
QuantBZeroPointColPtr += StrideQuantBZeroPoint;
}
@ -339,6 +454,49 @@ SQ4BitGemmM1Kernel_CompFp32(
}
}
MLAS_FORCEINLINE void
SQ4BitGemmM1Kernel_CompFp32(
size_t BlkLen,
const float* A,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB,
const float* Bias
)
{
if (QuantBZeroPoint != nullptr) {
SQ4BitGemmM1Kernel_CompFp32_Impl<true>(
BlkLen,
A,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockStrideQuantB,
Bias
);
} else {
SQ4BitGemmM1Kernel_CompFp32_Impl<false>(
BlkLen,
A,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockStrideQuantB,
Bias
);
}
}
MLAS_FORCEINLINE void
Q4BitBlkDequantBForSgemm_CompFp32(
size_t BlkLen,
@ -353,6 +511,7 @@ Q4BitBlkDequantBForSgemm_CompFp32(
{
auto impl0_reference = [&]() {
constexpr size_t BlkBitWidth = 4;
constexpr size_t SubBlkLen = 16;
float* Dst = FpData;
@ -378,11 +537,11 @@ Q4BitBlkDequantBForSgemm_CompFp32(
: 8;
for (size_t kk = 0; kk < kklen; ++kk) {
const size_t packed_idx = kk % 16;
const size_t packed_idx = kk % SubBlkLen;
const bool is_low_half = packed_idx < 8;
const size_t packed_byte_idx = packed_idx % 8;
const size_t packed_range_offset = (kk / 16) * 8;
const bool is_low_half = packed_idx < (SubBlkLen / 2);
const size_t packed_byte_idx = packed_idx % (SubBlkLen / 2);
const size_t packed_range_offset = (kk / SubBlkLen) * (SubBlkLen / 2);
const std::byte b_packed = b_data[packed_range_offset + packed_byte_idx];
const std::byte b_byte = is_low_half ? (b_packed & std::byte{0x0F}) : (b_packed >> 4);
@ -415,7 +574,7 @@ Q4BitBlkDequantBForSgemm_CompFp32(
}
//
// CompInt8 kernel implementation and related helpers
// CompInt8 kernel implementation.
//
template <size_t SubBlkLen>
@ -431,8 +590,6 @@ QuantizeBlock(
assert(BlkLen % SubBlkLen == 0);
constexpr size_t VectorCount = SubBlkLen / 4;
//
// Scan block values first to determine scale.
//
@ -443,16 +600,16 @@ QuantizeBlock(
for (k = 0; k < ElementCount; k += SubBlkLen) {
const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen);
float32x4_t a[VectorCount]{};
float32x4_t a[SubBlkLen / 4]{};
LoadFloatData<SubBlkLen>(A + k, SubBlkElementCount, a);
float32x4_t abs_a[VectorCount];
UnrolledLoop<VectorCount>([&](size_t i) {
float32x4_t abs_a[SubBlkLen / 4];
UnrolledLoop<SubBlkLen / 4>([&](size_t i) {
abs_a[i] = vabsq_f32(a[i]);
});
// find amax of SubBlkLen elements
for (size_t interval = VectorCount / 2; interval > 0; interval /= 2) {
for (size_t interval = SubBlkLen / 4 / 2; interval > 0; interval /= 2) {
for (size_t i = 0; i < interval; ++i) {
abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]);
}
@ -477,19 +634,19 @@ QuantizeBlock(
for (k = 0; k < ElementCount; k += SubBlkLen) {
const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen);
float32x4_t a[VectorCount]{};
float32x4_t a[SubBlkLen / 4]{};
LoadFloatData<SubBlkLen>(A + k, SubBlkElementCount, a);
UnrolledLoop<VectorCount>([&](size_t i) {
UnrolledLoop<SubBlkLen / 4>([&](size_t i) {
a[i] = vmulq_n_f32(a[i], scale_reciprocal);
});
int32x4_t a_s32[VectorCount];
UnrolledLoop<VectorCount>([&](size_t i) {
int32x4_t a_s32[SubBlkLen / 4];
UnrolledLoop<SubBlkLen / 4>([&](size_t i) {
a_s32[i] = vcvtaq_s32_f32(a[i]);
});
UnrolledLoop<VectorCount>([&](size_t i) {
UnrolledLoop<SubBlkLen / 4>([&](size_t i) {
QuantAData[k + i * 4 + 0] = static_cast<int8_t>(vgetq_lane_s32(a_s32[i], 0));
QuantAData[k + i * 4 + 1] = static_cast<int8_t>(vgetq_lane_s32(a_s32[i], 1));
QuantAData[k + i * 4 + 2] = static_cast<int8_t>(vgetq_lane_s32(a_s32[i], 2));
@ -530,7 +687,7 @@ QuantizeARow_CompInt8(
}
}
template <size_t NCols>
template <size_t NCols, size_t SubBlkLen, bool HasZeroPoint>
MLAS_FORCEINLINE void
ComputeDotProducts_BlkBitWidth4_CompInt8(
size_t BlkLen,
@ -546,20 +703,22 @@ ComputeDotProducts_BlkBitWidth4_CompInt8(
const float* BiasPtr
)
{
static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4");
constexpr size_t BlkBitWidth = 4;
constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration
assert(BlkLen % SubBlkLen == 0);
static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4");
static_assert(SubBlkLen == 16 || SubBlkLen == 32, "SubBlkLen must be 16 or 32");
const uint8x8_t LowMask = vdup_n_u8(0x0F);
assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0);
[[maybe_unused]] const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); // only used if SubBlkLen == 16
[[maybe_unused]] const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); // only used if SubBlkLen == 32
const std::byte* QuantA = QuantARowPtr;
const std::byte* QuantBData = QuantBDataColPtr;
const float* QuantBScale = QuantBScaleColPtr;
size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
[[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
// only used if HasZeroPoint == true
float32x4_t acc[NCols]{};
@ -572,8 +731,8 @@ ComputeDotProducts_BlkBitWidth4_CompInt8(
float b_scale[NCols];
UnrolledLoop<NCols>([&](size_t i) { b_scale[i] = QuantBScale[i * StrideQuantBScale]; });
int8_t b_zp[NCols];
if (QuantBZeroPointColPtr != nullptr) {
[[maybe_unused]] int8_t b_zp[NCols]; // only used if HasZeroPoint == true
if constexpr (HasZeroPoint) {
UnrolledLoop<NCols>([&](size_t i) {
const std::byte zp_packed =
QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2];
@ -581,42 +740,73 @@ ComputeDotProducts_BlkBitWidth4_CompInt8(
? std::to_integer<int8_t>(zp_packed >> 4)
: std::to_integer<int8_t>(zp_packed & std::byte{0x0F});
});
} else {
UnrolledLoop<NCols>([&](size_t i) {
b_zp[i] = 8;
});
}
for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) {
// load A row vector
int8x16_t av = vld1q_s8(a_data + k_idx_in_blk);
int8x16_t av[SubBlkLen / 16];
UnrolledLoop<SubBlkLen / 16>([&](size_t i) {
av[i] = vld1q_s8(a_data + k_idx_in_blk + i * 16);
});
// load B column vectors
uint8x8_t bv_packed[NCols];
const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8;
UnrolledLoop<NCols>([&](size_t i) {
bv_packed[i] = vld1_u8(
reinterpret_cast<const uint8_t*>(QuantBData) + i * StrideQuantBData + b_data_block_offset
);
});
int8x16_t bv[NCols][SubBlkLen / 16];
int8x16_t bv[NCols];
UnrolledLoop<NCols>([&](size_t i) {
const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMask));
const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4));
bv[i] = vcombine_s8(lo, hi);
});
const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8;
if constexpr (SubBlkLen == 16) {
uint8x8_t bv_packed[NCols];
UnrolledLoop<NCols>([&](size_t i) {
bv_packed[i] = vld1_u8(
reinterpret_cast<const uint8_t*>(QuantBData) + i * StrideQuantBData + b_data_block_offset
);
});
UnrolledLoop<NCols>([&](size_t i) {
const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMaskU8x8));
const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4));
bv[i][0] = vcombine_s8(lo, hi);
});
} else {
static_assert(SubBlkLen == 32);
uint8x16_t bv_packed[NCols];
UnrolledLoop<NCols>([&](size_t i) {
bv_packed[i] = vld1q_u8(
reinterpret_cast<const uint8_t*>(QuantBData) + i * StrideQuantBData + b_data_block_offset
);
});
UnrolledLoop<NCols>([&](size_t i) {
bv[i][0] = vreinterpretq_s8_u8(vandq_u8(bv_packed[i], LowMaskU8x16));
bv[i][1] = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed[i], 4));
});
}
// subtract B zero point
UnrolledLoop<NCols>([&](size_t i) {
const int8x16_t zp_v = vdupq_n_s8(b_zp[i]);
bv[i] = vsubq_s8(bv[i], zp_v);
});
if constexpr (HasZeroPoint) {
UnrolledLoop<NCols>([&](size_t i) {
const int8x16_t zp_v = vdupq_n_s8(b_zp[i]);
UnrolledLoop<SubBlkLen / 16>([&](size_t j) {
bv[i][j] = vsubq_s8(bv[i][j], zp_v);
});
});
} else {
const int8x16_t zp_v = vdupq_n_s8(8);
UnrolledLoop<NCols>([&](size_t i) {
UnrolledLoop<SubBlkLen / 16>([&](size_t j) {
bv[i][j] = vsubq_s8(bv[i][j], zp_v);
});
});
}
// compute quantized dot product
int32x4_t dot[NCols]{};
UnrolledLoop<NCols>([&](size_t i) {
dot[i] = vdotq_s32(dot[i], av, bv[i]);
UnrolledLoop<SubBlkLen / 16>([&](size_t j) {
dot[i] = vdotq_s32(dot[i], av[j], bv[i][j]);
});
});
// convert dot product result to float
@ -636,7 +826,9 @@ ComputeDotProducts_BlkBitWidth4_CompInt8(
QuantA += Q8BlkSize(BlkLen);
QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
QuantBScale += 1;
QuantBZeroPointIdx += 1;
if constexpr (HasZeroPoint) {
QuantBZeroPointIdx += 1;
}
}
if constexpr (NCols == 4) {
@ -657,9 +849,9 @@ ComputeDotProducts_BlkBitWidth4_CompInt8(
}
}
MLAS_FORCEINLINE
template <size_t NCols, size_t SubBlkLen, bool HasZeroPoint>
void
SQ4BitGemmM1Kernel_CompInt8(
SQ4BitGemmM1Kernel_CompInt8_Impl(
size_t BlkLen,
const std::byte* QuantA,
const std::byte* QuantBData,
@ -673,7 +865,6 @@ SQ4BitGemmM1Kernel_CompInt8(
)
{
constexpr size_t BlkBitWidth = 4;
constexpr size_t NCols = 4;
const std::byte* QuantARowPtr = QuantA;
float* CRowPtr = C;
@ -695,7 +886,7 @@ SQ4BitGemmM1Kernel_CompInt8(
int64_t nblk = static_cast<int64_t>(CountN) - NCols;
while (nblk >= 0) {
ComputeDotProducts_BlkBitWidth4_CompInt8<NCols>(
ComputeDotProducts_BlkBitWidth4_CompInt8<NCols, SubBlkLen, HasZeroPoint>(
BlkLen,
QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
@ -706,7 +897,7 @@ SQ4BitGemmM1Kernel_CompInt8(
QuantBDataColPtr += NCols * StrideQuantBData;
QuantBScaleColPtr += NCols * StrideQuantBScale;
if (QuantBZeroPointColPtr != nullptr) {
if constexpr (HasZeroPoint) {
QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint;
}
@ -719,7 +910,7 @@ SQ4BitGemmM1Kernel_CompInt8(
// left over columns less than `NCols`?
nblk += NCols;
for (int64_t n = 0; n < nblk; ++n) {
ComputeDotProducts_BlkBitWidth4_CompInt8<1>(
ComputeDotProducts_BlkBitWidth4_CompInt8<1, SubBlkLen, HasZeroPoint>(
BlkLen,
QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
@ -730,7 +921,7 @@ SQ4BitGemmM1Kernel_CompInt8(
QuantBDataColPtr += StrideQuantBData;
QuantBScaleColPtr += StrideQuantBScale;
if (QuantBZeroPointColPtr != nullptr) {
if constexpr (HasZeroPoint) {
QuantBZeroPointColPtr += StrideQuantBZeroPoint;
}
@ -739,6 +930,94 @@ SQ4BitGemmM1Kernel_CompInt8(
}
}
template <bool HasZeroPoint>
MLAS_FORCEINLINE void
SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(
size_t BlkLen,
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB,
const float* Bias
)
{
if (BlkLen == 16) {
SQ4BitGemmM1Kernel_CompInt8_Impl<4, 16, HasZeroPoint>(
BlkLen,
QuantA,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockStrideQuantB,
Bias
);
} else {
SQ4BitGemmM1Kernel_CompInt8_Impl<4, 32, HasZeroPoint>(
BlkLen,
QuantA,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockStrideQuantB,
Bias
);
}
}
MLAS_FORCEINLINE
void
SQ4BitGemmM1Kernel_CompInt8(
size_t BlkLen,
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB,
const float* Bias
)
{
if (QuantBZeroPoint != nullptr) {
SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<true>(
BlkLen,
QuantA,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockStrideQuantB,
Bias
);
} else {
SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<false>(
BlkLen,
QuantA,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockStrideQuantB,
Bias
);
}
}
} // namespace
//
@ -748,8 +1027,12 @@ SQ4BitGemmM1Kernel_CompInt8(
const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() {
MLAS_SQNBIT_GEMM_DISPATCH d;
d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize;
d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData;
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32;
d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8;
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8;

View file

@ -61,10 +61,11 @@ void SQNBITGEMM(benchmark::State& state) {
}
std::unique_ptr<std::byte[]> PackedQuantBData;
if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen);
if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType);
PackedQuantBDataSize > 0) {
PackedQuantBData = std::make_unique<std::byte[]>(PackedQuantBDataSize);
MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData.data(), PackedQuantBData.get(), tp.get());
MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(),
tp.get());
}
MLAS_SQNBIT_GEMM_DATA_PARAMS params{};
@ -87,7 +88,9 @@ void SQNBITGEMM(benchmark::State& state) {
}
}
static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) {
static void SQ4BitGemmArgs(benchmark::internal::Benchmark* b) {
constexpr size_t BlkBitWidth = 4;
b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "ComputeType"});
ArgsProductWithFilter(b,
@ -96,19 +99,17 @@ static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) {
{1, 1024, 2048}, // M
{4096, 11008}, // N
{4096, 11008}, // K
{8}, // Threads
{1, 8}, // Threads
{int64_t{false}, int64_t{true}}, // Symmetric
{int64_t{CompFp32}, int64_t{CompInt8}}}, // ComputeType
[](const std::vector<int64_t>& args) {
[&](const std::vector<int64_t>& args) {
return MlasIsSQNBitGemmAvailable(
// M, N, K
narrow<size_t>(args[1]), narrow<size_t>(args[2]), narrow<size_t>(args[3]),
// BlkBitWidth, BlkLen
4, narrow<size_t>(args[0]),
BlkBitWidth, narrow<size_t>(args[0]),
// ComputeType
static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(args[6]));
});
}
BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime();
BENCHMARK(SQNBITGEMM<4>)->Apply(SQ4BitGemmArgs)->UseRealTime();

View file

@ -259,10 +259,11 @@ class MlasSQNBitGemmTest : public MlasTestBase {
}
void* PackedQuantBData = nullptr;
if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen);
if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType);
PackedQuantBDataSize > 0) {
PackedQuantBData = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize);
MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData, PackedQuantBData, GetMlasThreadPool());
MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBData,
GetMlasThreadPool());
}
if (ComputeType == CompFp32) {
@ -330,7 +331,7 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture<MlasSQNBitGemmTest<Blk
bool WithThreadpool, bool Symmetric, bool WithBias) {
size_t tests_registered = 0;
if (MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType)) {
if (MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) {
std::stringstream ss;
ss << (WithThreadpool ? "SingleThread" : "Threaded")
<< "/isSymmetric" << Symmetric