[MLAS AArch64] SQNBitGemm CompInt8 kernel (#18953)

Implement ARM NEON SQNBitGemm kernel that first block quantizes A to int8 and then does int8 multiplication.
This commit is contained in:
Edward Chen 2024-01-12 17:58:08 -08:00 committed by GitHub
parent a756017e9f
commit 150c4cb8fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 1667 additions and 559 deletions

View file

@ -1,7 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib)
set(MLAS_ROOT ${ONNXRUNTIME_ROOT}/core/mlas)
set(MLAS_SRC_DIR ${MLAS_ROOT}/lib)
set(MLAS_INC_DIR ${MLAS_ROOT}/inc)
#
# All hardware agnostic source files here
@ -9,6 +11,7 @@ set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib)
# multi-target build
#
onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/mlasi.h
${MLAS_SRC_DIR}/platform.cpp
${MLAS_SRC_DIR}/threading.cpp
${MLAS_SRC_DIR}/sgemm.cpp
@ -33,9 +36,18 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/qpostprocessor.cpp
${MLAS_SRC_DIR}/qlgavgpool.cpp
${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp
${MLAS_SRC_DIR}/sqnbitgemm.h
${MLAS_SRC_DIR}/sqnbitgemm.cpp
)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_INC_DIR}/mlas_float16.h
${MLAS_INC_DIR}/mlas_gemm_postprocessor.h
${MLAS_INC_DIR}/mlas_q4.h
${MLAS_INC_DIR}/mlas_qnbit.h
${MLAS_INC_DIR}/mlas.h
)
if (NOT onnxruntime_ORT_MINIMAL_BUILD)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/q4_dq.cpp
@ -46,7 +58,7 @@ endif()
set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)
function(add_jblas)
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/jblas_gemm.cpp
@ -143,10 +155,6 @@ function(setup_mlas_source_for_windows)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/arm/sgemmc.cpp
)
# it should be removed after Visual Stuio is upgraded to 17.7
if (MSVC)
add_compile_options("-d2SSAOptimizer-")
endif()
elseif(onnxruntime_target_platform STREQUAL "x64")
file(GLOB_RECURSE mlas_platform_srcs_avx CONFIGURE_DEPENDS
@ -300,8 +308,8 @@ else()
if(APPLE)
get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES)
endif()
list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGH)
if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGH GREATER 1)
list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH)
if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1)
set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE)
endif()
#If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below
@ -348,6 +356,8 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
if (NOT APPLE)
set(mlas_platform_srcs
${mlas_platform_srcs}
@ -617,10 +627,12 @@ if(USE_JBLAS)
endif()
foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR})
target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR})
onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET})
set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime")
endforeach()
set_target_properties(onnxruntime_mlas PROPERTIES FOLDER "ONNXRuntime")
if (WIN32)
target_compile_options(onnxruntime_mlas PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:/wd6385>" "$<$<COMPILE_LANGUAGE:CXX>:/wd4127>")
if (onnxruntime_ENABLE_STATIC_ANALYSIS)
@ -636,6 +648,21 @@ if (NOT onnxruntime_BUILD_SHARED_LIB)
FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR})
endif()
# set up source group for MLAS source files
block()
set(source_group_srcs)
foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
get_target_property(mlas_target_srcs ${mlas_target} SOURCES)
foreach(mlas_target_src ${mlas_target_srcs})
cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root)
if(in_mlas_root)
list(APPEND source_group_srcs ${mlas_target_src})
endif()
endforeach()
endforeach()
source_group(TREE ${MLAS_ROOT} FILES ${source_group_srcs})
endblock()
if (NOT onnxruntime_ORT_MINIMAL_BUILD)
@ -647,7 +674,7 @@ if (NOT onnxruntime_ORT_MINIMAL_BUILD)
onnxruntime_add_executable(onnxruntime_mlas_q4dq
${MLAS_SRC_DIR}/q4_dq_cli.cpp
)
target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR})
target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR})
set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest")
target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)

View file

@ -64,6 +64,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
if (!all_constant_) {
return Status::OK();
}
#if defined(MLAS_JBLAS)
auto compt_type = static_cast<MLAS_SQNBIT_COMPUTE_TYPE>(accuracy_level_);
MLAS_THREADPOOL* pool = NULL;
if (input_idx == 1) {
@ -101,12 +104,32 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
is_packed = true;
}
#else // defined(MLAS_JBLAS)
if (input_idx == 1) {
packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_);
if (packed_b_size_ == 0) return Status::OK();
auto qptr = tensor.DataRaw();
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, qptr, packed_b_.get());
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}
#endif // defined(MLAS_JBLAS)
return Status::OK();
}
Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;
#if defined(MLAS_JBLAS)
// Pack three tensors into one buffer
if (input_idx == 1) {
used_shared_buffers = true;
@ -120,6 +143,15 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prep
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
#else // defined(MLAS_JBLAS)
if (input_idx == 1) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
#endif // defined(MLAS_JBLAS)
return Status::OK();
}
@ -129,6 +161,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const Tensor* a = ctx->Input<Tensor>(0);
const auto* a_data = a->Data<float>();
#if defined(MLAS_JBLAS)
if (packed_b_.get()) {
TensorShape b_shape({static_cast<int64_t>(N_), static_cast<int64_t>(K_)});
@ -158,7 +192,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
gemm_params[i].C = y_data + helper.OutputOffsets()[i];
gemm_params[i].ldc = N;
}
auto ws_size = MlasSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data());
auto ws_size = MlasSQNBitsGemmBatchPackedBWorkspaceSize(M, N, K, max_len, gemm_params.data());
// workspace for activation process(dynamic quantization and others)
auto ws_ptr = IAllocator::MakeUniquePtr<int8_t>(allocator, ws_size);
MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(),
@ -166,10 +200,10 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
return Status::OK();
}
const Tensor* b = ctx->Input<Tensor>(1);
#endif // defined(MLAS_JBLAS)
const Tensor* scales = ctx->Input<Tensor>(2);
const Tensor* zero_points = ctx->Input<Tensor>(3);
const uint8_t* b_data = b->Data<uint8_t>();
const auto* scales_data = scales->Data<float>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data<uint8_t>();
@ -181,8 +215,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
Tensor* y = ctx->Output(0, helper.OutputShape());
// Bail out early if the output is going to be empty
if (y->Shape().Size() == 0)
if (y->Shape().Size() == 0) {
return Status::OK();
}
auto* y_data = y->MutableData<float>();
@ -192,36 +227,46 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(false);
if (MlasIsSQNBitGemmAvailable(nbits_, block_size_)) {
// number of bytes or elements between adjacent matrices
size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes;
MlasBlockwiseQuantizedBufferSizes(static_cast<int>(nbits_), static_cast<int>(block_size_), /* columnwise */ true,
static_cast<int>(K), static_cast<int>(N),
b_data_matrix_stride_in_bytes, b_scale_matrix_stride,
&b_zero_point_matrix_stride_in_bytes);
const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(),
[](size_t offset) { return offset == 0; });
const size_t b_matrix_size = K * N;
if (has_single_b_matrix && packed_b_) {
for (int64_t accuracy_level = accuracy_level_;
accuracy_level >= static_cast<int64_t>(CompMostAccurate);
--accuracy_level) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level);
if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) {
IAllocatorUniquePtr<std::byte> workspace{};
if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count,
nbits_, block_size_, compute_type);
workspace_size > 0) {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size);
}
InlinedVector<MLAS_SQNBIT_GEMM_DATA_PARAMS> data(batch_count);
for (size_t i = 0; i < batch_count; ++i) {
const size_t b_matrix_offset = helper.RightOffsets()[i] / b_matrix_size;
InlinedVector<MLAS_SQNBIT_GEMM_DATA_PARAMS> data(batch_count);
for (size_t i = 0; i < batch_count; ++i) {
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].QuantBData = packed_b_.get();
data[i].QuantBScale = scales_data;
data[i].QuantBZeroPoint = zero_points_data;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
}
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].QuantBData = b_data + b_matrix_offset * b_data_matrix_stride_in_bytes;
data[i].QuantBScale = scales_data + b_matrix_offset * b_scale_matrix_stride;
data[i].QuantBZeroPoint = zero_points_data != nullptr
? zero_points_data + b_matrix_offset * b_zero_point_matrix_stride_in_bytes
: nullptr;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(),
thread_pool);
return Status::OK();
}
}
MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, data.data(), thread_pool);
return Status::OK();
}
const Tensor* b = ctx->Input<Tensor>(1);
const uint8_t* b_data = b->Data<uint8_t>();
const size_t ldb = helper.Ldb(true);
AllocatorPtr allocator;

View file

@ -23,19 +23,36 @@ Abstract:
#include "mlas.h"
#include "mlas_gemm_postprocessor.h"
/**
* @brief Define compute types of block quantization, in order of decreasing accuracy.
*/
typedef enum {
CompUndef = 0, /*!< undef */
CompFp32, /*!< input fp32, accumulator fp32 */
CompFp16, /*!< input fp16, accumulator fp16 */
CompBf16, /*!< input bf16, accumulator fp32 */
CompInt8, /*!< input int8, accumulator int32 */
// special values that should be the first and last actual values
CompMostAccurate = CompUndef,
CompLeastAccurate = CompInt8,
} MLAS_SQNBIT_COMPUTE_TYPE;
using MLAS_SQNBIT_GEMM_COMPUTE_TYPE = MLAS_SQNBIT_COMPUTE_TYPE; // TODO consolidate these
/**
* @brief Data parameters for float/n-bit quantized int GEMM routine.
*/
struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
const float* A = nullptr; ///< address of A (float32 matrix)
size_t lda = 0; ///< leading dimension of A
const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values)
const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
bool IsBPacked = false; ///< whether B values are packed in an optimized format for the computation
const float* Bias = nullptr; ///< optional address of Bias, vector size N
float* C = nullptr; ///< address of result matrix
size_t ldc = 0; ///< leading dimension of C
const float* A = nullptr; ///< address of A (float32 matrix)
size_t lda = 0; ///< leading dimension of A
const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values)
const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
const float* Bias = nullptr; ///< optional address of Bias, vector size N
float* C = nullptr; ///< address of result matrix
size_t ldc = 0; ///< leading dimension of C
///< optional post processing to apply to result matrix
MLAS_GEMM_POSTPROCESSOR<float>* PostProcessor = nullptr;
@ -46,13 +63,26 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
* A must be a float32 matrix
* B must be a quantized and packed n-bit int matrix
*
* Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called.
*
* Call MlasSQNBitGemmPackQuantBDataSize() with the same parameters to determine whether
* MLAS_SQNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with
* MlasSQNBitGemmPackQuantBData().
*
* Call MlasSQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should
* point to an intermediate workspace buffer.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] Workspace Address of intermediate workspace buffer.
If MlasSQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a
buffer with at least that many bytes. Otherwise, it may be nullptr.
* @param[in] ThreadPool optional thread pool to use
*/
void MLASCALL
@ -63,31 +93,96 @@ MlasSQNBitGemmBatch(
size_t BatchN,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
void* Workspace,
MLAS_THREADPOOL* ThreadPool = nullptr
);
/**
* @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
bool MLASCALL
MlasIsSQNBitGemmAvailable(
size_t M,
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
);
/**
* @brief Gets the size in bytes of the intermediate workspace buffer required by the float32/quantized n-bit int GEMM
* implementation. If zero, no intermediate workspace is required.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
size_t MLASCALL
MlasSQNBitGemmBatchWorkspaceSize(
size_t M,
size_t N,
size_t K,
size_t BatchN,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
);
/**
* @brief Gets the size in bytes of the packed quantized B data.
* If non-zero, the quantized B data must first be packed by calling MlasSQNBitGemmPackQuantBData() with a buffer of
* this size, and then that packed quantized B data buffer must be passed to MlasSQNBitGemmBatch().
* If zero, MlasSQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to
* MlasSQNBitGemmBatch().
*
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
*/
size_t MLASCALL
MlasSQNBitGemmPackQuantBDataSize(
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen
);
/**
* @brief Define compute types of block quantization
* @brief Packs the quantized B data in a format that the kernel expects.
*
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] QuantBData quantized B data
* @param[out] PackedQuantBData packed quantized B data
* @param[in] ThreadPool optional thread pool to use
*/
typedef enum {
CompUndef = 0, /*!< undef */
CompFp32 = 1, /*!< input fp32, accumulator fp32 */
CompFp16 = 2, /*!< input fp16, accumulator fp16 */
CompBf16 = 3, /*!< input bf16, accumulator fp32 */
CompInt8 = 4 /*!< input int8, accumulator int32 */
} MLAS_SQNBIT_COMPUTE_TYPE;
void MLASCALL
MlasSQNBitGemmPackQuantBData(
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
const void* QuantBData,
void* PackedQuantBData,
MLAS_THREADPOOL* ThreadPool = nullptr
);
/**
* @brief Data parameters for NBits GEMM routine
@ -139,7 +234,7 @@ MlasNBitsGemmPackBSize(
* @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor
* one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where
* they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up
* inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale
* inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale
* (is_asym is false) and Zp(is_asym is true).
* @param thread_pool
*/
@ -186,7 +281,7 @@ MlasNBitsGemmUnPackB(
* @return Workspace size in bytes
*/
size_t MLASCALL
MlasSQNBitsGemmBatchWorkspaceSize(
MlasSQNBitsGemmBatchPackedBWorkspaceSize(
const size_t M,
const size_t N,
const size_t K,

View file

@ -482,7 +482,6 @@ Return Value:
this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon;
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
//
// Check if the processor supports ASIMD dot product instructions.
@ -512,6 +511,9 @@ Return Value:
this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot;
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot;
// MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
}
#if defined(__linux__)

View file

@ -11,10 +11,14 @@ Module Name:
Abstract:
This module implements the float/quantized n-bit integer matrix
multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch.
multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch,
as well as some SQNBitGemm-related query functions.
--*/
#include "sqnbitgemm.h"
#include <cassert>
#ifdef MLAS_JBLAS
#include "jblas_gemm.h"
#endif
@ -22,29 +26,564 @@ Abstract:
namespace
{
// Get quantization variant based on `BlkBitWidth` and `BlkLen`.
// Return -1 if the input values are unsupported.
int32_t
GetDispatchQuantVariant(size_t BlkBitWidth, size_t BlkLen)
enum SQNBitGemmVariant {
SQNBitGemmVariantInvalid = -1,
// Valid variants
SQNBitGemmVariant_BitWidth4_CompFp32 = 0,
SQNBitGemmVariant_BitWidth4_CompInt8,
// End of valid variants
// Keep this element last and ensure that its value is the number of valid SQNBitGemmVariant values.
// Its value is used as an array size.
SQNBitGemmVariantCount,
};
SQNBitGemmVariant
GetSQNBitGemmVariant(
size_t M,
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
int32_t type = -1;
if (BlkBitWidth == 4 && BlkLen == 16) {
type = QuantVariant_BitWidth4_BlockSize16;
} else if (BlkBitWidth == 4 && BlkLen == 32) {
type = QuantVariant_BitWidth4_BlockSize32;
} else if (BlkBitWidth == 4 && BlkLen == 64) {
type = QuantVariant_BitWidth4_BlockSize64;
} else if (BlkBitWidth == 4 && BlkLen == 128) {
type = QuantVariant_BitWidth4_BlockSize128;
} else if (BlkBitWidth == 4 && BlkLen == 256) {
type = QuantVariant_BitWidth4_BlockSize256;
MLAS_UNREFERENCED_PARAMETER(N);
MLAS_UNREFERENCED_PARAMETER(K);
if (BlkBitWidth == 4 &&
(BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) {
if (ComputeType == CompFp32 ||
ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32
return SQNBitGemmVariant_BitWidth4_CompFp32;
} else if (ComputeType == CompInt8 && M == 1) {
return SQNBitGemmVariant_BitWidth4_CompInt8;
}
}
return type;
return SQNBitGemmVariantInvalid;
}
} // namespace
bool MLASCALL
MlasIsSQNBitGemmAvailable(
size_t M,
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch;
if (Dispatch == nullptr) {
return false;
}
const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType);
switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompFp32: {
return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr &&
Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr;
}
case SQNBitGemmVariant_BitWidth4_CompInt8: {
return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr &&
Dispatch->QuantizeARow_CompInt8 != nullptr;
}
default: {
return false;
}
}
}
namespace
{
size_t
SQNBitGemmWorkspaceAlignment(SQNBitGemmVariant Variant)
{
switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompInt8: {
return Q8BlkAlignment();
}
default: {
return 1;
}
}
}
size_t
SQNBitGemmPerGemmWorkspaceSize(
SQNBitGemmVariant Variant,
size_t M,
size_t N,
size_t K,
size_t BlkLen
)
{
MLAS_UNREFERENCED_PARAMETER(N);
switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompInt8: {
// workspace buffer is used for block quantization of A to int8
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen);
return PerGemmWorkspaceSize;
}
default: {
return 0;
}
}
}
size_t
SQNBitGemmPerGemmWorkspaceStride(
SQNBitGemmVariant Variant,
size_t M,
size_t N,
size_t K,
size_t BlkLen
)
{
const auto Size = SQNBitGemmPerGemmWorkspaceSize(Variant, M, N, K, BlkLen);
const auto Alignment = SQNBitGemmWorkspaceAlignment(Variant);
return MlasDivRoundup(Size, Alignment) * Alignment;
}
} // namespace
size_t MLASCALL
MlasSQNBitGemmBatchWorkspaceSize(
size_t M,
size_t N,
size_t K,
size_t BatchN,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType);
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen);
if (PerGemmWorkspaceStride == 0) {
return 0;
}
const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant);
const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride;
return WorkspaceSize + Alignment - 1;
}
namespace
{
void
SQ4BitGemmPackQuantBData(
size_t N,
size_t K,
size_t BlkLen,
const std::byte* QuantBDataBegin,
std::byte* PackedQuantBDataBegin,
MLAS_THREADPOOL* ThreadPool
)
{
constexpr size_t BlkBitWidth = 4;
assert(BlkLen % 16 == 0);
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const size_t Iterations = N * BlockCountK; // one iteration per block
MlasTrySimpleParallel(
ThreadPool, Iterations,
[&](ptrdiff_t tid) {
const size_t n = tid / BlockCountK;
const size_t k_blk = tid % BlockCountK;
const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize;
const std::byte* QuantBData = QuantBDataBegin + data_offset;
std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset;
//
// Pack 16 4-bit values (8 bytes) at a time like this:
//
// src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF |
// =>
// dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF |
//
for (size_t kk = 0; kk < BlkLen; kk += 16) {
for (size_t byte_pair_idx = 0; byte_pair_idx < 4; ++byte_pair_idx) {
const std::byte src0 = QuantBData[byte_pair_idx];
const std::byte src1 = QuantBData[byte_pair_idx + 4];
std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx];
std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1];
dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4);
dst1 = (src0 >> 4) | ((src1 >> 4) << 4);
}
QuantBData += 8;
PackedQuantBData += 8;
}
}
);
}
} // namespace
size_t MLASCALL
MlasSQNBitGemmPackQuantBDataSize(
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen
)
{
// Ensure that a general implementation is available on this platform.
// For now, all implementations share the same packed format.
{
// Currently, there are implementations specific to M = 1, so pick a more general M > 1.
constexpr size_t M = 2;
// A CompUndef implementation should be available if any is available.
constexpr MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType = CompUndef;
const bool HasGeneralImplementation =
MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType);
if (!HasGeneralImplementation) {
return 0;
}
}
if (BlkBitWidth == 4) {
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
return PackedQuantBDataSize;
}
return 0;
}
void MLASCALL
MlasSQNBitGemmPackQuantBData(
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
const void* QuantBData,
void* PackedQuantBData,
MLAS_THREADPOOL* ThreadPool
)
{
if (BlkBitWidth == 4) {
SQ4BitGemmPackQuantBData(
N,
K,
BlkLen,
static_cast<const std::byte*>(QuantBData),
static_cast<std::byte*>(PackedQuantBData),
ThreadPool
);
}
}
namespace
{
MLAS_FORCEINLINE void
AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc)
{
for (size_t m = 0; m < CountM; m++) {
const float* bias = Bias;
float* sum = C;
for (size_t n = 0; n < CountN; n += 4) {
if (CountN - n < 4) {
for (size_t nn = n; nn < CountN; nn++) {
*sum += *bias;
sum++;
bias++;
}
break;
}
MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum);
acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias));
MlasStoreFloat32x4(sum, acc_x);
bias += 4;
sum += 4;
}
C += ldc;
}
}
typedef void(SQNBitGemmFn)(
size_t BlkLen,
size_t K,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
void* PerGemmWorkspace,
size_t RangeStartM,
size_t RangeCountM,
size_t RangeStartN,
size_t RangeCountN
);
void
SQ4BitGemm_CompFp32(
const size_t BlkLen,
const size_t K,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams,
void* const PerGemmWorkspace,
const size_t RangeStartM,
const size_t RangeCountM,
const size_t RangeStartN,
const size_t RangeCountN
)
{
constexpr size_t BlkBitWidth = 4;
MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace);
const size_t lda = DataParams->lda;
const size_t ldc = DataParams->ldc;
const size_t k_blks = MlasDivRoundup(K, BlkLen);
const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(k_blks);
const float* A = DataParams->A + RangeStartM * lda;
const std::byte* QuantBData = static_cast<const std::byte*>(DataParams->QuantBData) + RangeStartN * ldb;
const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks;
const std::byte* QuantBZeroPoint =
(DataParams->QuantBZeroPoint == nullptr)
? nullptr
: static_cast<const std::byte*>(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes;
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;
const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;
if (RangeCountM == 1) {
size_t CountN;
for (size_t n = 0; n < RangeCountN; n += CountN) {
CountN = std::min(RangeCountN - n, size_t{128});
const float* a_row = A;
const std::byte* b_col = QuantBData + n * ldb;
const float* b_col_scale = QuantBScale + n * k_blks;
const std::byte* b_col_zp =
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32(
BlkLen,
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
);
if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM, RangeStartN + n,
RangeCountM, CountN, ldc
);
}
}
return;
}
constexpr size_t StrideN = 32;
size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float);
MlasThreadedBufAlloc(bufsize);
auto* dequant_b = reinterpret_cast<float*>(ThreadedBufHolder.get());
//
// Step through each slice of matrix B along the N dimension.
//
size_t CountN;
for (size_t n = 0; n < RangeCountN; n += CountN) {
CountN = std::min(RangeCountN - n, StrideN);
//
// Step through each slice of matrix A along the M dimension.
//
const float* a_row = A;
const std::byte* b_col = QuantBData + n * ldb;
const float* b_col_scale = QuantBScale + n * k_blks;
const std::byte* b_col_zp =
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
GetMlasPlatform().SQNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32(
BlkLen,
dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks
);
size_t RowsRemaining = RangeCountM;
while (RowsRemaining > 0) {
#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)
auto RowsHandled = GetMlasPlatform().GemmFloatKernel(
a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true
);
#else
auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f);
#endif
if (bias) {
AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc);
}
if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN,
RowsHandled, CountN, ldc
);
}
c_blk += ldc * RowsHandled;
a_row += lda * RowsHandled;
RowsRemaining -= RowsHandled;
}
}
}
void
SQ4BitGemm_CompInt8(
const size_t BlkLen,
const size_t K,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams,
void* const PerGemmWorkspace,
const size_t RangeStartM,
const size_t RangeCountM,
const size_t RangeStartN,
const size_t RangeCountN
)
{
constexpr size_t BlkBitWidth = 4;
const size_t k_blks = MlasDivRoundup(K, BlkLen);
const size_t lda = k_blks * Q8BlkSize(BlkLen);
const size_t ldc = DataParams->ldc;
const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(k_blks);
const std::byte* QuantA = static_cast<const std::byte*>(PerGemmWorkspace) + RangeStartM * lda;
const std::byte* QuantBData = static_cast<const std::byte*>(DataParams->QuantBData) + RangeStartN * ldb;
const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks;
const std::byte* QuantBZeroPoint =
(DataParams->QuantBZeroPoint == nullptr)
? nullptr
: static_cast<const std::byte*>(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes;
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;
const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;
if (RangeCountM == 1) {
size_t CountN;
for (size_t n = 0; n < RangeCountN; n += CountN) {
CountN = std::min(RangeCountN - n, size_t{128});
const std::byte* a_row = QuantA;
const std::byte* b_col = QuantBData + n * ldb;
const float* b_col_scale = QuantBScale + n * k_blks;
const std::byte* b_col_zp =
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8(
BlkLen,
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
);
if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM, RangeStartN + n,
RangeCountM, CountN, ldc
);
}
}
return;
}
assert(false && "not implemented for M > 1");
}
typedef void(InitializeWorkspaceFn)(
size_t M,
size_t N,
size_t K,
size_t BatchN,
size_t BlkLen,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
void* Workspace,
size_t PerGemmWorkspaceStride,
MLAS_THREADPOOL* ThreadPool
);
void
InitializeWorkspace_CompInt8(
size_t M,
size_t N,
size_t K,
size_t BatchN,
size_t BlkLen,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
void* Workspace,
size_t PerGemmWorkspaceStride,
MLAS_THREADPOOL* ThreadPool
)
{
MLAS_UNREFERENCED_PARAMETER(N);
const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8;
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen);
MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) {
const auto& data = DataParams[gemm_idx];
const float* ARowPtr = data.A;
std::byte* QuantARowPtr = static_cast<std::byte*>(Workspace) + gemm_idx * PerGemmWorkspaceStride;
for (size_t m = 0; m < M; ++m) {
QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr);
ARowPtr += data.lda;
QuantARowPtr += QuantAStride;
}
});
}
struct Operations {
InitializeWorkspaceFn* InitializeWorkspace = nullptr;
SQNBitGemmFn* SQNBitGemm = nullptr;
};
constexpr auto OperationMap = []() {
std::array<Operations, SQNBitGemmVariantCount> ops;
ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32;
ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8;
ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8;
return ops;
}();
} // namespace
void MLASCALL
MlasSQNBitGemmBatch(
const size_t M,
@ -53,17 +592,43 @@ MlasSQNBitGemmBatch(
const size_t BatchN,
const size_t BlkBitWidth,
const size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
void* Workspace,
MLAS_THREADPOOL* ThreadPool
)
{
const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen);
MLAS_SQNBIT_GEMM_OPERATION* const Operation = GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant];
const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType);
assert(Variant != SQNBitGemmVariantInvalid);
//
// Ensure `Workspace` has correct alignment.
//
if (Workspace != nullptr) {
const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant);
const uintptr_t WorkspaceAddress = reinterpret_cast<uintptr_t>(Workspace);
Workspace = reinterpret_cast<void*>(
(WorkspaceAddress + Alignment - 1) & (~(Alignment - 1))
);
}
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen);
if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace;
InitializeWorkspaceOperation != nullptr) {
InitializeWorkspaceOperation(
M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool
);
}
const auto ComputeOperation = OperationMap[Variant].SQNBitGemm;
if (ThreadPool == nullptr) {
for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) {
auto Data = &DataParams[gemm_i];
Operation(K, Data, 0, M, 0, N);
const auto* Data = &DataParams[gemm_i];
void* PerGemmWorkspace =
reinterpret_cast<std::byte*>(Workspace) + gemm_i * PerGemmWorkspaceStride;
ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N);
}
return;
}
@ -112,7 +677,10 @@ MlasSQNBitGemmBatch(
MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) {
const auto gemm_i = tid / ThreadsPerGemm;
const auto blk_i = tid % ThreadsPerGemm;
auto Data = &DataParams[gemm_i];
const auto* Data = &DataParams[gemm_i];
void* PerGemmWorkspace = reinterpret_cast<void*>(
reinterpret_cast<std::byte*>(Workspace) + gemm_i * PerGemmWorkspaceStride
);
const ptrdiff_t ThreadIdN = blk_i / ThreadCountM;
const ptrdiff_t ThreadIdM = blk_i % ThreadCountM;
@ -123,29 +691,10 @@ MlasSQNBitGemmBatch(
const size_t RangeStartN = ThreadIdN * StrideN;
const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN);
Operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN);
ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN);
});
}
bool MLASCALL
MlasIsSQNBitGemmAvailable(
size_t BlkBitWidth,
size_t BlkLen
)
{
const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen);
if (QuantVariant == -1) {
return false;
}
if (GetMlasPlatform().SQNBitGemmDispatch == nullptr ||
GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant] == nullptr) {
return false;
}
return true;
}
size_t MLASCALL
MlasNBitsGemmPackBSize(
size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType
@ -224,7 +773,7 @@ MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, s
}
size_t MLASCALL
MlasSQNBitsGemmBatchWorkspaceSize(
MlasSQNBitsGemmBatchPackedBWorkspaceSize(
const size_t M,
const size_t N,
const size_t K,

View file

@ -10,98 +10,23 @@ Module Name:
Abstract:
This module includes:
This module includes kernel function prototypes and helper functions for
implementing SQNBitGemm.
- Declaration of the set of template functions used to implement a kernel
for a matrix/matrix multiplication, A*B, where A is a float matrix and B is
a n-bit quantized integer matrix (QNBitGemm).
- A shared kernel driver function template, MlasSQNBitGemmOperation.
- Kernel dispatch structure.
The B matrix is block quantized, which means that its values are grouped
into blocks which each have one scale and optional zero point. Each
quantized value in B is n-bits wide.
SQNBitGemm is a matrix/matrix multiplication, A*B, where A is a float
matrix and B is a n-bit quantized integer matrix. B is block quantized,
meaning values of B are divided into blocks and each block has its own
scale and optional zero point.
--*/
#pragma once
#include <cassert>
#include "mlas_qnbit.h"
#include "mlasi.h"
//
// Kernel implementation template declarations
//
/**
* @brief Multiply float matrix A with quantized n-bit integer matrix B.
* B is block quantized and column major.
* This kernel handles the special case where M, the number of rows of A and C, is 1.
*
* @tparam BlkBitWidth Bit width of each value in a block.
* @tparam BlkLen Number of values in a block.
* @tparam KernelType Hardware-specific kernel type.
*
* @param A Supplies the A matrix.
* @param QuantBData Supplies the quantized B matrix block data.
* @param QuantBScale Supplies the quantized B matrix block scale values.
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
* @param[out] C Supplies the output C matrix.
* @param CountN Number of columns of B and C.
* @param CountK Number of columns of A and rows of B.
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
* @param Bias Bias vector of length N.
*/
template <size_t BlkBitWidth, size_t BlkLen, typename KernelType>
MLAS_FORCEINLINE void
MlasSQNBitGemmM1Kernel(
const float* A,
const uint8_t* QuantBData,
const float* QuantBScale,
const uint8_t* QuantBZeroPoint,
float* C,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB,
const float* Bias
);
/**
* @brief Dequantize B into the format expected by the Sgemm kernel.
* B is block quantized and column major.
* This is equivalent to dequantizing B and then running
* MlasSgemmCopyPackB.
*
* @tparam BlkBitWidth Bit width of each value in a block.
* @tparam BlkLen Number of values in a block.
* @tparam KernelType Hardware-specific kernel type.
*
* @param[out] FpData Supplies the output buffer for the dequantized B float data.
* @param QuantBData Supplies the quantized B matrix block data.
* @param QuantBScale Supplies the quantized B matrix block scale values.
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
* @param CountN Number of columns of B.
* @param CountK Number of rows of B.
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
*/
template <size_t BlkBitWidth, size_t BlkLen, typename KernelType>
MLAS_FORCEINLINE void
MlasQNBitBlkDequantBForSgemm(
float* FpData,
const uint8_t* QuantBData,
const float* QuantBScale,
const uint8_t* QuantBZeroPoint,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB
);
//
// MlasQNBitGemmOperation and helpers
//
constexpr MLAS_FORCEINLINE size_t
MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen)
{
@ -119,169 +44,174 @@ MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount)
}
}
MLAS_FORCEINLINE void
MlasAddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc)
{
for (size_t m = 0; m < CountM; m++) {
const float* bias = Bias;
float* sum = C;
for (size_t n = 0; n < CountN; n += 4) {
if (CountN - n < 4) {
for (size_t nn = n; nn < CountN; nn++) {
*sum += *bias;
sum++;
bias++;
}
break;
}
//
// Quantized int8 block helpers.
//
MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum);
acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias));
MlasStoreFloat32x4(sum, acc_x);
bias += 4;
sum += 4;
}
C += ldc;
}
MLAS_FORCEINLINE
const float&
Q8BlkScale(const std::byte* BlkPtr)
{
return *reinterpret_cast<const float*>(BlkPtr);
}
template <size_t BlkBitWidth, size_t BlkLen, typename KernelType>
MLAS_FORCEINLINE void MLASCALL
MlasSQNBitGemmOperation(
const size_t K,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams,
const size_t RangeStartM,
const size_t RangeCountM,
const size_t RangeStartN,
const size_t RangeCountN
)
MLAS_FORCEINLINE
float&
Q8BlkScale(std::byte* BlkPtr)
{
const size_t lda = DataParams->lda;
const size_t ldc = DataParams->ldc;
return *reinterpret_cast<float*>(BlkPtr);
}
const size_t k_blks = MlasDivRoundup(K, BlkLen);
const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(k_blks);
MLAS_FORCEINLINE
const int8_t*
Q8BlkData(const std::byte* BlkPtr)
{
return reinterpret_cast<const int8_t*>(BlkPtr + sizeof(float));
}
const float* A = DataParams->A + RangeStartM * lda;
MLAS_FORCEINLINE
int8_t*
Q8BlkData(std::byte* BlkPtr)
{
return reinterpret_cast<int8_t*>(BlkPtr + sizeof(float));
}
const uint8_t* QuantBData = static_cast<const uint8_t*>(DataParams->QuantBData) + RangeStartN * ldb;
const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks;
const uint8_t* QuantBZeroPoint =
(DataParams->QuantBZeroPoint == nullptr)
? nullptr
: static_cast<const uint8_t*>(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes;
MLAS_FORCEINLINE
constexpr size_t
Q8BlkSize(size_t BlkLen)
{
const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t);
// Currently, the strictest alignment requirement of a block is for a float.
// Ensure contiguous blocks are suitably aligned.
assert(BlkSize % alignof(float) == 0);
return BlkSize;
}
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;
const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;
if (RangeCountM == 1) {
size_t CountN;
for (size_t n = 0; n < RangeCountN; n += CountN) {
CountN = std::min(RangeCountN - n, size_t{128});
const float* a_row = A;
const uint8_t* b_col = QuantBData + n * ldb;
const float* b_col_scale = QuantBScale + n * k_blks;
const uint8_t* b_col_zp =
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
MlasSQNBitGemmM1Kernel<BlkBitWidth, BlkLen, KernelType>(
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
);
if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM, RangeStartN + n,
RangeCountM, CountN, ldc
);
}
}
return;
}
constexpr size_t StrideN = 32;
size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float);
MlasThreadedBufAlloc(bufsize);
auto* dequant_b = reinterpret_cast<float*>(ThreadedBufHolder.get());
//
// Step through each slice of matrix B along the N dimension.
//
size_t CountN;
for (size_t n = 0; n < RangeCountN; n += CountN) {
CountN = std::min(RangeCountN - n, StrideN);
//
// Step through each slice of matrix A along the M dimension.
//
const float* a_row = A;
const uint8_t* b_col = QuantBData + n * ldb;
const float* b_col_scale = QuantBScale + n * k_blks;
const uint8_t* b_col_zp =
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
MlasQNBitBlkDequantBForSgemm<BlkBitWidth, BlkLen, KernelType>(
dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks
);
size_t RowsRemaining = RangeCountM;
while (RowsRemaining > 0) {
#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)
auto RowsHandled = GetMlasPlatform().GemmFloatKernel(
a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true
);
#else
auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f);
#endif
if (bias) {
MlasAddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc);
}
if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN,
RowsHandled, CountN, ldc
);
}
c_blk += ldc * RowsHandled;
a_row += lda * RowsHandled;
RowsRemaining -= RowsHandled;
}
}
MLAS_FORCEINLINE
constexpr size_t
Q8BlkAlignment()
{
return alignof(float);
}
//
// Kernel dispatch structure.
//
typedef void(MLASCALL MLAS_SQNBIT_GEMM_OPERATION)(
size_t K,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
size_t RangeStartM,
size_t RangeCountM,
size_t RangeStartN,
size_t RangeCountN
);
enum QuantVariant {
QuantVariant_BitWidth4_BlockSize16,
QuantVariant_BitWidth4_BlockSize32,
QuantVariant_BitWidth4_BlockSize64,
QuantVariant_BitWidth4_BlockSize128,
QuantVariant_BitWidth4_BlockSize256,
QuantVariantCount, // Keep this element last and ensure that its value is the number of other QuantVariant values.
// Its value is used as an array size.
};
struct MLAS_SQNBIT_GEMM_DISPATCH {
MLAS_SQNBIT_GEMM_OPERATION* Operations[QuantVariantCount] = {
// Initialized to nullptrs. Overwrite in hardware-specific kernel implementation.
};
//
// CompFp32 kernel function prototypes.
//
/**
* @brief Multiply float matrix A with quantized 4-bit integer matrix B.
* B is block quantized and column major.
* This kernel handles the special case where M, the number of rows of A and C, is 1.
*
* @param BlkLen Number of values in a block.
* @param A Supplies the A matrix.
* @param QuantBData Supplies the quantized B matrix block data.
* @param QuantBScale Supplies the quantized B matrix block scale values.
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
* @param[out] C Supplies the output C matrix.
* @param CountN Number of columns of B and C.
* @param CountK Number of columns of A and rows of B.
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
* @param Bias Bias vector of length N.
*/
typedef void(SQ4BitGemmM1Kernel_CompFp32_Fn)(
size_t BlkLen,
const float* A,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB,
const float* Bias
);
SQ4BitGemmM1Kernel_CompFp32_Fn* SQ4BitGemmM1Kernel_CompFp32 = nullptr;
/**
* @brief Dequantize B into the format expected by the Sgemm kernel.
* B is a quantized 4-bit integer matrix that is block quantized and column major.
* This is equivalent to dequantizing B and then running MlasSgemmCopyPackB.
*
* @param BlkLen Number of values in a block.
* @param[out] FpData Supplies the output buffer for the dequantized B float data.
* @param QuantBData Supplies the quantized B matrix block data.
* @param QuantBScale Supplies the quantized B matrix block scale values.
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
* @param CountN Number of columns of B.
* @param CountK Number of rows of B.
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
*/
typedef void(Q4BitBlkDequantBForSgemm_CompFp32_Fn)(
size_t BlkLen,
float* FpData,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB
);
Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr;
//
// CompInt8 kernel function prototypes.
//
/**
* @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B.
* A and B are block quantized and B is column major.
* This kernel handles the special case where M, the number of rows of A and C, is 1.
*
* @param BlkLen Number of values in a block.
* @param QuantA Supplies the quantized A matrix.
Binary data containing block quantized int8 data and scale values.
* @param QuantBData Supplies the quantized B matrix block data.
* @param QuantBScale Supplies the quantized B matrix block scale values.
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
* @param[out] C Supplies the output C matrix.
* @param CountN Number of columns of B and C.
* @param CountK Number of columns of A and rows of B.
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
* @param Bias Bias vector of length N.
*/
typedef void(SQ4BitGemmM1Kernel_CompInt8_Fn)(
size_t BlkLen,
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB,
const float* Bias
);
SQ4BitGemmM1Kernel_CompInt8_Fn* SQ4BitGemmM1Kernel_CompInt8 = nullptr;
/**
* @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers.
*
* @param BlkLen Number of values in a block.
* @param A Supplies the A matrix.
* @param CountK Number of columns of A.
* @param[out] QuantA Supplies the output quantized A matrix.
* Binary data containing block quantized int8 data and scale values.
*/
typedef void(QuantizeARow_CompInt8_Fn)(
size_t BlkLen,
const float* A,
size_t CountK,
std::byte* QuantA
);
QuantizeARow_CompInt8_Fn* QuantizeARow_CompInt8 = nullptr;
};

View file

@ -23,12 +23,6 @@ Abstract:
#include <cassert>
#include <utility>
//
// Hardware-specific kernel type.
//
struct MLAS_SQNBIT_GEMM_KERNEL_NEON {
};
namespace
{
@ -70,7 +64,7 @@ FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3)
template <size_t Capacity>
MLAS_FORCEINLINE void
LoadData(const float* src, size_t count, float32x4_t (& dst)[Capacity / 4])
LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4])
{
static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4.");
@ -101,13 +95,14 @@ LoadData(const float* src, size_t count, float32x4_t (& dst)[Capacity / 4])
}
}
template <size_t BlkBitWidth, size_t BlkLen, size_t NCols>
template <size_t NCols>
MLAS_FORCEINLINE void
ComputeDotProducts(
ComputeDotProducts_BlkBitWidth4_CompFp32(
size_t BlkLen,
const float* ARowPtr,
const uint8_t* QuantBDataColPtr,
const std::byte* QuantBDataColPtr,
const float* QuantBScaleColPtr,
const uint8_t* QuantBZeroPointColPtr,
const std::byte* QuantBZeroPointColPtr,
float* SumPtr,
size_t CountK,
size_t StrideQuantBData,
@ -116,8 +111,13 @@ ComputeDotProducts(
const float* BiasPtr
)
{
constexpr size_t BlkBitWidth = 4;
static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4");
constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration
assert(BlkLen % SubBlkLen == 0);
const uint8x8_t LowMask = vdup_n_u8(0x0F);
// Manual conversion to float takes place in two steps:
@ -135,7 +135,7 @@ ComputeDotProducts(
float32x4_t acc[NCols]{};
const uint8_t* QuantBData = QuantBDataColPtr;
const std::byte* QuantBData = QuantBDataColPtr;
const float* QuantBScale = QuantBScaleColPtr;
size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
@ -150,10 +150,12 @@ ComputeDotProducts(
float offset[NCols]; // Includes zero point and float conversion offset of 16.
if (QuantBZeroPointColPtr != nullptr) {
UnrolledLoop<NCols>([&](size_t i) {
const uint8_t zp_packed =
const std::byte zp_packed =
QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2];
const uint8_t zp = ((QuantBZeroPointIdx & 1) == 1) ? (zp_packed >> 4) : (zp_packed & 0x0F);
offset[i] = 16.0f + zp;
const std::byte zp = ((QuantBZeroPointIdx & 1) == 1)
? (zp_packed >> 4)
: (zp_packed & std::byte{0x0F});
offset[i] = 16.0f + std::to_integer<uint8_t>(zp);
});
} else {
UnrolledLoop<NCols>([&](size_t i) {
@ -162,33 +164,27 @@ ComputeDotProducts(
});
}
constexpr size_t SubBlkLen = 16; // number of block elements to process in one iteration
for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) {
// load A row vector elements
// load `SubBlkLen` elements from A, padded with 0's if there aren't enough
const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen);
float32x4_t av[4]{};
LoadData<SubBlkLen>(ARowPtr + k + k_idx_in_blk, k_subblk_len, av);
LoadFloatData<SubBlkLen>(ARowPtr + k + k_idx_in_blk, k_subblk_len, av);
// load B column vectors
uint8x8_t bv_packed[NCols];
const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8;
UnrolledLoop<NCols>([&](size_t i) {
const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8;
bv_packed[i] = vld1_u8(QuantBData + i * StrideQuantBData + b_data_block_offset);
});
uint8x8_t bv_u8_unzipped[NCols][2];
UnrolledLoop<NCols>([&](size_t i) {
bv_u8_unzipped[i][0] = vand_u8(bv_packed[i], LowMask);
bv_u8_unzipped[i][1] = vand_u8(vshr_n_u8(bv_packed[i], 4), LowMask);
bv_packed[i] = vld1_u8(
reinterpret_cast<const uint8_t*>(QuantBData) + i * StrideQuantBData + b_data_block_offset
);
});
uint8x8_t bv_u8[NCols][2];
UnrolledLoop<NCols>([&](size_t i) {
bv_u8[i][0] = vzip1_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]);
bv_u8[i][1] = vzip2_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]);
bv_u8[i][0] = vand_u8(bv_packed[i], LowMask);
bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4);
});
// dequantize B
@ -262,19 +258,13 @@ ComputeDotProducts(
}
}
} // namespace
//
// MlasSQNBitGemmKernel and helpers.
//
template <size_t BlkBitWidth, size_t BlkLen>
MLAS_FORCEINLINE void
MlasSQNBitGemmM1KernelNeon(
SQ4BitGemmM1Kernel_CompFp32(
size_t BlkLen,
const float* A,
const uint8_t* QuantBData,
const std::byte* QuantBData,
const float* QuantBScale,
const uint8_t* QuantBZeroPoint,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountN,
size_t CountK,
@ -282,6 +272,7 @@ MlasSQNBitGemmM1KernelNeon(
const float* Bias
)
{
constexpr size_t BlkBitWidth = 4;
constexpr size_t NCols = 4;
const float* ARowPtr = A;
@ -295,16 +286,17 @@ MlasSQNBitGemmM1KernelNeon(
const float* BiasPtr = Bias;
const uint8_t* QuantBDataColPtr = QuantBData;
const std::byte* QuantBDataColPtr = QuantBData;
const float* QuantBScaleColPtr = QuantBScale;
const uint8_t* QuantBZeroPointColPtr = QuantBZeroPoint;
const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint;
float* SumPtr = CRowPtr;
int64_t nblk = static_cast<int64_t>(CountN) - NCols;
while (nblk >= 0) {
ComputeDotProducts<BlkBitWidth, BlkLen, NCols>(
ComputeDotProducts_BlkBitWidth4_CompFp32<NCols>(
BlkLen,
ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
BiasPtr
@ -327,7 +319,8 @@ MlasSQNBitGemmM1KernelNeon(
// left over columns less than `NCols`?
nblk += NCols;
for (int64_t n = 0; n < nblk; ++n) {
ComputeDotProducts<BlkBitWidth, BlkLen, 1>(
ComputeDotProducts_BlkBitWidth4_CompFp32<1>(
BlkLen,
ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
BiasPtr
@ -346,59 +339,26 @@ MlasSQNBitGemmM1KernelNeon(
}
}
#define SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(BlkBitWidth, BlkLen) \
template <> \
MLAS_FORCEINLINE void \
MlasSQNBitGemmM1Kernel<BlkBitWidth, BlkLen, MLAS_SQNBIT_GEMM_KERNEL_NEON>( \
const float* A, \
const uint8_t* QuantBData, \
const float* QuantBScale, \
const uint8_t* QuantBZeroPoint, \
float* C, \
size_t CountN, \
size_t CountK, \
size_t BlockStrideQuantB, \
const float* Bias \
) \
{ \
return MlasSQNBitGemmM1KernelNeon<BlkBitWidth, BlkLen>( \
A, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, \
BlockStrideQuantB, Bias \
); \
}
SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 16)
SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 32)
SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 64)
SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 128)
SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 256)
#undef SPECIALIZE_SQNBIT_GEMM_M1_KERNEL
//
// MlasQNBitBlkDequantBForSgemm and helpers.
//
template <size_t BlkBitWidth, size_t BlkLen>
MLAS_FORCEINLINE void
MlasQNBitBlkDequantBForSgemmNeon(
Q4BitBlkDequantBForSgemm_CompFp32(
size_t BlkLen,
float* FpData,
const uint8_t* QuantBData,
const std::byte* QuantBData,
const float* QuantBScale,
const uint8_t* QuantBZeroPoint,
const std::byte* QuantBZeroPoint,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB
)
{
auto impl0_reference = [&]() {
static_assert(BlkBitWidth == 4);
constexpr size_t BlkBitWidth = 4;
float* Dst = FpData;
const uint8_t* QuantBDataCol = QuantBData;
const std::byte* QuantBDataCol = QuantBData;
const float* QuantBScaleCol = QuantBScale;
const uint8_t* QuantBZeroPointCol = QuantBZeroPoint;
const std::byte* QuantBZeroPointCol = QuantBZeroPoint;
for (size_t n = 0; n < CountN; n += 16) {
const size_t nnlen = std::min(CountN - n, size_t{16});
@ -407,20 +367,26 @@ MlasQNBitBlkDequantBForSgemmNeon(
for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) {
const size_t kklen = std::min(CountK - k, BlkLen);
const uint8_t* b_data =
const std::byte* b_data =
QuantBDataCol + k_blk_idx * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const float b_s = QuantBScaleCol[k_blk_idx];
const uint8_t b_z =
(QuantBZeroPointCol != nullptr)
? ((k_blk_idx & 1) == 1)
? QuantBZeroPointCol[k_blk_idx / 2] >> 4
: QuantBZeroPointCol[k_blk_idx / 2] & 0x0F
? std::to_integer<uint8_t>(QuantBZeroPointCol[k_blk_idx / 2] >> 4)
: std::to_integer<uint8_t>(QuantBZeroPointCol[k_blk_idx / 2] & std::byte{0x0F})
: 8;
for (size_t kk = 0; kk < kklen; ++kk) {
const uint8_t b_packed = b_data[kk / 2];
const uint8_t b_byte = ((kk & 1) == 1) ? b_packed >> 4 : b_packed & 0x0F;
const float b_value = (b_byte - b_z) * b_s;
const size_t packed_idx = kk % 16;
const bool is_low_half = packed_idx < 8;
const size_t packed_byte_idx = packed_idx % 8;
const size_t packed_range_offset = (kk / 16) * 8;
const std::byte b_packed = b_data[packed_range_offset + packed_byte_idx];
const std::byte b_byte = is_low_half ? (b_packed & std::byte{0x0F}) : (b_packed >> 4);
const float b_value = (std::to_integer<int8_t>(b_byte) - b_z) * b_s;
Dst[(k + kk) * 16 + nn] = b_value;
}
@ -448,31 +414,332 @@ MlasQNBitBlkDequantBForSgemmNeon(
impl0_reference();
}
#define SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(BlkBitWidth, BlkLen) \
template <> \
MLAS_FORCEINLINE void \
MlasQNBitBlkDequantBForSgemm<BlkBitWidth, BlkLen, MLAS_SQNBIT_GEMM_KERNEL_NEON>( \
float* FpData, \
const uint8_t* QuantBData, \
const float* QuantBScale, \
const uint8_t* QuantBZeroPoint, \
size_t CountN, \
size_t CountK, \
size_t BlockStrideQuantB \
) \
{ \
MlasQNBitBlkDequantBForSgemmNeon<BlkBitWidth, BlkLen>( \
FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB \
); \
//
// CompInt8 kernel implementation and related helpers
//
template <size_t SubBlkLen>
MLAS_FORCEINLINE void
QuantizeBlock(
size_t BlkLen,
const float* A,
size_t ElementCount,
std::byte* QuantA
)
{
static_assert(SubBlkLen >= 16 && SubBlkLen % 16 == 0);
assert(BlkLen % SubBlkLen == 0);
constexpr size_t VectorCount = SubBlkLen / 4;
//
// Scan block values first to determine scale.
//
float amax = 0.0f; // max of absolute values of A block
size_t k;
for (k = 0; k < ElementCount; k += SubBlkLen) {
const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen);
float32x4_t a[VectorCount]{};
LoadFloatData<SubBlkLen>(A + k, SubBlkElementCount, a);
float32x4_t abs_a[VectorCount];
UnrolledLoop<VectorCount>([&](size_t i) {
abs_a[i] = vabsq_f32(a[i]);
});
// find amax of SubBlkLen elements
for (size_t interval = VectorCount / 2; interval > 0; interval /= 2) {
for (size_t i = 0; i < interval; ++i) {
abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]);
}
}
// update existing amax
amax = std::max(amax, vmaxvq_f32(abs_a[0]));
}
SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 16)
SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 32)
SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 64)
SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 128)
SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 256)
constexpr float range_max = (1 << 7) - 1;
const float scale = amax / range_max;
const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f;
#undef SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM
Q8BlkScale(QuantA) = scale;
//
// Compute quantized block values.
//
int8_t* QuantAData = Q8BlkData(QuantA);
for (k = 0; k < ElementCount; k += SubBlkLen) {
const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen);
float32x4_t a[VectorCount]{};
LoadFloatData<SubBlkLen>(A + k, SubBlkElementCount, a);
UnrolledLoop<VectorCount>([&](size_t i) {
a[i] = vmulq_n_f32(a[i], scale_reciprocal);
});
int32x4_t a_s32[VectorCount];
UnrolledLoop<VectorCount>([&](size_t i) {
a_s32[i] = vcvtaq_s32_f32(a[i]);
});
UnrolledLoop<VectorCount>([&](size_t i) {
QuantAData[k + i * 4 + 0] = static_cast<int8_t>(vgetq_lane_s32(a_s32[i], 0));
QuantAData[k + i * 4 + 1] = static_cast<int8_t>(vgetq_lane_s32(a_s32[i], 1));
QuantAData[k + i * 4 + 2] = static_cast<int8_t>(vgetq_lane_s32(a_s32[i], 2));
QuantAData[k + i * 4 + 3] = static_cast<int8_t>(vgetq_lane_s32(a_s32[i], 3));
});
}
//
// Zero out any remaining sub-block elements.
//
for (; k < BlkLen; k += SubBlkLen) {
const int8x16_t Zeros = vdupq_n_s8(0);
UnrolledLoop<SubBlkLen / 16>([&](size_t i) {
vst1q_s8(QuantAData + k + i * 16, Zeros);
});
}
}
void MLASCALL
QuantizeARow_CompInt8(
size_t BlkLen,
const float* A,
size_t CountK,
std::byte* QuantA
)
{
const float* ADataBlkPtr = A;
std::byte* QuantABlkPtr = QuantA;
for (size_t k = 0; k < CountK; k += BlkLen) {
const size_t k_blk_len = std::min(CountK - k, BlkLen);
QuantizeBlock<16>(BlkLen, ADataBlkPtr, k_blk_len, QuantABlkPtr);
ADataBlkPtr += BlkLen;
QuantABlkPtr += Q8BlkSize(BlkLen);
}
}
template <size_t NCols>
MLAS_FORCEINLINE void
ComputeDotProducts_BlkBitWidth4_CompInt8(
size_t BlkLen,
const std::byte* QuantARowPtr,
const std::byte* QuantBDataColPtr,
const float* QuantBScaleColPtr,
const std::byte* QuantBZeroPointColPtr,
float* SumPtr,
size_t CountK,
size_t StrideQuantBData,
size_t StrideQuantBScale,
size_t StrideQuantBZeroPoint,
const float* BiasPtr
)
{
static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4");
constexpr size_t BlkBitWidth = 4;
constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration
assert(BlkLen % SubBlkLen == 0);
const uint8x8_t LowMask = vdup_n_u8(0x0F);
const std::byte* QuantA = QuantARowPtr;
const std::byte* QuantBData = QuantBDataColPtr;
const float* QuantBScale = QuantBScaleColPtr;
size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
float32x4_t acc[NCols]{};
for (size_t k = 0; k < CountK; k += BlkLen) {
const size_t k_blk_len = std::min(CountK - k, BlkLen);
const float a_scale = Q8BlkScale(QuantA);
const int8_t* a_data = Q8BlkData(QuantA);
float b_scale[NCols];
UnrolledLoop<NCols>([&](size_t i) { b_scale[i] = QuantBScale[i * StrideQuantBScale]; });
int8_t b_zp[NCols];
if (QuantBZeroPointColPtr != nullptr) {
UnrolledLoop<NCols>([&](size_t i) {
const std::byte zp_packed =
QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2];
b_zp[i] = ((QuantBZeroPointIdx & 1) == 1)
? std::to_integer<int8_t>(zp_packed >> 4)
: std::to_integer<int8_t>(zp_packed & std::byte{0x0F});
});
} else {
UnrolledLoop<NCols>([&](size_t i) {
b_zp[i] = 8;
});
}
for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) {
// load A row vector
int8x16_t av = vld1q_s8(a_data + k_idx_in_blk);
// load B column vectors
uint8x8_t bv_packed[NCols];
const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8;
UnrolledLoop<NCols>([&](size_t i) {
bv_packed[i] = vld1_u8(
reinterpret_cast<const uint8_t*>(QuantBData) + i * StrideQuantBData + b_data_block_offset
);
});
int8x16_t bv[NCols];
UnrolledLoop<NCols>([&](size_t i) {
const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMask));
const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4));
bv[i] = vcombine_s8(lo, hi);
});
// subtract B zero point
UnrolledLoop<NCols>([&](size_t i) {
const int8x16_t zp_v = vdupq_n_s8(b_zp[i]);
bv[i] = vsubq_s8(bv[i], zp_v);
});
// compute quantized dot product
int32x4_t dot[NCols]{};
UnrolledLoop<NCols>([&](size_t i) {
dot[i] = vdotq_s32(dot[i], av, bv[i]);
});
// convert dot product result to float
float32x4_t dot_f32[NCols];
UnrolledLoop<NCols>([&](size_t i) {
dot_f32[i] = vcvtq_f32_s32(dot[i]);
});
// multiply dot product result by scale and update accumulator
UnrolledLoop<NCols>([&](size_t i) {
const float32x4_t scale_v = vdupq_n_f32(a_scale * b_scale[i]);
acc[i] = vfmaq_f32(acc[i], dot_f32[i], scale_v);
});
}
// increment pointers to next block
QuantA += Q8BlkSize(BlkLen);
QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
QuantBScale += 1;
QuantBZeroPointIdx += 1;
}
if constexpr (NCols == 4) {
float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]);
if (BiasPtr != nullptr) {
sum = vaddq_f32(sum, vld1q_f32(BiasPtr));
}
vst1q_f32(SumPtr, sum);
} else {
for (size_t i = 0; i < NCols; ++i) {
SumPtr[i] = vaddvq_f32(acc[i]);
if (BiasPtr != nullptr) {
SumPtr[i] += BiasPtr[i];
}
}
}
}
MLAS_FORCEINLINE
void
SQ4BitGemmM1Kernel_CompInt8(
size_t BlkLen,
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB,
const float* Bias
)
{
constexpr size_t BlkBitWidth = 4;
constexpr size_t NCols = 4;
const std::byte* QuantARowPtr = QuantA;
float* CRowPtr = C;
const size_t BlockCountK = BlockStrideQuantB;
const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const size_t StrideQuantBScale = BlockCountK;
const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(BlockCountK);
const float* BiasPtr = Bias;
const std::byte* QuantBDataColPtr = QuantBData;
const float* QuantBScaleColPtr = QuantBScale;
const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint;
float* SumPtr = CRowPtr;
int64_t nblk = static_cast<int64_t>(CountN) - NCols;
while (nblk >= 0) {
ComputeDotProducts_BlkBitWidth4_CompInt8<NCols>(
BlkLen,
QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
BiasPtr
);
// move to next `NCols` columns
QuantBDataColPtr += NCols * StrideQuantBData;
QuantBScaleColPtr += NCols * StrideQuantBScale;
if (QuantBZeroPointColPtr != nullptr) {
QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint;
}
BiasPtr += BiasPtr != nullptr ? NCols : 0;
SumPtr += NCols;
nblk -= NCols;
}
// left over columns less than `NCols`?
nblk += NCols;
for (int64_t n = 0; n < nblk; ++n) {
ComputeDotProducts_BlkBitWidth4_CompInt8<1>(
BlkLen,
QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
BiasPtr
);
// move to next column
QuantBDataColPtr += StrideQuantBData;
QuantBScaleColPtr += StrideQuantBScale;
if (QuantBZeroPointColPtr != nullptr) {
QuantBZeroPointColPtr += StrideQuantBZeroPoint;
}
BiasPtr += BiasPtr != nullptr ? 1 : 0;
SumPtr += 1;
}
}
} // namespace
//
// Kernel dispatch structure definition.
@ -480,10 +747,11 @@ SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 256)
const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() {
MLAS_SQNBIT_GEMM_DISPATCH d;
d.Operations[QuantVariant_BitWidth4_BlockSize16] = MlasSQNBitGemmOperation<4, 16, MLAS_SQNBIT_GEMM_KERNEL_NEON>;
d.Operations[QuantVariant_BitWidth4_BlockSize32] = MlasSQNBitGemmOperation<4, 32, MLAS_SQNBIT_GEMM_KERNEL_NEON>;
d.Operations[QuantVariant_BitWidth4_BlockSize64] = MlasSQNBitGemmOperation<4, 64, MLAS_SQNBIT_GEMM_KERNEL_NEON>;
d.Operations[QuantVariant_BitWidth4_BlockSize128] = MlasSQNBitGemmOperation<4, 128, MLAS_SQNBIT_GEMM_KERNEL_NEON>;
d.Operations[QuantVariant_BitWidth4_BlockSize256] = MlasSQNBitGemmOperation<4, 256, MLAS_SQNBIT_GEMM_KERNEL_NEON>;
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32;
d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8;
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8;
return d;
}();

View file

@ -109,12 +109,19 @@ void Q8Q4GEMM(benchmark::State& state, MLAS_BLK_QUANT_TYPE qtype) {
static void GemmSizeProducts(benchmark::internal::Benchmark* b) {
b->ArgNames(q4gemm_bench_arg_names);
ArgsProduct(b, {{1, 1024, 2048}, {4096}, {4096}, {8}});
b->ArgsProduct({{1, 1024, 2048}, {4096}, {4096}, {8}});
}
BENCHMARK_CAPTURE(Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK_CAPTURE(Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK_CAPTURE(Q4GEMM, Q4Sym128, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym128, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime();
[[maybe_unused]] static const bool benchmarks_registered = []() {
const bool is_q4gemm_supported = MlasQ4GemmPackBSize(BlkQ4Sym, 1, 1) > 0;
if (is_q4gemm_supported) {
BENCHMARK_CAPTURE(Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK_CAPTURE(Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK_CAPTURE(Q4GEMM, Q4Sym128, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym128, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime();
return true;
}
return false;
}();

View file

@ -224,8 +224,7 @@ BENCHMARK_CAPTURE(SCONV_NCHW, TeamsModel, "")->Apply(TeamsModel)->UseRealTime();
static void General_Conv2d(benchmark::internal::Benchmark* b) {
b->ArgNames(ArgNamesForConv(2));
ArgsProduct(
b,
b->ArgsProduct(
{{2}, // Rank,
{1}, // N
{1, 2}, // Groups

View file

@ -103,14 +103,14 @@ void SGEMM(benchmark::State& state, bool pack_b, bool trans_a, bool trans_b, flo
static void GemmSizeWithOne(benchmark::internal::Benchmark* b) {
b->ArgNames(sgemm_bench_arg_names);
ArgsProduct(b, {{1}, {63, 255, 1023}, {63, 255, 1023}});
ArgsProduct(b, {{63, 255, 1023}, {1}, {63, 255, 1023}});
ArgsProduct(b, {{63, 255, 1023}, {63, 255, 1023}, {1}});
b->ArgsProduct({{1}, {63, 255, 1023}, {63, 255, 1023}});
b->ArgsProduct({{63, 255, 1023}, {1}, {63, 255, 1023}});
b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {1}});
}
static void GemmSizeProducts(benchmark::internal::Benchmark* b) {
b->ArgNames(sgemm_bench_arg_names);
ArgsProduct(b, {{63, 255, 1023}, {63, 255, 1023}, {63, 255, 1023}});
b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {63, 255, 1023}});
}
BENCHMARK_CAPTURE(SGEMM, NORMAL_NoTrans, false, false, false)->Apply(GemmSizeProducts)->UseRealTime();
@ -128,7 +128,7 @@ BENCHMARK_CAPTURE(SGEMM, PACKB_TransA, true, true, false)->Apply(GemmSizeProduct
static void GemmLLMSizeProducts(benchmark::internal::Benchmark* b) {
b->ArgNames(sgemm_bench_arg_names);
ArgsProduct(b, {{1, 1024, 2048}, {4096, 11008}, {4096, 11008}});
b->ArgsProduct({{1, 1024, 2048}, {4096, 11008}, {4096, 11008}});
}
BENCHMARK_CAPTURE(SGEMM, LLM, false, false, true)->Apply(GemmLLMSizeProducts)->UseRealTime();

View file

@ -4,33 +4,36 @@
#include "mlas_q4.h"
#include "mlas_qnbit.h"
#include <memory>
#include <stdexcept>
#include <vector>
#include "benchmark/benchmark.h"
#include "bench_util.h"
#include "core/util/thread_utils.h"
#include "core/common/narrow.h"
template <size_t BlkBitWidth, size_t BlkLen, bool Symmetric>
using onnxruntime::narrow;
template <size_t BlkBitWidth>
void SQNBITGEMM(benchmark::State& state) {
if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!");
if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!");
if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!");
if (state.range(3) <= 0) throw std::invalid_argument("Threads must greater than 0!");
const size_t M = static_cast<size_t>(state.range(0));
const size_t N = static_cast<size_t>(state.range(1));
const size_t K = static_cast<size_t>(state.range(2));
const size_t threads = static_cast<size_t>(state.range(3));
const auto BlkLen = narrow<size_t>(state.range(0));
const auto M = narrow<size_t>(state.range(1));
const auto N = narrow<size_t>(state.range(2));
const auto K = narrow<size_t>(state.range(3));
const auto Threads = narrow<size_t>(state.range(4));
const auto Symmetric = narrow<bool>(state.range(5));
const auto ComputeType = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(state.range(6));
size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes;
MlasBlockwiseQuantizedBufferSizes(
BlkBitWidth, BlkLen, /* columnwise */ true,
BlkBitWidth, static_cast<int>(BlkLen), /* columnwise */ true,
static_cast<int>(K), static_cast<int>(N),
QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes);
OrtThreadPoolParams tpo;
tpo.thread_pool_size = static_cast<int>(threads);
tpo.thread_pool_size = static_cast<int>(Threads);
tpo.auto_set_affinity = true;
std::unique_ptr<onnxruntime::concurrency::ThreadPool> tp(
@ -47,14 +50,29 @@ void SQNBITGEMM(benchmark::State& state) {
MlasQuantizeBlockwise<float, BlkBitWidth>(QuantBData.data(), QuantBScale.data(),
Symmetric ? nullptr : QuantBZeroPoint.data(),
B.data(), BlkLen, /* columnwise */ true,
B.data(), static_cast<int>(BlkLen), /* columnwise */ true,
static_cast<int>(K), static_cast<int>(N), static_cast<int>(N),
tp.get());
std::unique_ptr<std::byte[]> Workspace;
if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType);
WorkspaceSize > 0) {
Workspace = std::make_unique<std::byte[]>(WorkspaceSize);
}
std::unique_ptr<std::byte[]> PackedQuantBData;
if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen);
PackedQuantBDataSize > 0) {
PackedQuantBData = std::make_unique<std::byte[]>(PackedQuantBDataSize);
MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData.data(), PackedQuantBData.get(), tp.get());
}
MLAS_SQNBIT_GEMM_DATA_PARAMS params{};
params.A = A.data();
params.lda = K;
params.QuantBData = QuantBData.data();
params.QuantBData = PackedQuantBData != nullptr
? static_cast<const void*>(PackedQuantBData.get())
: static_cast<const void*>(QuantBData.data());
params.QuantBScale = QuantBScale.data();
params.QuantBZeroPoint = Symmetric ? nullptr : QuantBZeroPoint.data();
params.Bias = nullptr;
@ -62,30 +80,41 @@ void SQNBITGEMM(benchmark::State& state) {
params.ldc = N;
// warm up run
MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, &params, tp.get());
MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, &params, Workspace.get(), tp.get());
for (auto _ : state) {
MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, &params, tp.get());
MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, &params, Workspace.get(), tp.get());
}
}
static void GemmSizeProducts(benchmark::internal::Benchmark* b) {
b->ArgNames({"M", "N", "K", "Threads"});
ArgsProduct(b, {{1, 1024, 2048}, {4096, 11008}, {4096, 11008}, {8}});
static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) {
b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "ComputeType"});
ArgsProductWithFilter(b,
{{16, 32, 64, 128, 256}, // BlkLen
{1, 1024, 2048}, // M
{4096, 11008}, // N
{4096, 11008}, // K
{8}, // Threads
{int64_t{false}, int64_t{true}}, // Symmetric
{int64_t{CompFp32}, int64_t{CompInt8}}}, // ComputeType
[](const std::vector<int64_t>& args) {
return MlasIsSQNBitGemmAvailable(
// M, N, K
narrow<size_t>(args[1]), narrow<size_t>(args[2]), narrow<size_t>(args[3]),
// BlkBitWidth, BlkLen
4, narrow<size_t>(args[0]),
// ComputeType
static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(args[6]));
});
}
BENCHMARK(SQNBITGEMM<4, 16, false>)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK(SQNBITGEMM<4, 16, true>)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK(SQNBITGEMM<4, 32, false>)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK(SQNBITGEMM<4, 32, true>)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK(SQNBITGEMM<4, 64, false>)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK(SQNBITGEMM<4, 64, true>)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK(SQNBITGEMM<4, 128, false>)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK(SQNBITGEMM<4, 128, true>)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK(SQNBITGEMM<4, 256, false>)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK(SQNBITGEMM<4, 256, true>)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime();
#if defined(MLAS_JBLAS)
#ifdef MLAS_JBLAS
void Q4GEMM_Jblas(benchmark::State& state, int block_size, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE cmp_type) {
if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!");
if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!");
@ -130,6 +159,11 @@ void Q4GEMM_Jblas(benchmark::State& state, int block_size, bool is_asym, MLAS_SQ
}
}
static void GemmSizeProducts(benchmark::internal::Benchmark* b) {
b->ArgNames({"M", "N", "K", "Threads"});
b->ArgsProduct({{1, 1024, 2048}, {4096, 11008}, {4096, 11008}, {8}});
}
BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4G32SymInt8, 32, false, CompInt8)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4G128SymInt8, 128, false, CompInt8)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4GPerNSymInt8, -1, false, CompInt8)->Apply(GemmSizeProducts)->UseRealTime();
@ -137,4 +171,5 @@ BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4G32SymFp32, 32, false, CompFp32)->Apply(GemmSi
BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4G128SymFp32, 128, false, CompFp32)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4GPerNSymFp32, -1, false, CompFp32)->Apply(GemmSizeProducts)->UseRealTime();
BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4G32AsymFp32, 32, true, CompFp32)->Apply(GemmSizeProducts)->UseRealTime();
#endif
#endif // defined(MLAS_JBLAS)

View file

@ -23,10 +23,9 @@ std::vector<float> RandomVectorUniform(std::vector<int64_t> shape, float min_val
return RandomVectorUniform(static_cast<size_t>(sz), min_value, max_value);
}
// The Benchmark used here do not contains this as in newer version.
// Use the code from newer version.
void ArgsProduct(benchmark::internal::Benchmark* bench,
const std::vector<std::vector<int64_t>>& arglists) {
void ArgsProductWithFilter(benchmark::internal::Benchmark* bench,
const std::vector<std::vector<int64_t>>& arglists,
std::function<bool(const std::vector<int64_t>& args)> include_filter) {
std::vector<std::size_t> indices(arglists.size(), 0);
const std::size_t total = std::accumulate(
std::begin(arglists), std::end(arglists), std::size_t{1},
@ -39,7 +38,9 @@ void ArgsProduct(benchmark::internal::Benchmark* bench,
for (std::size_t arg = 0; arg < arglists.size(); arg++) {
args.push_back(arglists[arg][indices[arg]]);
}
bench->Args(args);
if (include_filter(args)) {
bench->Args(args);
}
args.clear();
std::size_t arg = 0;

View file

@ -5,10 +5,14 @@
#include <benchmark/benchmark.h>
#include <functional>
#include <random>
void ArgsProduct(benchmark::internal::Benchmark* bench,
const std::vector<std::vector<int64_t>>& arglists);
// Specifies benchmark arguments from the cartesian product of `arglists`, like Benchmark::ArgsProduct().
// `include_filter` is called to determine whether a given set of arguments should be included.
void ArgsProductWithFilter(benchmark::internal::Benchmark* bench,
const std::vector<std::vector<int64_t>>& arglists,
std::function<bool(const std::vector<int64_t>& args)> include_filter);
template <typename ElementType>
std::vector<ElementType> RandomVectorUniform(

View file

@ -18,6 +18,17 @@ Abstract:
#include "mlas_q4.h"
#include "mlas_qnbit.h"
static constexpr const char* ComputeTypeName(MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType) {
switch (ComputeType) {
case CompFp32:
return "Fp32";
case CompInt8:
return "Int8";
default:
return "unknown";
}
}
/**
* @brief Test class for n-bit int block quantized GEMM
* Note: only 2-D matmul supported for now
@ -26,12 +37,16 @@ template <size_t BlkBitWidth, size_t BlkLen>
class MlasSQNBitGemmTest : public MlasTestBase {
private:
MatrixGuardBuffer<float> BufferA;
MatrixGuardBuffer<int8_t> BufferQuantAData;
MatrixGuardBuffer<float> BufferQuantAScale;
MatrixGuardBuffer<float> BufferB;
MatrixGuardBuffer<uint8_t> BufferQuantBData;
MatrixGuardBuffer<std::byte> BufferPackedQuantBData;
MatrixGuardBuffer<uint8_t> BufferQuantBZeroPoint;
MatrixGuardBuffer<float> BufferQuantBScale;
MatrixGuardBuffer<float> BufferDequantizedB;
MatrixGuardBuffer<float> BufferBias;
MatrixGuardBuffer<std::byte> BufferWorkspace;
MatrixGuardBuffer<float> BufferC;
MatrixGuardBuffer<float> BufferCReference;
@ -40,12 +55,15 @@ class MlasSQNBitGemmTest : public MlasTestBase {
size_t K,
const float* A,
size_t lda,
const uint8_t* QuantBData,
const void* QuantBData,
const void* PackedQuantBData,
const float* QuantBScale,
const uint8_t* QuantBZeroPoint,
const void* QuantBZeroPoint,
const float* Bias,
float* C,
size_t ldc,
void* Workspace,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
MLAS_THREADPOOL* Threadpool) {
MLAS_SQNBIT_GEMM_DATA_PARAMS params;
params.A = A;
@ -53,23 +71,106 @@ class MlasSQNBitGemmTest : public MlasTestBase {
params.Bias = Bias;
params.C = C;
params.ldc = ldc;
params.QuantBData = QuantBData;
params.QuantBData = PackedQuantBData != nullptr ? PackedQuantBData : QuantBData;
params.QuantBScale = QuantBScale;
params.QuantBZeroPoint = QuantBZeroPoint;
params.PostProcessor = nullptr;
MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, &params, Threadpool);
MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, &params, Workspace, Threadpool);
}
void CallReferenceGemm(size_t M,
size_t N,
size_t K,
const float* A,
const uint8_t* QuantBData,
const float* QuantBScale,
const uint8_t* QuantBZeroPoint,
const float* Bias,
float* C) {
void QuantizeA(size_t M, size_t K, const float* A, int8_t* QuantAData, float* QuantAScale) {
const size_t BlockCountK = (K + BlkLen - 1) / BlkLen;
const size_t lda = K;
for (size_t m = 0; m < M; ++m) {
for (size_t k = 0, k_blk = 0; k < K; k += BlkLen, ++k_blk) {
const size_t local_blk_len = std::min(K - k, BlkLen);
float blk_a[BlkLen]{};
std::copy_n(A + m * lda + k, local_blk_len, blk_a);
float amax = 0.0f; // max of absolute values of A block
for (size_t kk = 0; kk < local_blk_len; ++kk) {
float a = blk_a[kk];
amax = std::max(amax, fabsf(a));
}
constexpr float range_max = (1 << 7) - 1;
const float scale = amax / range_max;
const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f;
QuantAScale[m * BlockCountK + k_blk] = scale;
for (size_t kk = 0; kk < BlkLen; ++kk) {
const float q = roundf(blk_a[kk] * scale_reciprocal);
QuantAData[m * BlockCountK * BlkLen + k + kk] =
static_cast<int8_t>(
std::clamp(q,
static_cast<float>(std::numeric_limits<int8_t>::min()),
static_cast<float>(std::numeric_limits<int8_t>::max())));
}
}
}
}
void CallReferenceGemm_CompInt8(size_t M,
size_t N,
size_t K,
const float* A,
const uint8_t* QuantBData,
const float* QuantBScale,
const uint8_t* QuantBZeroPoint,
const float* Bias,
float* C) {
const size_t BlockCountK = (K + BlkLen - 1) / BlkLen;
int8_t* QuantAData = BufferQuantAData.GetBuffer(M * BlockCountK * BlkLen);
float* QuantAScale = BufferQuantAScale.GetBuffer(M * BlockCountK);
QuantizeA(M, K, A, QuantAData, QuantAScale);
for (size_t m = 0; m < M; ++m) {
for (size_t n = 0; n < N; ++n) {
float sum = Bias == nullptr ? 0.0f : Bias[n];
for (size_t k = 0, k_blk = 0; k < K; k += BlkLen, ++k_blk) {
const size_t k_blk_len = std::min(K - k, BlkLen);
const float a_scale = QuantAScale[m * BlockCountK + k_blk];
const float b_scale = QuantBScale[n * BlockCountK + k_blk];
static_assert(BlkBitWidth == 4, "only implemented for 4-bit quantized B");
uint8_t b_zp = 8;
if (QuantBZeroPoint != nullptr) {
const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockCountK + 1) / 2) + k_blk / 2];
b_zp = (k_blk & 1) ? (b_zp_byte >> 4) : (b_zp_byte & 0x0F);
}
int32_t qsum = 0;
for (size_t kk = 0; kk < k_blk_len; ++kk) {
const int8_t qa = QuantAData[m * BlockCountK * BlkLen + k + kk];
const uint8_t qb_byte = QuantBData[(n * BlockCountK * BlkLen + k + kk) / 2];
const int8_t qb = ((kk & 1) == 1 ? (qb_byte >> 4) : (qb_byte & 0x0F)) - b_zp;
qsum += qa * qb;
}
sum += static_cast<float>(qsum) * a_scale * b_scale;
}
C[m * N + n] = sum;
}
}
}
void CallReferenceGemm_CompFp32(size_t M,
size_t N,
size_t K,
const float* A,
const uint8_t* QuantBData,
const float* QuantBScale,
const uint8_t* QuantBZeroPoint,
const float* Bias,
float* C) {
float* DequantizedBData = BufferDequantizedB.GetBuffer(K * N);
MlasDequantizeBlockwise<float, BlkBitWidth>(
DequantizedBData, QuantBData, QuantBScale, QuantBZeroPoint, BlkLen, /* columnwise */ true,
@ -95,6 +196,7 @@ class MlasSQNBitGemmTest : public MlasTestBase {
public:
void Test(size_t M, size_t N, size_t K,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
bool WithBias, bool Symmetric, bool WithThreadpool) {
MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr;
@ -126,7 +228,7 @@ class MlasSQNBitGemmTest : public MlasTestBase {
float* C = BufferC.GetBuffer(N * M, true);
float* CReference = BufferCReference.GetBuffer(N * M, true);
// pack B
// quantize B
uint8_t* QuantBData = nullptr;
float* QuantBScale = nullptr;
uint8_t* QuantBZeroPoint = nullptr;
@ -138,20 +240,48 @@ class MlasSQNBitGemmTest : public MlasTestBase {
QuantBData = BufferQuantBData.GetBuffer(QuantBDataSizeInBytes);
QuantBScale = BufferQuantBScale.GetBuffer(QuantBScaleSize);
if (Symmetric) {
if (!Symmetric) {
QuantBZeroPoint = BufferQuantBZeroPoint.GetBuffer(QuantBZeroPointSizeInBytes);
}
MlasQuantizeBlockwise<float, 4>(QuantBData, QuantBScale, QuantBZeroPoint,
B, BlkLen,
/* columnwise */ true,
static_cast<int>(K), static_cast<int>(N),
static_cast<int>(N),
GetMlasThreadPool());
MlasQuantizeBlockwise<float, BlkBitWidth>(QuantBData, QuantBScale, QuantBZeroPoint,
B, BlkLen,
/* columnwise */ true,
static_cast<int>(K), static_cast<int>(N),
static_cast<int>(N),
GetMlasThreadPool());
}
CallGemm(M, N, K, A, /* lda */ K, QuantBData, QuantBScale, QuantBZeroPoint, Bias, C, /* ldc */ N, Threadpool);
CallReferenceGemm(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference);
void* Workspace = nullptr;
if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType);
WorkspaceSize > 0) {
Workspace = BufferWorkspace.GetBuffer(WorkspaceSize);
}
void* PackedQuantBData = nullptr;
if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen);
PackedQuantBDataSize > 0) {
PackedQuantBData = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize);
MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData, PackedQuantBData, GetMlasThreadPool());
}
if (ComputeType == CompFp32) {
CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference);
} else if (ComputeType == CompInt8) {
CallReferenceGemm_CompInt8(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference);
} else {
FAIL() << "Test is not implemented for compute type "
<< ComputeType << " (" << ComputeTypeName(ComputeType) << ")";
}
CallGemm(M, N, K,
A, /* lda */ K,
QuantBData, PackedQuantBData, QuantBScale, QuantBZeroPoint,
Bias,
C, /* ldc */ N,
Workspace,
ComputeType,
Threadpool);
size_t f = 0;
for (size_t m = 0; m < M; m++) {
@ -179,74 +309,90 @@ template <size_t BlkBitWidth, size_t BlkLen>
class SQNBitGemmShortExecuteTest : public MlasTestFixture<MlasSQNBitGemmTest<BlkBitWidth, BlkLen>> {
public:
explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
bool WithThreadpool, bool Symmetric, bool WithBias)
: M_(M), N_(N), K_(K), WithThreadpool_(WithThreadpool), Symmetric_(Symmetric), WithBias_(WithBias) {
: M_(M),
N_(N),
K_(K),
ComputeType_(ComputeType),
WithThreadpool_(WithThreadpool),
Symmetric_(Symmetric),
WithBias_(WithBias) {
}
void TestBody() override {
MlasTestFixture<MlasSQNBitGemmTest<BlkBitWidth, BlkLen>>::mlas_tester->Test(
M_, N_, K_, WithThreadpool_, Symmetric_, WithBias_);
M_, N_, K_, ComputeType_, WithThreadpool_, Symmetric_, WithBias_);
}
static size_t RegisterSingleTest(size_t M, size_t N, size_t K,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
bool WithThreadpool, bool Symmetric, bool WithBias) {
std::stringstream ss;
ss << (WithThreadpool ? "SingleThread" : "Threaded")
<< "/isSymmetric" << Symmetric
<< "/M" << M << "xN" << N << "xK" << K
<< "/hasBias" << WithBias;
auto test_name = ss.str();
size_t tests_registered = 0;
testing::RegisterTest(
MlasSQNBitGemmTest<BlkBitWidth, BlkLen>::GetTestSuiteName(),
test_name.c_str(),
nullptr,
test_name.c_str(),
__FILE__,
__LINE__,
// Important to use the fixture type as the return type here.
[=]() -> MlasTestFixture<MlasSQNBitGemmTest<BlkBitWidth, BlkLen>>* {
return new SQNBitGemmShortExecuteTest(
M, N, K, WithThreadpool, Symmetric, WithBias);
});
if (MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType)) {
std::stringstream ss;
ss << (WithThreadpool ? "SingleThread" : "Threaded")
<< "/isSymmetric" << Symmetric
<< "/M" << M << "xN" << N << "xK" << K
<< "/hasBias" << WithBias
<< "/computeType" << ComputeTypeName(ComputeType);
auto test_name = ss.str();
return 1;
testing::RegisterTest(
MlasSQNBitGemmTest<BlkBitWidth, BlkLen>::GetTestSuiteName(),
test_name.c_str(),
nullptr,
test_name.c_str(),
__FILE__,
__LINE__,
// Important to use the fixture type as the return type here.
[=]() -> MlasTestFixture<MlasSQNBitGemmTest<BlkBitWidth, BlkLen>>* {
return new SQNBitGemmShortExecuteTest(
M, N, K, ComputeType, WithThreadpool, Symmetric, WithBias);
});
tests_registered += 1;
}
return tests_registered;
}
static size_t RegisterShortExecuteTests() {
size_t test_registered = 0;
size_t tests_registered = 0;
if (MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen)) {
for (MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType : {CompFp32, CompInt8}) {
for (bool WithThreadpool : {false, true}) {
for (bool Symmetric : {false, true}) {
for (size_t b = 1; b < 16; b++) {
test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, false);
test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, false);
tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, true);
}
for (size_t b = 16; b <= 256; b <<= 1) {
test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, false);
test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, false);
tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, true);
}
for (size_t b = 256; b < 320; b += 32) {
test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, true);
}
for (size_t b = 1; b < 96; b++) {
test_registered += RegisterSingleTest(1, b, 32, WithThreadpool, Symmetric, false);
test_registered += RegisterSingleTest(1, 32, b, WithThreadpool, Symmetric, true);
test_registered += RegisterSingleTest(1, b, b, WithThreadpool, Symmetric, false);
tests_registered += RegisterSingleTest(1, b, 32, ComputeType, WithThreadpool, Symmetric, false);
tests_registered += RegisterSingleTest(1, 32, b, ComputeType, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(1, b, b, ComputeType, WithThreadpool, Symmetric, false);
}
test_registered += RegisterSingleTest(43, 500, 401, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(43, 500, 401, ComputeType, WithThreadpool, Symmetric, true);
// test_registered += RegisterSingleTest(1001, 1027, 1031, WithThreadpool, Symmetric, false);
// tests_registered += RegisterSingleTest(1001, 1027, 1031, ComputeType, WithThreadpool, Symmetric, false);
}
}
}
return test_registered;
return tests_registered;
}
private:
size_t M_, N_, K_;
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType_;
bool WithThreadpool_, Symmetric_, WithBias_;
};