mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-24 02:47:54 +00:00
[MLAS AArch64] SQNBitGemm CompInt8 kernel (#18953)
Implement ARM NEON SQNBitGemm kernel that first block quantizes A to int8 and then does int8 multiplication.
This commit is contained in:
parent
a756017e9f
commit
150c4cb8fe
14 changed files with 1667 additions and 559 deletions
|
|
@ -1,7 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib)
|
||||
set(MLAS_ROOT ${ONNXRUNTIME_ROOT}/core/mlas)
|
||||
set(MLAS_SRC_DIR ${MLAS_ROOT}/lib)
|
||||
set(MLAS_INC_DIR ${MLAS_ROOT}/inc)
|
||||
|
||||
#
|
||||
# All hardware agnostic source files here
|
||||
|
|
@ -9,6 +11,7 @@ set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib)
|
|||
# multi-target build
|
||||
#
|
||||
onnxruntime_add_static_library(onnxruntime_mlas
|
||||
${MLAS_SRC_DIR}/mlasi.h
|
||||
${MLAS_SRC_DIR}/platform.cpp
|
||||
${MLAS_SRC_DIR}/threading.cpp
|
||||
${MLAS_SRC_DIR}/sgemm.cpp
|
||||
|
|
@ -33,9 +36,18 @@ onnxruntime_add_static_library(onnxruntime_mlas
|
|||
${MLAS_SRC_DIR}/qpostprocessor.cpp
|
||||
${MLAS_SRC_DIR}/qlgavgpool.cpp
|
||||
${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm.h
|
||||
${MLAS_SRC_DIR}/sqnbitgemm.cpp
|
||||
)
|
||||
|
||||
target_sources(onnxruntime_mlas PRIVATE
|
||||
${MLAS_INC_DIR}/mlas_float16.h
|
||||
${MLAS_INC_DIR}/mlas_gemm_postprocessor.h
|
||||
${MLAS_INC_DIR}/mlas_q4.h
|
||||
${MLAS_INC_DIR}/mlas_qnbit.h
|
||||
${MLAS_INC_DIR}/mlas.h
|
||||
)
|
||||
|
||||
if (NOT onnxruntime_ORT_MINIMAL_BUILD)
|
||||
target_sources(onnxruntime_mlas PRIVATE
|
||||
${MLAS_SRC_DIR}/q4_dq.cpp
|
||||
|
|
@ -46,7 +58,7 @@ endif()
|
|||
set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)
|
||||
|
||||
function(add_jblas)
|
||||
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
|
||||
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
|
||||
target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
|
||||
target_sources(onnxruntime_mlas PRIVATE
|
||||
${MLAS_SRC_DIR}/jblas_gemm.cpp
|
||||
|
|
@ -143,10 +155,6 @@ function(setup_mlas_source_for_windows)
|
|||
target_sources(onnxruntime_mlas PRIVATE
|
||||
${MLAS_SRC_DIR}/arm/sgemmc.cpp
|
||||
)
|
||||
# it should be removed after Visual Stuio is upgraded to 17.7
|
||||
if (MSVC)
|
||||
add_compile_options("-d2SSAOptimizer-")
|
||||
endif()
|
||||
elseif(onnxruntime_target_platform STREQUAL "x64")
|
||||
|
||||
file(GLOB_RECURSE mlas_platform_srcs_avx CONFIGURE_DEPENDS
|
||||
|
|
@ -300,8 +308,8 @@ else()
|
|||
if(APPLE)
|
||||
get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES)
|
||||
endif()
|
||||
list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGH)
|
||||
if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGH GREATER 1)
|
||||
list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH)
|
||||
if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1)
|
||||
set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE)
|
||||
endif()
|
||||
#If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below
|
||||
|
|
@ -348,6 +356,8 @@ else()
|
|||
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
|
||||
)
|
||||
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
|
||||
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
|
||||
if (NOT APPLE)
|
||||
set(mlas_platform_srcs
|
||||
${mlas_platform_srcs}
|
||||
|
|
@ -617,10 +627,12 @@ if(USE_JBLAS)
|
|||
endif()
|
||||
|
||||
foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
|
||||
target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR})
|
||||
target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR})
|
||||
onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET})
|
||||
|
||||
set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime")
|
||||
endforeach()
|
||||
set_target_properties(onnxruntime_mlas PROPERTIES FOLDER "ONNXRuntime")
|
||||
|
||||
if (WIN32)
|
||||
target_compile_options(onnxruntime_mlas PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:/wd6385>" "$<$<COMPILE_LANGUAGE:CXX>:/wd4127>")
|
||||
if (onnxruntime_ENABLE_STATIC_ANALYSIS)
|
||||
|
|
@ -636,6 +648,21 @@ if (NOT onnxruntime_BUILD_SHARED_LIB)
|
|||
FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR})
|
||||
endif()
|
||||
|
||||
# set up source group for MLAS source files
|
||||
block()
|
||||
set(source_group_srcs)
|
||||
foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
|
||||
get_target_property(mlas_target_srcs ${mlas_target} SOURCES)
|
||||
foreach(mlas_target_src ${mlas_target_srcs})
|
||||
cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root)
|
||||
if(in_mlas_root)
|
||||
list(APPEND source_group_srcs ${mlas_target_src})
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
source_group(TREE ${MLAS_ROOT} FILES ${source_group_srcs})
|
||||
endblock()
|
||||
|
||||
|
||||
if (NOT onnxruntime_ORT_MINIMAL_BUILD)
|
||||
|
||||
|
|
@ -647,7 +674,7 @@ if (NOT onnxruntime_ORT_MINIMAL_BUILD)
|
|||
onnxruntime_add_executable(onnxruntime_mlas_q4dq
|
||||
${MLAS_SRC_DIR}/q4_dq_cli.cpp
|
||||
)
|
||||
target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR})
|
||||
target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR})
|
||||
set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest")
|
||||
|
||||
target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
|
||||
|
|
|
|||
|
|
@ -64,6 +64,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
|
|||
if (!all_constant_) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if defined(MLAS_JBLAS)
|
||||
|
||||
auto compt_type = static_cast<MLAS_SQNBIT_COMPUTE_TYPE>(accuracy_level_);
|
||||
MLAS_THREADPOOL* pool = NULL;
|
||||
if (input_idx == 1) {
|
||||
|
|
@ -101,12 +104,32 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
|
|||
is_packed = true;
|
||||
}
|
||||
|
||||
#else // defined(MLAS_JBLAS)
|
||||
|
||||
if (input_idx == 1) {
|
||||
packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_);
|
||||
if (packed_b_size_ == 0) return Status::OK();
|
||||
auto qptr = tensor.DataRaw();
|
||||
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
|
||||
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, qptr, packed_b_.get());
|
||||
if (prepacked_weights) {
|
||||
prepacked_weights->buffers_.push_back(std::move(packed_b_));
|
||||
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
|
||||
}
|
||||
is_packed = true;
|
||||
}
|
||||
|
||||
#endif // defined(MLAS_JBLAS)
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
|
||||
/*out*/ bool& used_shared_buffers) {
|
||||
used_shared_buffers = false;
|
||||
|
||||
#if defined(MLAS_JBLAS)
|
||||
|
||||
// Pack three tensors into one buffer
|
||||
if (input_idx == 1) {
|
||||
used_shared_buffers = true;
|
||||
|
|
@ -120,6 +143,15 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prep
|
|||
used_shared_buffers = true;
|
||||
packed_b_ = std::move(prepacked_buffers[0]);
|
||||
}
|
||||
|
||||
#else // defined(MLAS_JBLAS)
|
||||
|
||||
if (input_idx == 1) {
|
||||
used_shared_buffers = true;
|
||||
packed_b_ = std::move(prepacked_buffers[0]);
|
||||
}
|
||||
|
||||
#endif // defined(MLAS_JBLAS)
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -129,6 +161,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
|
|||
const Tensor* a = ctx->Input<Tensor>(0);
|
||||
const auto* a_data = a->Data<float>();
|
||||
|
||||
#if defined(MLAS_JBLAS)
|
||||
|
||||
if (packed_b_.get()) {
|
||||
TensorShape b_shape({static_cast<int64_t>(N_), static_cast<int64_t>(K_)});
|
||||
|
||||
|
|
@ -158,7 +192,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
|
|||
gemm_params[i].C = y_data + helper.OutputOffsets()[i];
|
||||
gemm_params[i].ldc = N;
|
||||
}
|
||||
auto ws_size = MlasSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data());
|
||||
auto ws_size = MlasSQNBitsGemmBatchPackedBWorkspaceSize(M, N, K, max_len, gemm_params.data());
|
||||
// workspace for activation process(dynamic quantization and others)
|
||||
auto ws_ptr = IAllocator::MakeUniquePtr<int8_t>(allocator, ws_size);
|
||||
MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(),
|
||||
|
|
@ -166,10 +200,10 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
const Tensor* b = ctx->Input<Tensor>(1);
|
||||
#endif // defined(MLAS_JBLAS)
|
||||
|
||||
const Tensor* scales = ctx->Input<Tensor>(2);
|
||||
const Tensor* zero_points = ctx->Input<Tensor>(3);
|
||||
const uint8_t* b_data = b->Data<uint8_t>();
|
||||
const auto* scales_data = scales->Data<float>();
|
||||
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data<uint8_t>();
|
||||
|
||||
|
|
@ -181,8 +215,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
|
|||
Tensor* y = ctx->Output(0, helper.OutputShape());
|
||||
|
||||
// Bail out early if the output is going to be empty
|
||||
if (y->Shape().Size() == 0)
|
||||
if (y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto* y_data = y->MutableData<float>();
|
||||
|
||||
|
|
@ -192,36 +227,46 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
|
|||
const size_t K = static_cast<size_t>(helper.K());
|
||||
const size_t lda = helper.Lda(false);
|
||||
|
||||
if (MlasIsSQNBitGemmAvailable(nbits_, block_size_)) {
|
||||
// number of bytes or elements between adjacent matrices
|
||||
size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes;
|
||||
MlasBlockwiseQuantizedBufferSizes(static_cast<int>(nbits_), static_cast<int>(block_size_), /* columnwise */ true,
|
||||
static_cast<int>(K), static_cast<int>(N),
|
||||
b_data_matrix_stride_in_bytes, b_scale_matrix_stride,
|
||||
&b_zero_point_matrix_stride_in_bytes);
|
||||
const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(),
|
||||
[](size_t offset) { return offset == 0; });
|
||||
|
||||
const size_t b_matrix_size = K * N;
|
||||
if (has_single_b_matrix && packed_b_) {
|
||||
for (int64_t accuracy_level = accuracy_level_;
|
||||
accuracy_level >= static_cast<int64_t>(CompMostAccurate);
|
||||
--accuracy_level) {
|
||||
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level);
|
||||
if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) {
|
||||
IAllocatorUniquePtr<std::byte> workspace{};
|
||||
if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count,
|
||||
nbits_, block_size_, compute_type);
|
||||
workspace_size > 0) {
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
|
||||
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size);
|
||||
}
|
||||
|
||||
InlinedVector<MLAS_SQNBIT_GEMM_DATA_PARAMS> data(batch_count);
|
||||
for (size_t i = 0; i < batch_count; ++i) {
|
||||
const size_t b_matrix_offset = helper.RightOffsets()[i] / b_matrix_size;
|
||||
InlinedVector<MLAS_SQNBIT_GEMM_DATA_PARAMS> data(batch_count);
|
||||
for (size_t i = 0; i < batch_count; ++i) {
|
||||
data[i].A = a_data + helper.LeftOffsets()[i];
|
||||
data[i].lda = lda;
|
||||
data[i].QuantBData = packed_b_.get();
|
||||
data[i].QuantBScale = scales_data;
|
||||
data[i].QuantBZeroPoint = zero_points_data;
|
||||
data[i].C = y_data + helper.OutputOffsets()[i];
|
||||
data[i].ldc = N;
|
||||
}
|
||||
|
||||
data[i].A = a_data + helper.LeftOffsets()[i];
|
||||
data[i].lda = lda;
|
||||
data[i].QuantBData = b_data + b_matrix_offset * b_data_matrix_stride_in_bytes;
|
||||
data[i].QuantBScale = scales_data + b_matrix_offset * b_scale_matrix_stride;
|
||||
data[i].QuantBZeroPoint = zero_points_data != nullptr
|
||||
? zero_points_data + b_matrix_offset * b_zero_point_matrix_stride_in_bytes
|
||||
: nullptr;
|
||||
data[i].C = y_data + helper.OutputOffsets()[i];
|
||||
data[i].ldc = N;
|
||||
MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(),
|
||||
thread_pool);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, data.data(), thread_pool);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const Tensor* b = ctx->Input<Tensor>(1);
|
||||
const uint8_t* b_data = b->Data<uint8_t>();
|
||||
|
||||
const size_t ldb = helper.Ldb(true);
|
||||
|
||||
AllocatorPtr allocator;
|
||||
|
|
|
|||
|
|
@ -23,19 +23,36 @@ Abstract:
|
|||
#include "mlas.h"
|
||||
#include "mlas_gemm_postprocessor.h"
|
||||
|
||||
/**
|
||||
* @brief Define compute types of block quantization, in order of decreasing accuracy.
|
||||
*/
|
||||
typedef enum {
|
||||
CompUndef = 0, /*!< undef */
|
||||
CompFp32, /*!< input fp32, accumulator fp32 */
|
||||
CompFp16, /*!< input fp16, accumulator fp16 */
|
||||
CompBf16, /*!< input bf16, accumulator fp32 */
|
||||
CompInt8, /*!< input int8, accumulator int32 */
|
||||
|
||||
// special values that should be the first and last actual values
|
||||
|
||||
CompMostAccurate = CompUndef,
|
||||
CompLeastAccurate = CompInt8,
|
||||
} MLAS_SQNBIT_COMPUTE_TYPE;
|
||||
|
||||
using MLAS_SQNBIT_GEMM_COMPUTE_TYPE = MLAS_SQNBIT_COMPUTE_TYPE; // TODO consolidate these
|
||||
|
||||
/**
|
||||
* @brief Data parameters for float/n-bit quantized int GEMM routine.
|
||||
*/
|
||||
struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
|
||||
const float* A = nullptr; ///< address of A (float32 matrix)
|
||||
size_t lda = 0; ///< leading dimension of A
|
||||
const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values)
|
||||
const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
|
||||
const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
|
||||
bool IsBPacked = false; ///< whether B values are packed in an optimized format for the computation
|
||||
const float* Bias = nullptr; ///< optional address of Bias, vector size N
|
||||
float* C = nullptr; ///< address of result matrix
|
||||
size_t ldc = 0; ///< leading dimension of C
|
||||
const float* A = nullptr; ///< address of A (float32 matrix)
|
||||
size_t lda = 0; ///< leading dimension of A
|
||||
const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values)
|
||||
const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
|
||||
const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
|
||||
const float* Bias = nullptr; ///< optional address of Bias, vector size N
|
||||
float* C = nullptr; ///< address of result matrix
|
||||
size_t ldc = 0; ///< leading dimension of C
|
||||
|
||||
///< optional post processing to apply to result matrix
|
||||
MLAS_GEMM_POSTPROCESSOR<float>* PostProcessor = nullptr;
|
||||
|
|
@ -46,13 +63,26 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
|
|||
* A must be a float32 matrix
|
||||
* B must be a quantized and packed n-bit int matrix
|
||||
*
|
||||
* Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called.
|
||||
*
|
||||
* Call MlasSQNBitGemmPackQuantBDataSize() with the same parameters to determine whether
|
||||
* MLAS_SQNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with
|
||||
* MlasSQNBitGemmPackQuantBData().
|
||||
*
|
||||
* Call MlasSQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should
|
||||
* point to an intermediate workspace buffer.
|
||||
*
|
||||
* @param[in] M row size of matrix A and C
|
||||
* @param[in] N column size of matrix B and C
|
||||
* @param[in] K column size of matrix A and row size of matrix B
|
||||
* @param[in] BatchN number of batches
|
||||
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
|
||||
* @param[in] BlkLen number of quantized values per block
|
||||
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
|
||||
* @param[inout] DataParams An array (size BatchN) of parameter blocks
|
||||
* @param[in] Workspace Address of intermediate workspace buffer.
|
||||
If MlasSQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a
|
||||
buffer with at least that many bytes. Otherwise, it may be nullptr.
|
||||
* @param[in] ThreadPool optional thread pool to use
|
||||
*/
|
||||
void MLASCALL
|
||||
|
|
@ -63,31 +93,96 @@ MlasSQNBitGemmBatch(
|
|||
size_t BatchN,
|
||||
size_t BlkBitWidth,
|
||||
size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
|
||||
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
|
||||
void* Workspace,
|
||||
MLAS_THREADPOOL* ThreadPool = nullptr
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform.
|
||||
*
|
||||
* @param[in] M row size of matrix A and C
|
||||
* @param[in] N column size of matrix B and C
|
||||
* @param[in] K column size of matrix A and row size of matrix B
|
||||
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
|
||||
* @param[in] BlkLen number of quantized values per block
|
||||
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
|
||||
*/
|
||||
bool MLASCALL
|
||||
MlasIsSQNBitGemmAvailable(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BlkBitWidth,
|
||||
size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Gets the size in bytes of the intermediate workspace buffer required by the float32/quantized n-bit int GEMM
|
||||
* implementation. If zero, no intermediate workspace is required.
|
||||
*
|
||||
* @param[in] M row size of matrix A and C
|
||||
* @param[in] N column size of matrix B and C
|
||||
* @param[in] K column size of matrix A and row size of matrix B
|
||||
* @param[in] BatchN number of batches
|
||||
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
|
||||
* @param[in] BlkLen number of quantized values per block
|
||||
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
|
||||
*/
|
||||
size_t MLASCALL
|
||||
MlasSQNBitGemmBatchWorkspaceSize(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BatchN,
|
||||
size_t BlkBitWidth,
|
||||
size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Gets the size in bytes of the packed quantized B data.
|
||||
* If non-zero, the quantized B data must first be packed by calling MlasSQNBitGemmPackQuantBData() with a buffer of
|
||||
* this size, and then that packed quantized B data buffer must be passed to MlasSQNBitGemmBatch().
|
||||
* If zero, MlasSQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to
|
||||
* MlasSQNBitGemmBatch().
|
||||
*
|
||||
* @param[in] N column size of matrix B and C
|
||||
* @param[in] K column size of matrix A and row size of matrix B
|
||||
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
|
||||
* @param[in] BlkLen number of quantized values per block
|
||||
*/
|
||||
size_t MLASCALL
|
||||
MlasSQNBitGemmPackQuantBDataSize(
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BlkBitWidth,
|
||||
size_t BlkLen
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Define compute types of block quantization
|
||||
* @brief Packs the quantized B data in a format that the kernel expects.
|
||||
*
|
||||
* @param[in] N column size of matrix B and C
|
||||
* @param[in] K column size of matrix A and row size of matrix B
|
||||
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
|
||||
* @param[in] BlkLen number of quantized values per block
|
||||
* @param[in] QuantBData quantized B data
|
||||
* @param[out] PackedQuantBData packed quantized B data
|
||||
* @param[in] ThreadPool optional thread pool to use
|
||||
*/
|
||||
typedef enum {
|
||||
CompUndef = 0, /*!< undef */
|
||||
CompFp32 = 1, /*!< input fp32, accumulator fp32 */
|
||||
CompFp16 = 2, /*!< input fp16, accumulator fp16 */
|
||||
CompBf16 = 3, /*!< input bf16, accumulator fp32 */
|
||||
CompInt8 = 4 /*!< input int8, accumulator int32 */
|
||||
} MLAS_SQNBIT_COMPUTE_TYPE;
|
||||
void MLASCALL
|
||||
MlasSQNBitGemmPackQuantBData(
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BlkBitWidth,
|
||||
size_t BlkLen,
|
||||
const void* QuantBData,
|
||||
void* PackedQuantBData,
|
||||
MLAS_THREADPOOL* ThreadPool = nullptr
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Data parameters for NBits GEMM routine
|
||||
|
|
@ -139,7 +234,7 @@ MlasNBitsGemmPackBSize(
|
|||
* @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor
|
||||
* one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where
|
||||
* they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up
|
||||
* inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale
|
||||
* inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale
|
||||
* (is_asym is false) and Zp(is_asym is true).
|
||||
* @param thread_pool
|
||||
*/
|
||||
|
|
@ -186,7 +281,7 @@ MlasNBitsGemmUnPackB(
|
|||
* @return Workspace size in bytes
|
||||
*/
|
||||
size_t MLASCALL
|
||||
MlasSQNBitsGemmBatchWorkspaceSize(
|
||||
MlasSQNBitsGemmBatchPackedBWorkspaceSize(
|
||||
const size_t M,
|
||||
const size_t N,
|
||||
const size_t K,
|
||||
|
|
|
|||
|
|
@ -482,7 +482,6 @@ Return Value:
|
|||
this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon;
|
||||
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon;
|
||||
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
|
||||
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
|
||||
|
||||
//
|
||||
// Check if the processor supports ASIMD dot product instructions.
|
||||
|
|
@ -512,6 +511,9 @@ Return Value:
|
|||
this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot;
|
||||
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot;
|
||||
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot;
|
||||
|
||||
// MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions
|
||||
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
|
||||
}
|
||||
|
||||
#if defined(__linux__)
|
||||
|
|
|
|||
|
|
@ -11,10 +11,14 @@ Module Name:
|
|||
Abstract:
|
||||
|
||||
This module implements the float/quantized n-bit integer matrix
|
||||
multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch.
|
||||
multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch,
|
||||
as well as some SQNBitGemm-related query functions.
|
||||
--*/
|
||||
|
||||
#include "sqnbitgemm.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#ifdef MLAS_JBLAS
|
||||
#include "jblas_gemm.h"
|
||||
#endif
|
||||
|
|
@ -22,29 +26,564 @@ Abstract:
|
|||
namespace
|
||||
{
|
||||
|
||||
// Get quantization variant based on `BlkBitWidth` and `BlkLen`.
|
||||
// Return -1 if the input values are unsupported.
|
||||
int32_t
|
||||
GetDispatchQuantVariant(size_t BlkBitWidth, size_t BlkLen)
|
||||
enum SQNBitGemmVariant {
|
||||
SQNBitGemmVariantInvalid = -1,
|
||||
|
||||
// Valid variants
|
||||
|
||||
SQNBitGemmVariant_BitWidth4_CompFp32 = 0,
|
||||
SQNBitGemmVariant_BitWidth4_CompInt8,
|
||||
|
||||
// End of valid variants
|
||||
|
||||
// Keep this element last and ensure that its value is the number of valid SQNBitGemmVariant values.
|
||||
// Its value is used as an array size.
|
||||
SQNBitGemmVariantCount,
|
||||
};
|
||||
|
||||
SQNBitGemmVariant
|
||||
GetSQNBitGemmVariant(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BlkBitWidth,
|
||||
size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
||||
)
|
||||
{
|
||||
int32_t type = -1;
|
||||
if (BlkBitWidth == 4 && BlkLen == 16) {
|
||||
type = QuantVariant_BitWidth4_BlockSize16;
|
||||
} else if (BlkBitWidth == 4 && BlkLen == 32) {
|
||||
type = QuantVariant_BitWidth4_BlockSize32;
|
||||
} else if (BlkBitWidth == 4 && BlkLen == 64) {
|
||||
type = QuantVariant_BitWidth4_BlockSize64;
|
||||
} else if (BlkBitWidth == 4 && BlkLen == 128) {
|
||||
type = QuantVariant_BitWidth4_BlockSize128;
|
||||
} else if (BlkBitWidth == 4 && BlkLen == 256) {
|
||||
type = QuantVariant_BitWidth4_BlockSize256;
|
||||
MLAS_UNREFERENCED_PARAMETER(N);
|
||||
MLAS_UNREFERENCED_PARAMETER(K);
|
||||
|
||||
if (BlkBitWidth == 4 &&
|
||||
(BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) {
|
||||
if (ComputeType == CompFp32 ||
|
||||
ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32
|
||||
return SQNBitGemmVariant_BitWidth4_CompFp32;
|
||||
} else if (ComputeType == CompInt8 && M == 1) {
|
||||
return SQNBitGemmVariant_BitWidth4_CompInt8;
|
||||
}
|
||||
}
|
||||
|
||||
return type;
|
||||
return SQNBitGemmVariantInvalid;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool MLASCALL
|
||||
MlasIsSQNBitGemmAvailable(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BlkBitWidth,
|
||||
size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
||||
)
|
||||
{
|
||||
const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch;
|
||||
if (Dispatch == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType);
|
||||
|
||||
switch (Variant) {
|
||||
case SQNBitGemmVariant_BitWidth4_CompFp32: {
|
||||
return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr &&
|
||||
Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr;
|
||||
}
|
||||
case SQNBitGemmVariant_BitWidth4_CompInt8: {
|
||||
return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr &&
|
||||
Dispatch->QuantizeARow_CompInt8 != nullptr;
|
||||
}
|
||||
default: {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
size_t
|
||||
SQNBitGemmWorkspaceAlignment(SQNBitGemmVariant Variant)
|
||||
{
|
||||
switch (Variant) {
|
||||
case SQNBitGemmVariant_BitWidth4_CompInt8: {
|
||||
return Q8BlkAlignment();
|
||||
}
|
||||
default: {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t
|
||||
SQNBitGemmPerGemmWorkspaceSize(
|
||||
SQNBitGemmVariant Variant,
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BlkLen
|
||||
)
|
||||
{
|
||||
MLAS_UNREFERENCED_PARAMETER(N);
|
||||
|
||||
switch (Variant) {
|
||||
case SQNBitGemmVariant_BitWidth4_CompInt8: {
|
||||
// workspace buffer is used for block quantization of A to int8
|
||||
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
|
||||
const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen);
|
||||
return PerGemmWorkspaceSize;
|
||||
}
|
||||
default: {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t
|
||||
SQNBitGemmPerGemmWorkspaceStride(
|
||||
SQNBitGemmVariant Variant,
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BlkLen
|
||||
)
|
||||
{
|
||||
const auto Size = SQNBitGemmPerGemmWorkspaceSize(Variant, M, N, K, BlkLen);
|
||||
const auto Alignment = SQNBitGemmWorkspaceAlignment(Variant);
|
||||
return MlasDivRoundup(Size, Alignment) * Alignment;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
size_t MLASCALL
|
||||
MlasSQNBitGemmBatchWorkspaceSize(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BatchN,
|
||||
size_t BlkBitWidth,
|
||||
size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
||||
)
|
||||
{
|
||||
const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType);
|
||||
|
||||
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen);
|
||||
if (PerGemmWorkspaceStride == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant);
|
||||
|
||||
const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride;
|
||||
|
||||
return WorkspaceSize + Alignment - 1;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
void
|
||||
SQ4BitGemmPackQuantBData(
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BlkLen,
|
||||
const std::byte* QuantBDataBegin,
|
||||
std::byte* PackedQuantBDataBegin,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
{
|
||||
constexpr size_t BlkBitWidth = 4;
|
||||
|
||||
assert(BlkLen % 16 == 0);
|
||||
|
||||
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
|
||||
const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
||||
const size_t Iterations = N * BlockCountK; // one iteration per block
|
||||
|
||||
MlasTrySimpleParallel(
|
||||
ThreadPool, Iterations,
|
||||
[&](ptrdiff_t tid) {
|
||||
const size_t n = tid / BlockCountK;
|
||||
const size_t k_blk = tid % BlockCountK;
|
||||
|
||||
const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize;
|
||||
const std::byte* QuantBData = QuantBDataBegin + data_offset;
|
||||
std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset;
|
||||
|
||||
//
|
||||
// Pack 16 4-bit values (8 bytes) at a time like this:
|
||||
//
|
||||
// src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF |
|
||||
// =>
|
||||
// dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF |
|
||||
//
|
||||
for (size_t kk = 0; kk < BlkLen; kk += 16) {
|
||||
for (size_t byte_pair_idx = 0; byte_pair_idx < 4; ++byte_pair_idx) {
|
||||
const std::byte src0 = QuantBData[byte_pair_idx];
|
||||
const std::byte src1 = QuantBData[byte_pair_idx + 4];
|
||||
|
||||
std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx];
|
||||
std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1];
|
||||
|
||||
dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4);
|
||||
dst1 = (src0 >> 4) | ((src1 >> 4) << 4);
|
||||
}
|
||||
|
||||
QuantBData += 8;
|
||||
PackedQuantBData += 8;
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
size_t MLASCALL
|
||||
MlasSQNBitGemmPackQuantBDataSize(
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BlkBitWidth,
|
||||
size_t BlkLen
|
||||
)
|
||||
{
|
||||
// Ensure that a general implementation is available on this platform.
|
||||
// For now, all implementations share the same packed format.
|
||||
{
|
||||
// Currently, there are implementations specific to M = 1, so pick a more general M > 1.
|
||||
constexpr size_t M = 2;
|
||||
// A CompUndef implementation should be available if any is available.
|
||||
constexpr MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType = CompUndef;
|
||||
const bool HasGeneralImplementation =
|
||||
MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType);
|
||||
if (!HasGeneralImplementation) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (BlkBitWidth == 4) {
|
||||
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
|
||||
const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
||||
return PackedQuantBDataSize;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void MLASCALL
|
||||
MlasSQNBitGemmPackQuantBData(
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BlkBitWidth,
|
||||
size_t BlkLen,
|
||||
const void* QuantBData,
|
||||
void* PackedQuantBData,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
{
|
||||
if (BlkBitWidth == 4) {
|
||||
SQ4BitGemmPackQuantBData(
|
||||
N,
|
||||
K,
|
||||
BlkLen,
|
||||
static_cast<const std::byte*>(QuantBData),
|
||||
static_cast<std::byte*>(PackedQuantBData),
|
||||
ThreadPool
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
MLAS_FORCEINLINE void
|
||||
AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc)
|
||||
{
|
||||
for (size_t m = 0; m < CountM; m++) {
|
||||
const float* bias = Bias;
|
||||
float* sum = C;
|
||||
for (size_t n = 0; n < CountN; n += 4) {
|
||||
if (CountN - n < 4) {
|
||||
for (size_t nn = n; nn < CountN; nn++) {
|
||||
*sum += *bias;
|
||||
sum++;
|
||||
bias++;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum);
|
||||
acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias));
|
||||
MlasStoreFloat32x4(sum, acc_x);
|
||||
bias += 4;
|
||||
sum += 4;
|
||||
}
|
||||
C += ldc;
|
||||
}
|
||||
}
|
||||
|
||||
typedef void(SQNBitGemmFn)(
|
||||
size_t BlkLen,
|
||||
size_t K,
|
||||
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
|
||||
void* PerGemmWorkspace,
|
||||
size_t RangeStartM,
|
||||
size_t RangeCountM,
|
||||
size_t RangeStartN,
|
||||
size_t RangeCountN
|
||||
);
|
||||
|
||||
void
|
||||
SQ4BitGemm_CompFp32(
|
||||
const size_t BlkLen,
|
||||
const size_t K,
|
||||
const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams,
|
||||
void* const PerGemmWorkspace,
|
||||
const size_t RangeStartM,
|
||||
const size_t RangeCountM,
|
||||
const size_t RangeStartN,
|
||||
const size_t RangeCountN
|
||||
)
|
||||
{
|
||||
constexpr size_t BlkBitWidth = 4;
|
||||
|
||||
MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace);
|
||||
|
||||
const size_t lda = DataParams->lda;
|
||||
const size_t ldc = DataParams->ldc;
|
||||
|
||||
const size_t k_blks = MlasDivRoundup(K, BlkLen);
|
||||
const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
||||
const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(k_blks);
|
||||
|
||||
const float* A = DataParams->A + RangeStartM * lda;
|
||||
|
||||
const std::byte* QuantBData = static_cast<const std::byte*>(DataParams->QuantBData) + RangeStartN * ldb;
|
||||
const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks;
|
||||
const std::byte* QuantBZeroPoint =
|
||||
(DataParams->QuantBZeroPoint == nullptr)
|
||||
? nullptr
|
||||
: static_cast<const std::byte*>(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes;
|
||||
|
||||
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;
|
||||
|
||||
const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;
|
||||
|
||||
if (RangeCountM == 1) {
|
||||
size_t CountN;
|
||||
for (size_t n = 0; n < RangeCountN; n += CountN) {
|
||||
CountN = std::min(RangeCountN - n, size_t{128});
|
||||
|
||||
const float* a_row = A;
|
||||
const std::byte* b_col = QuantBData + n * ldb;
|
||||
const float* b_col_scale = QuantBScale + n * k_blks;
|
||||
const std::byte* b_col_zp =
|
||||
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
|
||||
float* c_blk = C + n;
|
||||
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
|
||||
|
||||
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32(
|
||||
BlkLen,
|
||||
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
|
||||
);
|
||||
|
||||
if (DataParams->PostProcessor != nullptr) {
|
||||
DataParams->PostProcessor->Process(
|
||||
DataParams->C, RangeStartM, RangeStartN + n,
|
||||
RangeCountM, CountN, ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr size_t StrideN = 32;
|
||||
size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float);
|
||||
MlasThreadedBufAlloc(bufsize);
|
||||
auto* dequant_b = reinterpret_cast<float*>(ThreadedBufHolder.get());
|
||||
|
||||
//
|
||||
// Step through each slice of matrix B along the N dimension.
|
||||
//
|
||||
size_t CountN;
|
||||
for (size_t n = 0; n < RangeCountN; n += CountN) {
|
||||
CountN = std::min(RangeCountN - n, StrideN);
|
||||
|
||||
//
|
||||
// Step through each slice of matrix A along the M dimension.
|
||||
//
|
||||
const float* a_row = A;
|
||||
const std::byte* b_col = QuantBData + n * ldb;
|
||||
const float* b_col_scale = QuantBScale + n * k_blks;
|
||||
const std::byte* b_col_zp =
|
||||
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
|
||||
float* c_blk = C + n;
|
||||
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
|
||||
|
||||
GetMlasPlatform().SQNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32(
|
||||
BlkLen,
|
||||
dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks
|
||||
);
|
||||
|
||||
size_t RowsRemaining = RangeCountM;
|
||||
while (RowsRemaining > 0) {
|
||||
#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)
|
||||
auto RowsHandled = GetMlasPlatform().GemmFloatKernel(
|
||||
a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true
|
||||
);
|
||||
#else
|
||||
auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f);
|
||||
#endif
|
||||
|
||||
if (bias) {
|
||||
AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc);
|
||||
}
|
||||
if (DataParams->PostProcessor != nullptr) {
|
||||
DataParams->PostProcessor->Process(
|
||||
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN,
|
||||
RowsHandled, CountN, ldc
|
||||
);
|
||||
}
|
||||
|
||||
c_blk += ldc * RowsHandled;
|
||||
a_row += lda * RowsHandled;
|
||||
RowsRemaining -= RowsHandled;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
SQ4BitGemm_CompInt8(
|
||||
const size_t BlkLen,
|
||||
const size_t K,
|
||||
const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams,
|
||||
void* const PerGemmWorkspace,
|
||||
const size_t RangeStartM,
|
||||
const size_t RangeCountM,
|
||||
const size_t RangeStartN,
|
||||
const size_t RangeCountN
|
||||
)
|
||||
{
|
||||
constexpr size_t BlkBitWidth = 4;
|
||||
|
||||
const size_t k_blks = MlasDivRoundup(K, BlkLen);
|
||||
|
||||
const size_t lda = k_blks * Q8BlkSize(BlkLen);
|
||||
const size_t ldc = DataParams->ldc;
|
||||
const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
||||
const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(k_blks);
|
||||
|
||||
const std::byte* QuantA = static_cast<const std::byte*>(PerGemmWorkspace) + RangeStartM * lda;
|
||||
|
||||
const std::byte* QuantBData = static_cast<const std::byte*>(DataParams->QuantBData) + RangeStartN * ldb;
|
||||
const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks;
|
||||
const std::byte* QuantBZeroPoint =
|
||||
(DataParams->QuantBZeroPoint == nullptr)
|
||||
? nullptr
|
||||
: static_cast<const std::byte*>(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes;
|
||||
|
||||
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;
|
||||
|
||||
const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;
|
||||
|
||||
if (RangeCountM == 1) {
|
||||
size_t CountN;
|
||||
for (size_t n = 0; n < RangeCountN; n += CountN) {
|
||||
CountN = std::min(RangeCountN - n, size_t{128});
|
||||
|
||||
const std::byte* a_row = QuantA;
|
||||
const std::byte* b_col = QuantBData + n * ldb;
|
||||
const float* b_col_scale = QuantBScale + n * k_blks;
|
||||
const std::byte* b_col_zp =
|
||||
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
|
||||
float* c_blk = C + n;
|
||||
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
|
||||
|
||||
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8(
|
||||
BlkLen,
|
||||
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
|
||||
);
|
||||
|
||||
if (DataParams->PostProcessor != nullptr) {
|
||||
DataParams->PostProcessor->Process(
|
||||
DataParams->C, RangeStartM, RangeStartN + n,
|
||||
RangeCountM, CountN, ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
assert(false && "not implemented for M > 1");
|
||||
}
|
||||
|
||||
typedef void(InitializeWorkspaceFn)(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BatchN,
|
||||
size_t BlkLen,
|
||||
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
|
||||
void* Workspace,
|
||||
size_t PerGemmWorkspaceStride,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
);
|
||||
|
||||
void
|
||||
InitializeWorkspace_CompInt8(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BatchN,
|
||||
size_t BlkLen,
|
||||
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
|
||||
void* Workspace,
|
||||
size_t PerGemmWorkspaceStride,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
{
|
||||
MLAS_UNREFERENCED_PARAMETER(N);
|
||||
|
||||
const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8;
|
||||
|
||||
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
|
||||
const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen);
|
||||
|
||||
MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) {
|
||||
const auto& data = DataParams[gemm_idx];
|
||||
|
||||
const float* ARowPtr = data.A;
|
||||
std::byte* QuantARowPtr = static_cast<std::byte*>(Workspace) + gemm_idx * PerGemmWorkspaceStride;
|
||||
|
||||
for (size_t m = 0; m < M; ++m) {
|
||||
QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr);
|
||||
|
||||
ARowPtr += data.lda;
|
||||
QuantARowPtr += QuantAStride;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
struct Operations {
|
||||
InitializeWorkspaceFn* InitializeWorkspace = nullptr;
|
||||
SQNBitGemmFn* SQNBitGemm = nullptr;
|
||||
};
|
||||
|
||||
constexpr auto OperationMap = []() {
|
||||
std::array<Operations, SQNBitGemmVariantCount> ops;
|
||||
|
||||
ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32;
|
||||
|
||||
ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8;
|
||||
ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8;
|
||||
|
||||
return ops;
|
||||
}();
|
||||
|
||||
} // namespace
|
||||
|
||||
void MLASCALL
|
||||
MlasSQNBitGemmBatch(
|
||||
const size_t M,
|
||||
|
|
@ -53,17 +592,43 @@ MlasSQNBitGemmBatch(
|
|||
const size_t BatchN,
|
||||
const size_t BlkBitWidth,
|
||||
const size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
|
||||
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
|
||||
void* Workspace,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
{
|
||||
const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen);
|
||||
MLAS_SQNBIT_GEMM_OPERATION* const Operation = GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant];
|
||||
const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType);
|
||||
assert(Variant != SQNBitGemmVariantInvalid);
|
||||
|
||||
//
|
||||
// Ensure `Workspace` has correct alignment.
|
||||
//
|
||||
if (Workspace != nullptr) {
|
||||
const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant);
|
||||
const uintptr_t WorkspaceAddress = reinterpret_cast<uintptr_t>(Workspace);
|
||||
Workspace = reinterpret_cast<void*>(
|
||||
(WorkspaceAddress + Alignment - 1) & (~(Alignment - 1))
|
||||
);
|
||||
}
|
||||
|
||||
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen);
|
||||
|
||||
if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace;
|
||||
InitializeWorkspaceOperation != nullptr) {
|
||||
InitializeWorkspaceOperation(
|
||||
M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool
|
||||
);
|
||||
}
|
||||
|
||||
const auto ComputeOperation = OperationMap[Variant].SQNBitGemm;
|
||||
|
||||
if (ThreadPool == nullptr) {
|
||||
for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) {
|
||||
auto Data = &DataParams[gemm_i];
|
||||
Operation(K, Data, 0, M, 0, N);
|
||||
const auto* Data = &DataParams[gemm_i];
|
||||
void* PerGemmWorkspace =
|
||||
reinterpret_cast<std::byte*>(Workspace) + gemm_i * PerGemmWorkspaceStride;
|
||||
ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
@ -112,7 +677,10 @@ MlasSQNBitGemmBatch(
|
|||
MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) {
|
||||
const auto gemm_i = tid / ThreadsPerGemm;
|
||||
const auto blk_i = tid % ThreadsPerGemm;
|
||||
auto Data = &DataParams[gemm_i];
|
||||
const auto* Data = &DataParams[gemm_i];
|
||||
void* PerGemmWorkspace = reinterpret_cast<void*>(
|
||||
reinterpret_cast<std::byte*>(Workspace) + gemm_i * PerGemmWorkspaceStride
|
||||
);
|
||||
|
||||
const ptrdiff_t ThreadIdN = blk_i / ThreadCountM;
|
||||
const ptrdiff_t ThreadIdM = blk_i % ThreadCountM;
|
||||
|
|
@ -123,29 +691,10 @@ MlasSQNBitGemmBatch(
|
|||
const size_t RangeStartN = ThreadIdN * StrideN;
|
||||
const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN);
|
||||
|
||||
Operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN);
|
||||
ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN);
|
||||
});
|
||||
}
|
||||
|
||||
bool MLASCALL
|
||||
MlasIsSQNBitGemmAvailable(
|
||||
size_t BlkBitWidth,
|
||||
size_t BlkLen
|
||||
)
|
||||
{
|
||||
const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen);
|
||||
if (QuantVariant == -1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (GetMlasPlatform().SQNBitGemmDispatch == nullptr ||
|
||||
GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant] == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t MLASCALL
|
||||
MlasNBitsGemmPackBSize(
|
||||
size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType
|
||||
|
|
@ -224,7 +773,7 @@ MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, s
|
|||
}
|
||||
|
||||
size_t MLASCALL
|
||||
MlasSQNBitsGemmBatchWorkspaceSize(
|
||||
MlasSQNBitsGemmBatchPackedBWorkspaceSize(
|
||||
const size_t M,
|
||||
const size_t N,
|
||||
const size_t K,
|
||||
|
|
|
|||
|
|
@ -10,98 +10,23 @@ Module Name:
|
|||
|
||||
Abstract:
|
||||
|
||||
This module includes:
|
||||
This module includes kernel function prototypes and helper functions for
|
||||
implementing SQNBitGemm.
|
||||
|
||||
- Declaration of the set of template functions used to implement a kernel
|
||||
for a matrix/matrix multiplication, A*B, where A is a float matrix and B is
|
||||
a n-bit quantized integer matrix (QNBitGemm).
|
||||
|
||||
- A shared kernel driver function template, MlasSQNBitGemmOperation.
|
||||
|
||||
- Kernel dispatch structure.
|
||||
|
||||
The B matrix is block quantized, which means that its values are grouped
|
||||
into blocks which each have one scale and optional zero point. Each
|
||||
quantized value in B is n-bits wide.
|
||||
SQNBitGemm is a matrix/matrix multiplication, A*B, where A is a float
|
||||
matrix and B is a n-bit quantized integer matrix. B is block quantized,
|
||||
meaning values of B are divided into blocks and each block has its own
|
||||
scale and optional zero point.
|
||||
|
||||
--*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlas_qnbit.h"
|
||||
#include "mlasi.h"
|
||||
|
||||
//
|
||||
// Kernel implementation template declarations
|
||||
//
|
||||
|
||||
/**
|
||||
* @brief Multiply float matrix A with quantized n-bit integer matrix B.
|
||||
* B is block quantized and column major.
|
||||
* This kernel handles the special case where M, the number of rows of A and C, is 1.
|
||||
*
|
||||
* @tparam BlkBitWidth Bit width of each value in a block.
|
||||
* @tparam BlkLen Number of values in a block.
|
||||
* @tparam KernelType Hardware-specific kernel type.
|
||||
*
|
||||
* @param A Supplies the A matrix.
|
||||
* @param QuantBData Supplies the quantized B matrix block data.
|
||||
* @param QuantBScale Supplies the quantized B matrix block scale values.
|
||||
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
|
||||
* @param[out] C Supplies the output C matrix.
|
||||
* @param CountN Number of columns of B and C.
|
||||
* @param CountK Number of columns of A and rows of B.
|
||||
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
|
||||
* @param Bias Bias vector of length N.
|
||||
*/
|
||||
template <size_t BlkBitWidth, size_t BlkLen, typename KernelType>
|
||||
MLAS_FORCEINLINE void
|
||||
MlasSQNBitGemmM1Kernel(
|
||||
const float* A,
|
||||
const uint8_t* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const uint8_t* QuantBZeroPoint,
|
||||
float* C,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockStrideQuantB,
|
||||
const float* Bias
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Dequantize B into the format expected by the Sgemm kernel.
|
||||
* B is block quantized and column major.
|
||||
* This is equivalent to dequantizing B and then running
|
||||
* MlasSgemmCopyPackB.
|
||||
*
|
||||
* @tparam BlkBitWidth Bit width of each value in a block.
|
||||
* @tparam BlkLen Number of values in a block.
|
||||
* @tparam KernelType Hardware-specific kernel type.
|
||||
*
|
||||
* @param[out] FpData Supplies the output buffer for the dequantized B float data.
|
||||
* @param QuantBData Supplies the quantized B matrix block data.
|
||||
* @param QuantBScale Supplies the quantized B matrix block scale values.
|
||||
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
|
||||
* @param CountN Number of columns of B.
|
||||
* @param CountK Number of rows of B.
|
||||
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
|
||||
*/
|
||||
template <size_t BlkBitWidth, size_t BlkLen, typename KernelType>
|
||||
MLAS_FORCEINLINE void
|
||||
MlasQNBitBlkDequantBForSgemm(
|
||||
float* FpData,
|
||||
const uint8_t* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const uint8_t* QuantBZeroPoint,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockStrideQuantB
|
||||
);
|
||||
|
||||
//
|
||||
// MlasQNBitGemmOperation and helpers
|
||||
//
|
||||
|
||||
constexpr MLAS_FORCEINLINE size_t
|
||||
MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen)
|
||||
{
|
||||
|
|
@ -119,169 +44,174 @@ MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount)
|
|||
}
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE void
|
||||
MlasAddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc)
|
||||
{
|
||||
for (size_t m = 0; m < CountM; m++) {
|
||||
const float* bias = Bias;
|
||||
float* sum = C;
|
||||
for (size_t n = 0; n < CountN; n += 4) {
|
||||
if (CountN - n < 4) {
|
||||
for (size_t nn = n; nn < CountN; nn++) {
|
||||
*sum += *bias;
|
||||
sum++;
|
||||
bias++;
|
||||
}
|
||||
break;
|
||||
}
|
||||
//
|
||||
// Quantized int8 block helpers.
|
||||
//
|
||||
|
||||
MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum);
|
||||
acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias));
|
||||
MlasStoreFloat32x4(sum, acc_x);
|
||||
bias += 4;
|
||||
sum += 4;
|
||||
}
|
||||
C += ldc;
|
||||
}
|
||||
MLAS_FORCEINLINE
|
||||
const float&
|
||||
Q8BlkScale(const std::byte* BlkPtr)
|
||||
{
|
||||
return *reinterpret_cast<const float*>(BlkPtr);
|
||||
}
|
||||
|
||||
template <size_t BlkBitWidth, size_t BlkLen, typename KernelType>
|
||||
MLAS_FORCEINLINE void MLASCALL
|
||||
MlasSQNBitGemmOperation(
|
||||
const size_t K,
|
||||
const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams,
|
||||
const size_t RangeStartM,
|
||||
const size_t RangeCountM,
|
||||
const size_t RangeStartN,
|
||||
const size_t RangeCountN
|
||||
)
|
||||
MLAS_FORCEINLINE
|
||||
float&
|
||||
Q8BlkScale(std::byte* BlkPtr)
|
||||
{
|
||||
const size_t lda = DataParams->lda;
|
||||
const size_t ldc = DataParams->ldc;
|
||||
return *reinterpret_cast<float*>(BlkPtr);
|
||||
}
|
||||
|
||||
const size_t k_blks = MlasDivRoundup(K, BlkLen);
|
||||
const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
||||
const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(k_blks);
|
||||
MLAS_FORCEINLINE
|
||||
const int8_t*
|
||||
Q8BlkData(const std::byte* BlkPtr)
|
||||
{
|
||||
return reinterpret_cast<const int8_t*>(BlkPtr + sizeof(float));
|
||||
}
|
||||
|
||||
const float* A = DataParams->A + RangeStartM * lda;
|
||||
MLAS_FORCEINLINE
|
||||
int8_t*
|
||||
Q8BlkData(std::byte* BlkPtr)
|
||||
{
|
||||
return reinterpret_cast<int8_t*>(BlkPtr + sizeof(float));
|
||||
}
|
||||
|
||||
const uint8_t* QuantBData = static_cast<const uint8_t*>(DataParams->QuantBData) + RangeStartN * ldb;
|
||||
const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks;
|
||||
const uint8_t* QuantBZeroPoint =
|
||||
(DataParams->QuantBZeroPoint == nullptr)
|
||||
? nullptr
|
||||
: static_cast<const uint8_t*>(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes;
|
||||
MLAS_FORCEINLINE
|
||||
constexpr size_t
|
||||
Q8BlkSize(size_t BlkLen)
|
||||
{
|
||||
const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t);
|
||||
// Currently, the strictest alignment requirement of a block is for a float.
|
||||
// Ensure contiguous blocks are suitably aligned.
|
||||
assert(BlkSize % alignof(float) == 0);
|
||||
return BlkSize;
|
||||
}
|
||||
|
||||
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;
|
||||
|
||||
const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;
|
||||
|
||||
if (RangeCountM == 1) {
|
||||
size_t CountN;
|
||||
for (size_t n = 0; n < RangeCountN; n += CountN) {
|
||||
CountN = std::min(RangeCountN - n, size_t{128});
|
||||
|
||||
const float* a_row = A;
|
||||
const uint8_t* b_col = QuantBData + n * ldb;
|
||||
const float* b_col_scale = QuantBScale + n * k_blks;
|
||||
const uint8_t* b_col_zp =
|
||||
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
|
||||
float* c_blk = C + n;
|
||||
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
|
||||
|
||||
MlasSQNBitGemmM1Kernel<BlkBitWidth, BlkLen, KernelType>(
|
||||
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
|
||||
);
|
||||
|
||||
if (DataParams->PostProcessor != nullptr) {
|
||||
DataParams->PostProcessor->Process(
|
||||
DataParams->C, RangeStartM, RangeStartN + n,
|
||||
RangeCountM, CountN, ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr size_t StrideN = 32;
|
||||
size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float);
|
||||
MlasThreadedBufAlloc(bufsize);
|
||||
auto* dequant_b = reinterpret_cast<float*>(ThreadedBufHolder.get());
|
||||
//
|
||||
// Step through each slice of matrix B along the N dimension.
|
||||
//
|
||||
|
||||
size_t CountN;
|
||||
for (size_t n = 0; n < RangeCountN; n += CountN) {
|
||||
CountN = std::min(RangeCountN - n, StrideN);
|
||||
|
||||
//
|
||||
// Step through each slice of matrix A along the M dimension.
|
||||
//
|
||||
const float* a_row = A;
|
||||
const uint8_t* b_col = QuantBData + n * ldb;
|
||||
const float* b_col_scale = QuantBScale + n * k_blks;
|
||||
const uint8_t* b_col_zp =
|
||||
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
|
||||
float* c_blk = C + n;
|
||||
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
|
||||
|
||||
MlasQNBitBlkDequantBForSgemm<BlkBitWidth, BlkLen, KernelType>(
|
||||
dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks
|
||||
);
|
||||
|
||||
size_t RowsRemaining = RangeCountM;
|
||||
while (RowsRemaining > 0) {
|
||||
#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)
|
||||
auto RowsHandled = GetMlasPlatform().GemmFloatKernel(
|
||||
a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true
|
||||
);
|
||||
#else
|
||||
auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f);
|
||||
#endif
|
||||
|
||||
if (bias) {
|
||||
MlasAddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc);
|
||||
}
|
||||
if (DataParams->PostProcessor != nullptr) {
|
||||
DataParams->PostProcessor->Process(
|
||||
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN,
|
||||
RowsHandled, CountN, ldc
|
||||
);
|
||||
}
|
||||
|
||||
c_blk += ldc * RowsHandled;
|
||||
a_row += lda * RowsHandled;
|
||||
RowsRemaining -= RowsHandled;
|
||||
}
|
||||
}
|
||||
MLAS_FORCEINLINE
|
||||
constexpr size_t
|
||||
Q8BlkAlignment()
|
||||
{
|
||||
return alignof(float);
|
||||
}
|
||||
|
||||
//
|
||||
// Kernel dispatch structure.
|
||||
//
|
||||
|
||||
typedef void(MLASCALL MLAS_SQNBIT_GEMM_OPERATION)(
|
||||
size_t K,
|
||||
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
|
||||
size_t RangeStartM,
|
||||
size_t RangeCountM,
|
||||
size_t RangeStartN,
|
||||
size_t RangeCountN
|
||||
);
|
||||
|
||||
enum QuantVariant {
|
||||
QuantVariant_BitWidth4_BlockSize16,
|
||||
QuantVariant_BitWidth4_BlockSize32,
|
||||
QuantVariant_BitWidth4_BlockSize64,
|
||||
QuantVariant_BitWidth4_BlockSize128,
|
||||
QuantVariant_BitWidth4_BlockSize256,
|
||||
QuantVariantCount, // Keep this element last and ensure that its value is the number of other QuantVariant values.
|
||||
// Its value is used as an array size.
|
||||
};
|
||||
|
||||
struct MLAS_SQNBIT_GEMM_DISPATCH {
|
||||
MLAS_SQNBIT_GEMM_OPERATION* Operations[QuantVariantCount] = {
|
||||
// Initialized to nullptrs. Overwrite in hardware-specific kernel implementation.
|
||||
};
|
||||
//
|
||||
// CompFp32 kernel function prototypes.
|
||||
//
|
||||
|
||||
/**
|
||||
* @brief Multiply float matrix A with quantized 4-bit integer matrix B.
|
||||
* B is block quantized and column major.
|
||||
* This kernel handles the special case where M, the number of rows of A and C, is 1.
|
||||
*
|
||||
* @param BlkLen Number of values in a block.
|
||||
* @param A Supplies the A matrix.
|
||||
* @param QuantBData Supplies the quantized B matrix block data.
|
||||
* @param QuantBScale Supplies the quantized B matrix block scale values.
|
||||
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
|
||||
* @param[out] C Supplies the output C matrix.
|
||||
* @param CountN Number of columns of B and C.
|
||||
* @param CountK Number of columns of A and rows of B.
|
||||
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
|
||||
* @param Bias Bias vector of length N.
|
||||
*/
|
||||
typedef void(SQ4BitGemmM1Kernel_CompFp32_Fn)(
|
||||
size_t BlkLen,
|
||||
const float* A,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
float* C,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockStrideQuantB,
|
||||
const float* Bias
|
||||
);
|
||||
|
||||
SQ4BitGemmM1Kernel_CompFp32_Fn* SQ4BitGemmM1Kernel_CompFp32 = nullptr;
|
||||
|
||||
/**
|
||||
* @brief Dequantize B into the format expected by the Sgemm kernel.
|
||||
* B is a quantized 4-bit integer matrix that is block quantized and column major.
|
||||
* This is equivalent to dequantizing B and then running MlasSgemmCopyPackB.
|
||||
*
|
||||
* @param BlkLen Number of values in a block.
|
||||
* @param[out] FpData Supplies the output buffer for the dequantized B float data.
|
||||
* @param QuantBData Supplies the quantized B matrix block data.
|
||||
* @param QuantBScale Supplies the quantized B matrix block scale values.
|
||||
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
|
||||
* @param CountN Number of columns of B.
|
||||
* @param CountK Number of rows of B.
|
||||
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
|
||||
*/
|
||||
typedef void(Q4BitBlkDequantBForSgemm_CompFp32_Fn)(
|
||||
size_t BlkLen,
|
||||
float* FpData,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockStrideQuantB
|
||||
);
|
||||
|
||||
Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr;
|
||||
|
||||
//
|
||||
// CompInt8 kernel function prototypes.
|
||||
//
|
||||
|
||||
/**
|
||||
* @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B.
|
||||
* A and B are block quantized and B is column major.
|
||||
* This kernel handles the special case where M, the number of rows of A and C, is 1.
|
||||
*
|
||||
* @param BlkLen Number of values in a block.
|
||||
* @param QuantA Supplies the quantized A matrix.
|
||||
Binary data containing block quantized int8 data and scale values.
|
||||
* @param QuantBData Supplies the quantized B matrix block data.
|
||||
* @param QuantBScale Supplies the quantized B matrix block scale values.
|
||||
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
|
||||
* @param[out] C Supplies the output C matrix.
|
||||
* @param CountN Number of columns of B and C.
|
||||
* @param CountK Number of columns of A and rows of B.
|
||||
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
|
||||
* @param Bias Bias vector of length N.
|
||||
*/
|
||||
typedef void(SQ4BitGemmM1Kernel_CompInt8_Fn)(
|
||||
size_t BlkLen,
|
||||
const std::byte* QuantA,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
float* C,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockStrideQuantB,
|
||||
const float* Bias
|
||||
);
|
||||
|
||||
SQ4BitGemmM1Kernel_CompInt8_Fn* SQ4BitGemmM1Kernel_CompInt8 = nullptr;
|
||||
|
||||
/**
|
||||
* @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers.
|
||||
*
|
||||
* @param BlkLen Number of values in a block.
|
||||
* @param A Supplies the A matrix.
|
||||
* @param CountK Number of columns of A.
|
||||
* @param[out] QuantA Supplies the output quantized A matrix.
|
||||
* Binary data containing block quantized int8 data and scale values.
|
||||
*/
|
||||
typedef void(QuantizeARow_CompInt8_Fn)(
|
||||
size_t BlkLen,
|
||||
const float* A,
|
||||
size_t CountK,
|
||||
std::byte* QuantA
|
||||
);
|
||||
|
||||
QuantizeARow_CompInt8_Fn* QuantizeARow_CompInt8 = nullptr;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -23,12 +23,6 @@ Abstract:
|
|||
#include <cassert>
|
||||
#include <utility>
|
||||
|
||||
//
|
||||
// Hardware-specific kernel type.
|
||||
//
|
||||
struct MLAS_SQNBIT_GEMM_KERNEL_NEON {
|
||||
};
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
|
|
@ -70,7 +64,7 @@ FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3)
|
|||
|
||||
template <size_t Capacity>
|
||||
MLAS_FORCEINLINE void
|
||||
LoadData(const float* src, size_t count, float32x4_t (& dst)[Capacity / 4])
|
||||
LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4])
|
||||
{
|
||||
static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4.");
|
||||
|
||||
|
|
@ -101,13 +95,14 @@ LoadData(const float* src, size_t count, float32x4_t (& dst)[Capacity / 4])
|
|||
}
|
||||
}
|
||||
|
||||
template <size_t BlkBitWidth, size_t BlkLen, size_t NCols>
|
||||
template <size_t NCols>
|
||||
MLAS_FORCEINLINE void
|
||||
ComputeDotProducts(
|
||||
ComputeDotProducts_BlkBitWidth4_CompFp32(
|
||||
size_t BlkLen,
|
||||
const float* ARowPtr,
|
||||
const uint8_t* QuantBDataColPtr,
|
||||
const std::byte* QuantBDataColPtr,
|
||||
const float* QuantBScaleColPtr,
|
||||
const uint8_t* QuantBZeroPointColPtr,
|
||||
const std::byte* QuantBZeroPointColPtr,
|
||||
float* SumPtr,
|
||||
size_t CountK,
|
||||
size_t StrideQuantBData,
|
||||
|
|
@ -116,8 +111,13 @@ ComputeDotProducts(
|
|||
const float* BiasPtr
|
||||
)
|
||||
{
|
||||
constexpr size_t BlkBitWidth = 4;
|
||||
|
||||
static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4");
|
||||
|
||||
constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration
|
||||
assert(BlkLen % SubBlkLen == 0);
|
||||
|
||||
const uint8x8_t LowMask = vdup_n_u8(0x0F);
|
||||
|
||||
// Manual conversion to float takes place in two steps:
|
||||
|
|
@ -135,7 +135,7 @@ ComputeDotProducts(
|
|||
|
||||
float32x4_t acc[NCols]{};
|
||||
|
||||
const uint8_t* QuantBData = QuantBDataColPtr;
|
||||
const std::byte* QuantBData = QuantBDataColPtr;
|
||||
const float* QuantBScale = QuantBScaleColPtr;
|
||||
size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
|
||||
|
||||
|
|
@ -150,10 +150,12 @@ ComputeDotProducts(
|
|||
float offset[NCols]; // Includes zero point and float conversion offset of 16.
|
||||
if (QuantBZeroPointColPtr != nullptr) {
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
const uint8_t zp_packed =
|
||||
const std::byte zp_packed =
|
||||
QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2];
|
||||
const uint8_t zp = ((QuantBZeroPointIdx & 1) == 1) ? (zp_packed >> 4) : (zp_packed & 0x0F);
|
||||
offset[i] = 16.0f + zp;
|
||||
const std::byte zp = ((QuantBZeroPointIdx & 1) == 1)
|
||||
? (zp_packed >> 4)
|
||||
: (zp_packed & std::byte{0x0F});
|
||||
offset[i] = 16.0f + std::to_integer<uint8_t>(zp);
|
||||
});
|
||||
} else {
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
|
|
@ -162,33 +164,27 @@ ComputeDotProducts(
|
|||
});
|
||||
}
|
||||
|
||||
constexpr size_t SubBlkLen = 16; // number of block elements to process in one iteration
|
||||
|
||||
for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) {
|
||||
// load A row vector elements
|
||||
|
||||
// load `SubBlkLen` elements from A, padded with 0's if there aren't enough
|
||||
const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen);
|
||||
float32x4_t av[4]{};
|
||||
LoadData<SubBlkLen>(ARowPtr + k + k_idx_in_blk, k_subblk_len, av);
|
||||
LoadFloatData<SubBlkLen>(ARowPtr + k + k_idx_in_blk, k_subblk_len, av);
|
||||
|
||||
// load B column vectors
|
||||
uint8x8_t bv_packed[NCols];
|
||||
const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8;
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8;
|
||||
bv_packed[i] = vld1_u8(QuantBData + i * StrideQuantBData + b_data_block_offset);
|
||||
});
|
||||
|
||||
uint8x8_t bv_u8_unzipped[NCols][2];
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
bv_u8_unzipped[i][0] = vand_u8(bv_packed[i], LowMask);
|
||||
bv_u8_unzipped[i][1] = vand_u8(vshr_n_u8(bv_packed[i], 4), LowMask);
|
||||
bv_packed[i] = vld1_u8(
|
||||
reinterpret_cast<const uint8_t*>(QuantBData) + i * StrideQuantBData + b_data_block_offset
|
||||
);
|
||||
});
|
||||
|
||||
uint8x8_t bv_u8[NCols][2];
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
bv_u8[i][0] = vzip1_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]);
|
||||
bv_u8[i][1] = vzip2_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]);
|
||||
bv_u8[i][0] = vand_u8(bv_packed[i], LowMask);
|
||||
bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4);
|
||||
});
|
||||
|
||||
// dequantize B
|
||||
|
|
@ -262,19 +258,13 @@ ComputeDotProducts(
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//
|
||||
// MlasSQNBitGemmKernel and helpers.
|
||||
//
|
||||
|
||||
template <size_t BlkBitWidth, size_t BlkLen>
|
||||
MLAS_FORCEINLINE void
|
||||
MlasSQNBitGemmM1KernelNeon(
|
||||
SQ4BitGemmM1Kernel_CompFp32(
|
||||
size_t BlkLen,
|
||||
const float* A,
|
||||
const uint8_t* QuantBData,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const uint8_t* QuantBZeroPoint,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
float* C,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
|
|
@ -282,6 +272,7 @@ MlasSQNBitGemmM1KernelNeon(
|
|||
const float* Bias
|
||||
)
|
||||
{
|
||||
constexpr size_t BlkBitWidth = 4;
|
||||
constexpr size_t NCols = 4;
|
||||
|
||||
const float* ARowPtr = A;
|
||||
|
|
@ -295,16 +286,17 @@ MlasSQNBitGemmM1KernelNeon(
|
|||
|
||||
const float* BiasPtr = Bias;
|
||||
|
||||
const uint8_t* QuantBDataColPtr = QuantBData;
|
||||
const std::byte* QuantBDataColPtr = QuantBData;
|
||||
const float* QuantBScaleColPtr = QuantBScale;
|
||||
const uint8_t* QuantBZeroPointColPtr = QuantBZeroPoint;
|
||||
const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint;
|
||||
|
||||
float* SumPtr = CRowPtr;
|
||||
|
||||
int64_t nblk = static_cast<int64_t>(CountN) - NCols;
|
||||
|
||||
while (nblk >= 0) {
|
||||
ComputeDotProducts<BlkBitWidth, BlkLen, NCols>(
|
||||
ComputeDotProducts_BlkBitWidth4_CompFp32<NCols>(
|
||||
BlkLen,
|
||||
ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
|
||||
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
|
||||
BiasPtr
|
||||
|
|
@ -327,7 +319,8 @@ MlasSQNBitGemmM1KernelNeon(
|
|||
// left over columns less than `NCols`?
|
||||
nblk += NCols;
|
||||
for (int64_t n = 0; n < nblk; ++n) {
|
||||
ComputeDotProducts<BlkBitWidth, BlkLen, 1>(
|
||||
ComputeDotProducts_BlkBitWidth4_CompFp32<1>(
|
||||
BlkLen,
|
||||
ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
|
||||
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
|
||||
BiasPtr
|
||||
|
|
@ -346,59 +339,26 @@ MlasSQNBitGemmM1KernelNeon(
|
|||
}
|
||||
}
|
||||
|
||||
#define SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(BlkBitWidth, BlkLen) \
|
||||
template <> \
|
||||
MLAS_FORCEINLINE void \
|
||||
MlasSQNBitGemmM1Kernel<BlkBitWidth, BlkLen, MLAS_SQNBIT_GEMM_KERNEL_NEON>( \
|
||||
const float* A, \
|
||||
const uint8_t* QuantBData, \
|
||||
const float* QuantBScale, \
|
||||
const uint8_t* QuantBZeroPoint, \
|
||||
float* C, \
|
||||
size_t CountN, \
|
||||
size_t CountK, \
|
||||
size_t BlockStrideQuantB, \
|
||||
const float* Bias \
|
||||
) \
|
||||
{ \
|
||||
return MlasSQNBitGemmM1KernelNeon<BlkBitWidth, BlkLen>( \
|
||||
A, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, \
|
||||
BlockStrideQuantB, Bias \
|
||||
); \
|
||||
}
|
||||
|
||||
SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 16)
|
||||
SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 32)
|
||||
SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 64)
|
||||
SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 128)
|
||||
SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 256)
|
||||
|
||||
#undef SPECIALIZE_SQNBIT_GEMM_M1_KERNEL
|
||||
|
||||
//
|
||||
// MlasQNBitBlkDequantBForSgemm and helpers.
|
||||
//
|
||||
|
||||
template <size_t BlkBitWidth, size_t BlkLen>
|
||||
MLAS_FORCEINLINE void
|
||||
MlasQNBitBlkDequantBForSgemmNeon(
|
||||
Q4BitBlkDequantBForSgemm_CompFp32(
|
||||
size_t BlkLen,
|
||||
float* FpData,
|
||||
const uint8_t* QuantBData,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const uint8_t* QuantBZeroPoint,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockStrideQuantB
|
||||
)
|
||||
{
|
||||
auto impl0_reference = [&]() {
|
||||
static_assert(BlkBitWidth == 4);
|
||||
constexpr size_t BlkBitWidth = 4;
|
||||
|
||||
float* Dst = FpData;
|
||||
|
||||
const uint8_t* QuantBDataCol = QuantBData;
|
||||
const std::byte* QuantBDataCol = QuantBData;
|
||||
const float* QuantBScaleCol = QuantBScale;
|
||||
const uint8_t* QuantBZeroPointCol = QuantBZeroPoint;
|
||||
const std::byte* QuantBZeroPointCol = QuantBZeroPoint;
|
||||
|
||||
for (size_t n = 0; n < CountN; n += 16) {
|
||||
const size_t nnlen = std::min(CountN - n, size_t{16});
|
||||
|
|
@ -407,20 +367,26 @@ MlasQNBitBlkDequantBForSgemmNeon(
|
|||
for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) {
|
||||
const size_t kklen = std::min(CountK - k, BlkLen);
|
||||
|
||||
const uint8_t* b_data =
|
||||
const std::byte* b_data =
|
||||
QuantBDataCol + k_blk_idx * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
||||
const float b_s = QuantBScaleCol[k_blk_idx];
|
||||
const uint8_t b_z =
|
||||
(QuantBZeroPointCol != nullptr)
|
||||
? ((k_blk_idx & 1) == 1)
|
||||
? QuantBZeroPointCol[k_blk_idx / 2] >> 4
|
||||
: QuantBZeroPointCol[k_blk_idx / 2] & 0x0F
|
||||
? std::to_integer<uint8_t>(QuantBZeroPointCol[k_blk_idx / 2] >> 4)
|
||||
: std::to_integer<uint8_t>(QuantBZeroPointCol[k_blk_idx / 2] & std::byte{0x0F})
|
||||
: 8;
|
||||
|
||||
for (size_t kk = 0; kk < kklen; ++kk) {
|
||||
const uint8_t b_packed = b_data[kk / 2];
|
||||
const uint8_t b_byte = ((kk & 1) == 1) ? b_packed >> 4 : b_packed & 0x0F;
|
||||
const float b_value = (b_byte - b_z) * b_s;
|
||||
const size_t packed_idx = kk % 16;
|
||||
|
||||
const bool is_low_half = packed_idx < 8;
|
||||
const size_t packed_byte_idx = packed_idx % 8;
|
||||
const size_t packed_range_offset = (kk / 16) * 8;
|
||||
|
||||
const std::byte b_packed = b_data[packed_range_offset + packed_byte_idx];
|
||||
const std::byte b_byte = is_low_half ? (b_packed & std::byte{0x0F}) : (b_packed >> 4);
|
||||
const float b_value = (std::to_integer<int8_t>(b_byte) - b_z) * b_s;
|
||||
|
||||
Dst[(k + kk) * 16 + nn] = b_value;
|
||||
}
|
||||
|
|
@ -448,31 +414,332 @@ MlasQNBitBlkDequantBForSgemmNeon(
|
|||
impl0_reference();
|
||||
}
|
||||
|
||||
#define SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(BlkBitWidth, BlkLen) \
|
||||
template <> \
|
||||
MLAS_FORCEINLINE void \
|
||||
MlasQNBitBlkDequantBForSgemm<BlkBitWidth, BlkLen, MLAS_SQNBIT_GEMM_KERNEL_NEON>( \
|
||||
float* FpData, \
|
||||
const uint8_t* QuantBData, \
|
||||
const float* QuantBScale, \
|
||||
const uint8_t* QuantBZeroPoint, \
|
||||
size_t CountN, \
|
||||
size_t CountK, \
|
||||
size_t BlockStrideQuantB \
|
||||
) \
|
||||
{ \
|
||||
MlasQNBitBlkDequantBForSgemmNeon<BlkBitWidth, BlkLen>( \
|
||||
FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB \
|
||||
); \
|
||||
//
|
||||
// CompInt8 kernel implementation and related helpers
|
||||
//
|
||||
|
||||
template <size_t SubBlkLen>
|
||||
MLAS_FORCEINLINE void
|
||||
QuantizeBlock(
|
||||
size_t BlkLen,
|
||||
const float* A,
|
||||
size_t ElementCount,
|
||||
std::byte* QuantA
|
||||
)
|
||||
{
|
||||
static_assert(SubBlkLen >= 16 && SubBlkLen % 16 == 0);
|
||||
|
||||
assert(BlkLen % SubBlkLen == 0);
|
||||
|
||||
constexpr size_t VectorCount = SubBlkLen / 4;
|
||||
|
||||
//
|
||||
// Scan block values first to determine scale.
|
||||
//
|
||||
|
||||
float amax = 0.0f; // max of absolute values of A block
|
||||
|
||||
size_t k;
|
||||
for (k = 0; k < ElementCount; k += SubBlkLen) {
|
||||
const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen);
|
||||
|
||||
float32x4_t a[VectorCount]{};
|
||||
LoadFloatData<SubBlkLen>(A + k, SubBlkElementCount, a);
|
||||
|
||||
float32x4_t abs_a[VectorCount];
|
||||
UnrolledLoop<VectorCount>([&](size_t i) {
|
||||
abs_a[i] = vabsq_f32(a[i]);
|
||||
});
|
||||
|
||||
// find amax of SubBlkLen elements
|
||||
for (size_t interval = VectorCount / 2; interval > 0; interval /= 2) {
|
||||
for (size_t i = 0; i < interval; ++i) {
|
||||
abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]);
|
||||
}
|
||||
}
|
||||
|
||||
// update existing amax
|
||||
amax = std::max(amax, vmaxvq_f32(abs_a[0]));
|
||||
}
|
||||
|
||||
SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 16)
|
||||
SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 32)
|
||||
SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 64)
|
||||
SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 128)
|
||||
SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 256)
|
||||
constexpr float range_max = (1 << 7) - 1;
|
||||
const float scale = amax / range_max;
|
||||
const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f;
|
||||
|
||||
#undef SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM
|
||||
Q8BlkScale(QuantA) = scale;
|
||||
|
||||
//
|
||||
// Compute quantized block values.
|
||||
//
|
||||
|
||||
int8_t* QuantAData = Q8BlkData(QuantA);
|
||||
|
||||
for (k = 0; k < ElementCount; k += SubBlkLen) {
|
||||
const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen);
|
||||
|
||||
float32x4_t a[VectorCount]{};
|
||||
LoadFloatData<SubBlkLen>(A + k, SubBlkElementCount, a);
|
||||
|
||||
UnrolledLoop<VectorCount>([&](size_t i) {
|
||||
a[i] = vmulq_n_f32(a[i], scale_reciprocal);
|
||||
});
|
||||
|
||||
int32x4_t a_s32[VectorCount];
|
||||
UnrolledLoop<VectorCount>([&](size_t i) {
|
||||
a_s32[i] = vcvtaq_s32_f32(a[i]);
|
||||
});
|
||||
|
||||
UnrolledLoop<VectorCount>([&](size_t i) {
|
||||
QuantAData[k + i * 4 + 0] = static_cast<int8_t>(vgetq_lane_s32(a_s32[i], 0));
|
||||
QuantAData[k + i * 4 + 1] = static_cast<int8_t>(vgetq_lane_s32(a_s32[i], 1));
|
||||
QuantAData[k + i * 4 + 2] = static_cast<int8_t>(vgetq_lane_s32(a_s32[i], 2));
|
||||
QuantAData[k + i * 4 + 3] = static_cast<int8_t>(vgetq_lane_s32(a_s32[i], 3));
|
||||
});
|
||||
}
|
||||
|
||||
//
|
||||
// Zero out any remaining sub-block elements.
|
||||
//
|
||||
|
||||
for (; k < BlkLen; k += SubBlkLen) {
|
||||
const int8x16_t Zeros = vdupq_n_s8(0);
|
||||
UnrolledLoop<SubBlkLen / 16>([&](size_t i) {
|
||||
vst1q_s8(QuantAData + k + i * 16, Zeros);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void MLASCALL
|
||||
QuantizeARow_CompInt8(
|
||||
size_t BlkLen,
|
||||
const float* A,
|
||||
size_t CountK,
|
||||
std::byte* QuantA
|
||||
)
|
||||
{
|
||||
const float* ADataBlkPtr = A;
|
||||
std::byte* QuantABlkPtr = QuantA;
|
||||
|
||||
for (size_t k = 0; k < CountK; k += BlkLen) {
|
||||
const size_t k_blk_len = std::min(CountK - k, BlkLen);
|
||||
|
||||
QuantizeBlock<16>(BlkLen, ADataBlkPtr, k_blk_len, QuantABlkPtr);
|
||||
|
||||
ADataBlkPtr += BlkLen;
|
||||
QuantABlkPtr += Q8BlkSize(BlkLen);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t NCols>
|
||||
MLAS_FORCEINLINE void
|
||||
ComputeDotProducts_BlkBitWidth4_CompInt8(
|
||||
size_t BlkLen,
|
||||
const std::byte* QuantARowPtr,
|
||||
const std::byte* QuantBDataColPtr,
|
||||
const float* QuantBScaleColPtr,
|
||||
const std::byte* QuantBZeroPointColPtr,
|
||||
float* SumPtr,
|
||||
size_t CountK,
|
||||
size_t StrideQuantBData,
|
||||
size_t StrideQuantBScale,
|
||||
size_t StrideQuantBZeroPoint,
|
||||
const float* BiasPtr
|
||||
)
|
||||
{
|
||||
static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4");
|
||||
|
||||
constexpr size_t BlkBitWidth = 4;
|
||||
|
||||
constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration
|
||||
assert(BlkLen % SubBlkLen == 0);
|
||||
|
||||
const uint8x8_t LowMask = vdup_n_u8(0x0F);
|
||||
|
||||
const std::byte* QuantA = QuantARowPtr;
|
||||
|
||||
const std::byte* QuantBData = QuantBDataColPtr;
|
||||
const float* QuantBScale = QuantBScaleColPtr;
|
||||
size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
|
||||
|
||||
float32x4_t acc[NCols]{};
|
||||
|
||||
for (size_t k = 0; k < CountK; k += BlkLen) {
|
||||
const size_t k_blk_len = std::min(CountK - k, BlkLen);
|
||||
|
||||
const float a_scale = Q8BlkScale(QuantA);
|
||||
const int8_t* a_data = Q8BlkData(QuantA);
|
||||
|
||||
float b_scale[NCols];
|
||||
UnrolledLoop<NCols>([&](size_t i) { b_scale[i] = QuantBScale[i * StrideQuantBScale]; });
|
||||
|
||||
int8_t b_zp[NCols];
|
||||
if (QuantBZeroPointColPtr != nullptr) {
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
const std::byte zp_packed =
|
||||
QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2];
|
||||
b_zp[i] = ((QuantBZeroPointIdx & 1) == 1)
|
||||
? std::to_integer<int8_t>(zp_packed >> 4)
|
||||
: std::to_integer<int8_t>(zp_packed & std::byte{0x0F});
|
||||
});
|
||||
} else {
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
b_zp[i] = 8;
|
||||
});
|
||||
}
|
||||
|
||||
for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) {
|
||||
// load A row vector
|
||||
int8x16_t av = vld1q_s8(a_data + k_idx_in_blk);
|
||||
|
||||
// load B column vectors
|
||||
uint8x8_t bv_packed[NCols];
|
||||
const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8;
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
bv_packed[i] = vld1_u8(
|
||||
reinterpret_cast<const uint8_t*>(QuantBData) + i * StrideQuantBData + b_data_block_offset
|
||||
);
|
||||
});
|
||||
|
||||
int8x16_t bv[NCols];
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMask));
|
||||
const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4));
|
||||
bv[i] = vcombine_s8(lo, hi);
|
||||
});
|
||||
|
||||
// subtract B zero point
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
const int8x16_t zp_v = vdupq_n_s8(b_zp[i]);
|
||||
bv[i] = vsubq_s8(bv[i], zp_v);
|
||||
});
|
||||
|
||||
// compute quantized dot product
|
||||
int32x4_t dot[NCols]{};
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
dot[i] = vdotq_s32(dot[i], av, bv[i]);
|
||||
});
|
||||
|
||||
// convert dot product result to float
|
||||
float32x4_t dot_f32[NCols];
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
dot_f32[i] = vcvtq_f32_s32(dot[i]);
|
||||
});
|
||||
|
||||
// multiply dot product result by scale and update accumulator
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
const float32x4_t scale_v = vdupq_n_f32(a_scale * b_scale[i]);
|
||||
acc[i] = vfmaq_f32(acc[i], dot_f32[i], scale_v);
|
||||
});
|
||||
}
|
||||
|
||||
// increment pointers to next block
|
||||
QuantA += Q8BlkSize(BlkLen);
|
||||
QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
||||
QuantBScale += 1;
|
||||
QuantBZeroPointIdx += 1;
|
||||
}
|
||||
|
||||
if constexpr (NCols == 4) {
|
||||
float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]);
|
||||
|
||||
if (BiasPtr != nullptr) {
|
||||
sum = vaddq_f32(sum, vld1q_f32(BiasPtr));
|
||||
}
|
||||
|
||||
vst1q_f32(SumPtr, sum);
|
||||
} else {
|
||||
for (size_t i = 0; i < NCols; ++i) {
|
||||
SumPtr[i] = vaddvq_f32(acc[i]);
|
||||
if (BiasPtr != nullptr) {
|
||||
SumPtr[i] += BiasPtr[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
SQ4BitGemmM1Kernel_CompInt8(
|
||||
size_t BlkLen,
|
||||
const std::byte* QuantA,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
float* C,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockStrideQuantB,
|
||||
const float* Bias
|
||||
)
|
||||
{
|
||||
constexpr size_t BlkBitWidth = 4;
|
||||
constexpr size_t NCols = 4;
|
||||
|
||||
const std::byte* QuantARowPtr = QuantA;
|
||||
float* CRowPtr = C;
|
||||
|
||||
const size_t BlockCountK = BlockStrideQuantB;
|
||||
|
||||
const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
||||
const size_t StrideQuantBScale = BlockCountK;
|
||||
const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(BlockCountK);
|
||||
|
||||
const float* BiasPtr = Bias;
|
||||
|
||||
const std::byte* QuantBDataColPtr = QuantBData;
|
||||
const float* QuantBScaleColPtr = QuantBScale;
|
||||
const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint;
|
||||
|
||||
float* SumPtr = CRowPtr;
|
||||
|
||||
int64_t nblk = static_cast<int64_t>(CountN) - NCols;
|
||||
|
||||
while (nblk >= 0) {
|
||||
ComputeDotProducts_BlkBitWidth4_CompInt8<NCols>(
|
||||
BlkLen,
|
||||
QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
|
||||
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
|
||||
BiasPtr
|
||||
);
|
||||
|
||||
// move to next `NCols` columns
|
||||
|
||||
QuantBDataColPtr += NCols * StrideQuantBData;
|
||||
QuantBScaleColPtr += NCols * StrideQuantBScale;
|
||||
if (QuantBZeroPointColPtr != nullptr) {
|
||||
QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint;
|
||||
}
|
||||
|
||||
BiasPtr += BiasPtr != nullptr ? NCols : 0;
|
||||
SumPtr += NCols;
|
||||
|
||||
nblk -= NCols;
|
||||
}
|
||||
|
||||
// left over columns less than `NCols`?
|
||||
nblk += NCols;
|
||||
for (int64_t n = 0; n < nblk; ++n) {
|
||||
ComputeDotProducts_BlkBitWidth4_CompInt8<1>(
|
||||
BlkLen,
|
||||
QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
|
||||
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
|
||||
BiasPtr
|
||||
);
|
||||
|
||||
// move to next column
|
||||
|
||||
QuantBDataColPtr += StrideQuantBData;
|
||||
QuantBScaleColPtr += StrideQuantBScale;
|
||||
if (QuantBZeroPointColPtr != nullptr) {
|
||||
QuantBZeroPointColPtr += StrideQuantBZeroPoint;
|
||||
}
|
||||
|
||||
BiasPtr += BiasPtr != nullptr ? 1 : 0;
|
||||
SumPtr += 1;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//
|
||||
// Kernel dispatch structure definition.
|
||||
|
|
@ -480,10 +747,11 @@ SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 256)
|
|||
|
||||
const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() {
|
||||
MLAS_SQNBIT_GEMM_DISPATCH d;
|
||||
d.Operations[QuantVariant_BitWidth4_BlockSize16] = MlasSQNBitGemmOperation<4, 16, MLAS_SQNBIT_GEMM_KERNEL_NEON>;
|
||||
d.Operations[QuantVariant_BitWidth4_BlockSize32] = MlasSQNBitGemmOperation<4, 32, MLAS_SQNBIT_GEMM_KERNEL_NEON>;
|
||||
d.Operations[QuantVariant_BitWidth4_BlockSize64] = MlasSQNBitGemmOperation<4, 64, MLAS_SQNBIT_GEMM_KERNEL_NEON>;
|
||||
d.Operations[QuantVariant_BitWidth4_BlockSize128] = MlasSQNBitGemmOperation<4, 128, MLAS_SQNBIT_GEMM_KERNEL_NEON>;
|
||||
d.Operations[QuantVariant_BitWidth4_BlockSize256] = MlasSQNBitGemmOperation<4, 256, MLAS_SQNBIT_GEMM_KERNEL_NEON>;
|
||||
|
||||
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32;
|
||||
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32;
|
||||
d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8;
|
||||
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8;
|
||||
|
||||
return d;
|
||||
}();
|
||||
|
|
|
|||
|
|
@ -109,12 +109,19 @@ void Q8Q4GEMM(benchmark::State& state, MLAS_BLK_QUANT_TYPE qtype) {
|
|||
|
||||
static void GemmSizeProducts(benchmark::internal::Benchmark* b) {
|
||||
b->ArgNames(q4gemm_bench_arg_names);
|
||||
ArgsProduct(b, {{1, 1024, 2048}, {4096}, {4096}, {8}});
|
||||
b->ArgsProduct({{1, 1024, 2048}, {4096}, {4096}, {8}});
|
||||
}
|
||||
|
||||
BENCHMARK_CAPTURE(Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(Q4GEMM, Q4Sym128, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym128, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
[[maybe_unused]] static const bool benchmarks_registered = []() {
|
||||
const bool is_q4gemm_supported = MlasQ4GemmPackBSize(BlkQ4Sym, 1, 1) > 0;
|
||||
if (is_q4gemm_supported) {
|
||||
BENCHMARK_CAPTURE(Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(Q4GEMM, Q4Sym128, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym128, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}();
|
||||
|
|
|
|||
|
|
@ -224,8 +224,7 @@ BENCHMARK_CAPTURE(SCONV_NCHW, TeamsModel, "")->Apply(TeamsModel)->UseRealTime();
|
|||
|
||||
static void General_Conv2d(benchmark::internal::Benchmark* b) {
|
||||
b->ArgNames(ArgNamesForConv(2));
|
||||
ArgsProduct(
|
||||
b,
|
||||
b->ArgsProduct(
|
||||
{{2}, // Rank,
|
||||
{1}, // N
|
||||
{1, 2}, // Groups
|
||||
|
|
|
|||
|
|
@ -103,14 +103,14 @@ void SGEMM(benchmark::State& state, bool pack_b, bool trans_a, bool trans_b, flo
|
|||
|
||||
static void GemmSizeWithOne(benchmark::internal::Benchmark* b) {
|
||||
b->ArgNames(sgemm_bench_arg_names);
|
||||
ArgsProduct(b, {{1}, {63, 255, 1023}, {63, 255, 1023}});
|
||||
ArgsProduct(b, {{63, 255, 1023}, {1}, {63, 255, 1023}});
|
||||
ArgsProduct(b, {{63, 255, 1023}, {63, 255, 1023}, {1}});
|
||||
b->ArgsProduct({{1}, {63, 255, 1023}, {63, 255, 1023}});
|
||||
b->ArgsProduct({{63, 255, 1023}, {1}, {63, 255, 1023}});
|
||||
b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {1}});
|
||||
}
|
||||
|
||||
static void GemmSizeProducts(benchmark::internal::Benchmark* b) {
|
||||
b->ArgNames(sgemm_bench_arg_names);
|
||||
ArgsProduct(b, {{63, 255, 1023}, {63, 255, 1023}, {63, 255, 1023}});
|
||||
b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {63, 255, 1023}});
|
||||
}
|
||||
|
||||
BENCHMARK_CAPTURE(SGEMM, NORMAL_NoTrans, false, false, false)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
|
|
@ -128,7 +128,7 @@ BENCHMARK_CAPTURE(SGEMM, PACKB_TransA, true, true, false)->Apply(GemmSizeProduct
|
|||
|
||||
static void GemmLLMSizeProducts(benchmark::internal::Benchmark* b) {
|
||||
b->ArgNames(sgemm_bench_arg_names);
|
||||
ArgsProduct(b, {{1, 1024, 2048}, {4096, 11008}, {4096, 11008}});
|
||||
b->ArgsProduct({{1, 1024, 2048}, {4096, 11008}, {4096, 11008}});
|
||||
}
|
||||
|
||||
BENCHMARK_CAPTURE(SGEMM, LLM, false, false, true)->Apply(GemmLLMSizeProducts)->UseRealTime();
|
||||
|
|
|
|||
|
|
@ -4,33 +4,36 @@
|
|||
#include "mlas_q4.h"
|
||||
#include "mlas_qnbit.h"
|
||||
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include "benchmark/benchmark.h"
|
||||
|
||||
#include "bench_util.h"
|
||||
#include "core/util/thread_utils.h"
|
||||
#include "core/common/narrow.h"
|
||||
|
||||
template <size_t BlkBitWidth, size_t BlkLen, bool Symmetric>
|
||||
using onnxruntime::narrow;
|
||||
|
||||
template <size_t BlkBitWidth>
|
||||
void SQNBITGEMM(benchmark::State& state) {
|
||||
if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!");
|
||||
if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!");
|
||||
if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!");
|
||||
if (state.range(3) <= 0) throw std::invalid_argument("Threads must greater than 0!");
|
||||
|
||||
const size_t M = static_cast<size_t>(state.range(0));
|
||||
const size_t N = static_cast<size_t>(state.range(1));
|
||||
const size_t K = static_cast<size_t>(state.range(2));
|
||||
const size_t threads = static_cast<size_t>(state.range(3));
|
||||
const auto BlkLen = narrow<size_t>(state.range(0));
|
||||
const auto M = narrow<size_t>(state.range(1));
|
||||
const auto N = narrow<size_t>(state.range(2));
|
||||
const auto K = narrow<size_t>(state.range(3));
|
||||
const auto Threads = narrow<size_t>(state.range(4));
|
||||
const auto Symmetric = narrow<bool>(state.range(5));
|
||||
const auto ComputeType = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(state.range(6));
|
||||
|
||||
size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes;
|
||||
MlasBlockwiseQuantizedBufferSizes(
|
||||
BlkBitWidth, BlkLen, /* columnwise */ true,
|
||||
BlkBitWidth, static_cast<int>(BlkLen), /* columnwise */ true,
|
||||
static_cast<int>(K), static_cast<int>(N),
|
||||
QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes);
|
||||
|
||||
OrtThreadPoolParams tpo;
|
||||
tpo.thread_pool_size = static_cast<int>(threads);
|
||||
tpo.thread_pool_size = static_cast<int>(Threads);
|
||||
tpo.auto_set_affinity = true;
|
||||
|
||||
std::unique_ptr<onnxruntime::concurrency::ThreadPool> tp(
|
||||
|
|
@ -47,14 +50,29 @@ void SQNBITGEMM(benchmark::State& state) {
|
|||
|
||||
MlasQuantizeBlockwise<float, BlkBitWidth>(QuantBData.data(), QuantBScale.data(),
|
||||
Symmetric ? nullptr : QuantBZeroPoint.data(),
|
||||
B.data(), BlkLen, /* columnwise */ true,
|
||||
B.data(), static_cast<int>(BlkLen), /* columnwise */ true,
|
||||
static_cast<int>(K), static_cast<int>(N), static_cast<int>(N),
|
||||
tp.get());
|
||||
|
||||
std::unique_ptr<std::byte[]> Workspace;
|
||||
if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType);
|
||||
WorkspaceSize > 0) {
|
||||
Workspace = std::make_unique<std::byte[]>(WorkspaceSize);
|
||||
}
|
||||
|
||||
std::unique_ptr<std::byte[]> PackedQuantBData;
|
||||
if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen);
|
||||
PackedQuantBDataSize > 0) {
|
||||
PackedQuantBData = std::make_unique<std::byte[]>(PackedQuantBDataSize);
|
||||
MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData.data(), PackedQuantBData.get(), tp.get());
|
||||
}
|
||||
|
||||
MLAS_SQNBIT_GEMM_DATA_PARAMS params{};
|
||||
params.A = A.data();
|
||||
params.lda = K;
|
||||
params.QuantBData = QuantBData.data();
|
||||
params.QuantBData = PackedQuantBData != nullptr
|
||||
? static_cast<const void*>(PackedQuantBData.get())
|
||||
: static_cast<const void*>(QuantBData.data());
|
||||
params.QuantBScale = QuantBScale.data();
|
||||
params.QuantBZeroPoint = Symmetric ? nullptr : QuantBZeroPoint.data();
|
||||
params.Bias = nullptr;
|
||||
|
|
@ -62,30 +80,41 @@ void SQNBITGEMM(benchmark::State& state) {
|
|||
params.ldc = N;
|
||||
|
||||
// warm up run
|
||||
MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, tp.get());
|
||||
MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get());
|
||||
|
||||
for (auto _ : state) {
|
||||
MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, tp.get());
|
||||
MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get());
|
||||
}
|
||||
}
|
||||
|
||||
static void GemmSizeProducts(benchmark::internal::Benchmark* b) {
|
||||
b->ArgNames({"M", "N", "K", "Threads"});
|
||||
ArgsProduct(b, {{1, 1024, 2048}, {4096, 11008}, {4096, 11008}, {8}});
|
||||
static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) {
|
||||
b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "ComputeType"});
|
||||
|
||||
ArgsProductWithFilter(b,
|
||||
|
||||
{{16, 32, 64, 128, 256}, // BlkLen
|
||||
{1, 1024, 2048}, // M
|
||||
{4096, 11008}, // N
|
||||
{4096, 11008}, // K
|
||||
{8}, // Threads
|
||||
{int64_t{false}, int64_t{true}}, // Symmetric
|
||||
{int64_t{CompFp32}, int64_t{CompInt8}}}, // ComputeType
|
||||
|
||||
[](const std::vector<int64_t>& args) {
|
||||
return MlasIsSQNBitGemmAvailable(
|
||||
// M, N, K
|
||||
narrow<size_t>(args[1]), narrow<size_t>(args[2]), narrow<size_t>(args[3]),
|
||||
// BlkBitWidth, BlkLen
|
||||
4, narrow<size_t>(args[0]),
|
||||
// ComputeType
|
||||
static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(args[6]));
|
||||
});
|
||||
}
|
||||
|
||||
BENCHMARK(SQNBITGEMM<4, 16, false>)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK(SQNBITGEMM<4, 16, true>)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK(SQNBITGEMM<4, 32, false>)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK(SQNBITGEMM<4, 32, true>)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK(SQNBITGEMM<4, 64, false>)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK(SQNBITGEMM<4, 64, true>)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK(SQNBITGEMM<4, 128, false>)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK(SQNBITGEMM<4, 128, true>)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK(SQNBITGEMM<4, 256, false>)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK(SQNBITGEMM<4, 256, true>)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime();
|
||||
|
||||
#if defined(MLAS_JBLAS)
|
||||
|
||||
#ifdef MLAS_JBLAS
|
||||
void Q4GEMM_Jblas(benchmark::State& state, int block_size, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE cmp_type) {
|
||||
if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!");
|
||||
if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!");
|
||||
|
|
@ -130,6 +159,11 @@ void Q4GEMM_Jblas(benchmark::State& state, int block_size, bool is_asym, MLAS_SQ
|
|||
}
|
||||
}
|
||||
|
||||
static void GemmSizeProducts(benchmark::internal::Benchmark* b) {
|
||||
b->ArgNames({"M", "N", "K", "Threads"});
|
||||
b->ArgsProduct({{1, 1024, 2048}, {4096, 11008}, {4096, 11008}, {8}});
|
||||
}
|
||||
|
||||
BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4G32SymInt8, 32, false, CompInt8)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4G128SymInt8, 128, false, CompInt8)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4GPerNSymInt8, -1, false, CompInt8)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
|
|
@ -137,4 +171,5 @@ BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4G32SymFp32, 32, false, CompFp32)->Apply(GemmSi
|
|||
BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4G128SymFp32, 128, false, CompFp32)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4GPerNSymFp32, -1, false, CompFp32)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4G32AsymFp32, 32, true, CompFp32)->Apply(GemmSizeProducts)->UseRealTime();
|
||||
#endif
|
||||
|
||||
#endif // defined(MLAS_JBLAS)
|
||||
|
|
|
|||
|
|
@ -23,10 +23,9 @@ std::vector<float> RandomVectorUniform(std::vector<int64_t> shape, float min_val
|
|||
return RandomVectorUniform(static_cast<size_t>(sz), min_value, max_value);
|
||||
}
|
||||
|
||||
// The Benchmark used here do not contains this as in newer version.
|
||||
// Use the code from newer version.
|
||||
void ArgsProduct(benchmark::internal::Benchmark* bench,
|
||||
const std::vector<std::vector<int64_t>>& arglists) {
|
||||
void ArgsProductWithFilter(benchmark::internal::Benchmark* bench,
|
||||
const std::vector<std::vector<int64_t>>& arglists,
|
||||
std::function<bool(const std::vector<int64_t>& args)> include_filter) {
|
||||
std::vector<std::size_t> indices(arglists.size(), 0);
|
||||
const std::size_t total = std::accumulate(
|
||||
std::begin(arglists), std::end(arglists), std::size_t{1},
|
||||
|
|
@ -39,7 +38,9 @@ void ArgsProduct(benchmark::internal::Benchmark* bench,
|
|||
for (std::size_t arg = 0; arg < arglists.size(); arg++) {
|
||||
args.push_back(arglists[arg][indices[arg]]);
|
||||
}
|
||||
bench->Args(args);
|
||||
if (include_filter(args)) {
|
||||
bench->Args(args);
|
||||
}
|
||||
args.clear();
|
||||
|
||||
std::size_t arg = 0;
|
||||
|
|
|
|||
|
|
@ -5,10 +5,14 @@
|
|||
|
||||
#include <benchmark/benchmark.h>
|
||||
|
||||
#include <functional>
|
||||
#include <random>
|
||||
|
||||
void ArgsProduct(benchmark::internal::Benchmark* bench,
|
||||
const std::vector<std::vector<int64_t>>& arglists);
|
||||
// Specifies benchmark arguments from the cartesian product of `arglists`, like Benchmark::ArgsProduct().
|
||||
// `include_filter` is called to determine whether a given set of arguments should be included.
|
||||
void ArgsProductWithFilter(benchmark::internal::Benchmark* bench,
|
||||
const std::vector<std::vector<int64_t>>& arglists,
|
||||
std::function<bool(const std::vector<int64_t>& args)> include_filter);
|
||||
|
||||
template <typename ElementType>
|
||||
std::vector<ElementType> RandomVectorUniform(
|
||||
|
|
|
|||
|
|
@ -18,6 +18,17 @@ Abstract:
|
|||
#include "mlas_q4.h"
|
||||
#include "mlas_qnbit.h"
|
||||
|
||||
static constexpr const char* ComputeTypeName(MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType) {
|
||||
switch (ComputeType) {
|
||||
case CompFp32:
|
||||
return "Fp32";
|
||||
case CompInt8:
|
||||
return "Int8";
|
||||
default:
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Test class for n-bit int block quantized GEMM
|
||||
* Note: only 2-D matmul supported for now
|
||||
|
|
@ -26,12 +37,16 @@ template <size_t BlkBitWidth, size_t BlkLen>
|
|||
class MlasSQNBitGemmTest : public MlasTestBase {
|
||||
private:
|
||||
MatrixGuardBuffer<float> BufferA;
|
||||
MatrixGuardBuffer<int8_t> BufferQuantAData;
|
||||
MatrixGuardBuffer<float> BufferQuantAScale;
|
||||
MatrixGuardBuffer<float> BufferB;
|
||||
MatrixGuardBuffer<uint8_t> BufferQuantBData;
|
||||
MatrixGuardBuffer<std::byte> BufferPackedQuantBData;
|
||||
MatrixGuardBuffer<uint8_t> BufferQuantBZeroPoint;
|
||||
MatrixGuardBuffer<float> BufferQuantBScale;
|
||||
MatrixGuardBuffer<float> BufferDequantizedB;
|
||||
MatrixGuardBuffer<float> BufferBias;
|
||||
MatrixGuardBuffer<std::byte> BufferWorkspace;
|
||||
MatrixGuardBuffer<float> BufferC;
|
||||
MatrixGuardBuffer<float> BufferCReference;
|
||||
|
||||
|
|
@ -40,12 +55,15 @@ class MlasSQNBitGemmTest : public MlasTestBase {
|
|||
size_t K,
|
||||
const float* A,
|
||||
size_t lda,
|
||||
const uint8_t* QuantBData,
|
||||
const void* QuantBData,
|
||||
const void* PackedQuantBData,
|
||||
const float* QuantBScale,
|
||||
const uint8_t* QuantBZeroPoint,
|
||||
const void* QuantBZeroPoint,
|
||||
const float* Bias,
|
||||
float* C,
|
||||
size_t ldc,
|
||||
void* Workspace,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
|
||||
MLAS_THREADPOOL* Threadpool) {
|
||||
MLAS_SQNBIT_GEMM_DATA_PARAMS params;
|
||||
params.A = A;
|
||||
|
|
@ -53,23 +71,106 @@ class MlasSQNBitGemmTest : public MlasTestBase {
|
|||
params.Bias = Bias;
|
||||
params.C = C;
|
||||
params.ldc = ldc;
|
||||
params.QuantBData = QuantBData;
|
||||
params.QuantBData = PackedQuantBData != nullptr ? PackedQuantBData : QuantBData;
|
||||
params.QuantBScale = QuantBScale;
|
||||
params.QuantBZeroPoint = QuantBZeroPoint;
|
||||
params.PostProcessor = nullptr;
|
||||
|
||||
MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, Threadpool);
|
||||
MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool);
|
||||
}
|
||||
|
||||
void CallReferenceGemm(size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const float* A,
|
||||
const uint8_t* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const uint8_t* QuantBZeroPoint,
|
||||
const float* Bias,
|
||||
float* C) {
|
||||
void QuantizeA(size_t M, size_t K, const float* A, int8_t* QuantAData, float* QuantAScale) {
|
||||
const size_t BlockCountK = (K + BlkLen - 1) / BlkLen;
|
||||
const size_t lda = K;
|
||||
for (size_t m = 0; m < M; ++m) {
|
||||
for (size_t k = 0, k_blk = 0; k < K; k += BlkLen, ++k_blk) {
|
||||
const size_t local_blk_len = std::min(K - k, BlkLen);
|
||||
float blk_a[BlkLen]{};
|
||||
std::copy_n(A + m * lda + k, local_blk_len, blk_a);
|
||||
|
||||
float amax = 0.0f; // max of absolute values of A block
|
||||
for (size_t kk = 0; kk < local_blk_len; ++kk) {
|
||||
float a = blk_a[kk];
|
||||
amax = std::max(amax, fabsf(a));
|
||||
}
|
||||
|
||||
constexpr float range_max = (1 << 7) - 1;
|
||||
const float scale = amax / range_max;
|
||||
const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f;
|
||||
|
||||
QuantAScale[m * BlockCountK + k_blk] = scale;
|
||||
|
||||
for (size_t kk = 0; kk < BlkLen; ++kk) {
|
||||
const float q = roundf(blk_a[kk] * scale_reciprocal);
|
||||
QuantAData[m * BlockCountK * BlkLen + k + kk] =
|
||||
static_cast<int8_t>(
|
||||
std::clamp(q,
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min()),
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max())));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CallReferenceGemm_CompInt8(size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const float* A,
|
||||
const uint8_t* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const uint8_t* QuantBZeroPoint,
|
||||
const float* Bias,
|
||||
float* C) {
|
||||
const size_t BlockCountK = (K + BlkLen - 1) / BlkLen;
|
||||
|
||||
int8_t* QuantAData = BufferQuantAData.GetBuffer(M * BlockCountK * BlkLen);
|
||||
float* QuantAScale = BufferQuantAScale.GetBuffer(M * BlockCountK);
|
||||
QuantizeA(M, K, A, QuantAData, QuantAScale);
|
||||
|
||||
for (size_t m = 0; m < M; ++m) {
|
||||
for (size_t n = 0; n < N; ++n) {
|
||||
float sum = Bias == nullptr ? 0.0f : Bias[n];
|
||||
for (size_t k = 0, k_blk = 0; k < K; k += BlkLen, ++k_blk) {
|
||||
const size_t k_blk_len = std::min(K - k, BlkLen);
|
||||
|
||||
const float a_scale = QuantAScale[m * BlockCountK + k_blk];
|
||||
|
||||
const float b_scale = QuantBScale[n * BlockCountK + k_blk];
|
||||
|
||||
static_assert(BlkBitWidth == 4, "only implemented for 4-bit quantized B");
|
||||
|
||||
uint8_t b_zp = 8;
|
||||
if (QuantBZeroPoint != nullptr) {
|
||||
const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockCountK + 1) / 2) + k_blk / 2];
|
||||
b_zp = (k_blk & 1) ? (b_zp_byte >> 4) : (b_zp_byte & 0x0F);
|
||||
}
|
||||
|
||||
int32_t qsum = 0;
|
||||
|
||||
for (size_t kk = 0; kk < k_blk_len; ++kk) {
|
||||
const int8_t qa = QuantAData[m * BlockCountK * BlkLen + k + kk];
|
||||
const uint8_t qb_byte = QuantBData[(n * BlockCountK * BlkLen + k + kk) / 2];
|
||||
const int8_t qb = ((kk & 1) == 1 ? (qb_byte >> 4) : (qb_byte & 0x0F)) - b_zp;
|
||||
qsum += qa * qb;
|
||||
}
|
||||
|
||||
sum += static_cast<float>(qsum) * a_scale * b_scale;
|
||||
}
|
||||
|
||||
C[m * N + n] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CallReferenceGemm_CompFp32(size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const float* A,
|
||||
const uint8_t* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const uint8_t* QuantBZeroPoint,
|
||||
const float* Bias,
|
||||
float* C) {
|
||||
float* DequantizedBData = BufferDequantizedB.GetBuffer(K * N);
|
||||
MlasDequantizeBlockwise<float, BlkBitWidth>(
|
||||
DequantizedBData, QuantBData, QuantBScale, QuantBZeroPoint, BlkLen, /* columnwise */ true,
|
||||
|
|
@ -95,6 +196,7 @@ class MlasSQNBitGemmTest : public MlasTestBase {
|
|||
|
||||
public:
|
||||
void Test(size_t M, size_t N, size_t K,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
|
||||
bool WithBias, bool Symmetric, bool WithThreadpool) {
|
||||
MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr;
|
||||
|
||||
|
|
@ -126,7 +228,7 @@ class MlasSQNBitGemmTest : public MlasTestBase {
|
|||
float* C = BufferC.GetBuffer(N * M, true);
|
||||
float* CReference = BufferCReference.GetBuffer(N * M, true);
|
||||
|
||||
// pack B
|
||||
// quantize B
|
||||
uint8_t* QuantBData = nullptr;
|
||||
float* QuantBScale = nullptr;
|
||||
uint8_t* QuantBZeroPoint = nullptr;
|
||||
|
|
@ -138,20 +240,48 @@ class MlasSQNBitGemmTest : public MlasTestBase {
|
|||
|
||||
QuantBData = BufferQuantBData.GetBuffer(QuantBDataSizeInBytes);
|
||||
QuantBScale = BufferQuantBScale.GetBuffer(QuantBScaleSize);
|
||||
if (Symmetric) {
|
||||
if (!Symmetric) {
|
||||
QuantBZeroPoint = BufferQuantBZeroPoint.GetBuffer(QuantBZeroPointSizeInBytes);
|
||||
}
|
||||
|
||||
MlasQuantizeBlockwise<float, 4>(QuantBData, QuantBScale, QuantBZeroPoint,
|
||||
B, BlkLen,
|
||||
/* columnwise */ true,
|
||||
static_cast<int>(K), static_cast<int>(N),
|
||||
static_cast<int>(N),
|
||||
GetMlasThreadPool());
|
||||
MlasQuantizeBlockwise<float, BlkBitWidth>(QuantBData, QuantBScale, QuantBZeroPoint,
|
||||
B, BlkLen,
|
||||
/* columnwise */ true,
|
||||
static_cast<int>(K), static_cast<int>(N),
|
||||
static_cast<int>(N),
|
||||
GetMlasThreadPool());
|
||||
}
|
||||
|
||||
CallGemm(M, N, K, A, /* lda */ K, QuantBData, QuantBScale, QuantBZeroPoint, Bias, C, /* ldc */ N, Threadpool);
|
||||
CallReferenceGemm(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference);
|
||||
void* Workspace = nullptr;
|
||||
if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType);
|
||||
WorkspaceSize > 0) {
|
||||
Workspace = BufferWorkspace.GetBuffer(WorkspaceSize);
|
||||
}
|
||||
|
||||
void* PackedQuantBData = nullptr;
|
||||
if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen);
|
||||
PackedQuantBDataSize > 0) {
|
||||
PackedQuantBData = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize);
|
||||
MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData, PackedQuantBData, GetMlasThreadPool());
|
||||
}
|
||||
|
||||
if (ComputeType == CompFp32) {
|
||||
CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference);
|
||||
} else if (ComputeType == CompInt8) {
|
||||
CallReferenceGemm_CompInt8(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference);
|
||||
} else {
|
||||
FAIL() << "Test is not implemented for compute type "
|
||||
<< ComputeType << " (" << ComputeTypeName(ComputeType) << ")";
|
||||
}
|
||||
|
||||
CallGemm(M, N, K,
|
||||
A, /* lda */ K,
|
||||
QuantBData, PackedQuantBData, QuantBScale, QuantBZeroPoint,
|
||||
Bias,
|
||||
C, /* ldc */ N,
|
||||
Workspace,
|
||||
ComputeType,
|
||||
Threadpool);
|
||||
|
||||
size_t f = 0;
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
|
|
@ -179,74 +309,90 @@ template <size_t BlkBitWidth, size_t BlkLen>
|
|||
class SQNBitGemmShortExecuteTest : public MlasTestFixture<MlasSQNBitGemmTest<BlkBitWidth, BlkLen>> {
|
||||
public:
|
||||
explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
|
||||
bool WithThreadpool, bool Symmetric, bool WithBias)
|
||||
: M_(M), N_(N), K_(K), WithThreadpool_(WithThreadpool), Symmetric_(Symmetric), WithBias_(WithBias) {
|
||||
: M_(M),
|
||||
N_(N),
|
||||
K_(K),
|
||||
ComputeType_(ComputeType),
|
||||
WithThreadpool_(WithThreadpool),
|
||||
Symmetric_(Symmetric),
|
||||
WithBias_(WithBias) {
|
||||
}
|
||||
|
||||
void TestBody() override {
|
||||
MlasTestFixture<MlasSQNBitGemmTest<BlkBitWidth, BlkLen>>::mlas_tester->Test(
|
||||
M_, N_, K_, WithThreadpool_, Symmetric_, WithBias_);
|
||||
M_, N_, K_, ComputeType_, WithThreadpool_, Symmetric_, WithBias_);
|
||||
}
|
||||
|
||||
static size_t RegisterSingleTest(size_t M, size_t N, size_t K,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
|
||||
bool WithThreadpool, bool Symmetric, bool WithBias) {
|
||||
std::stringstream ss;
|
||||
ss << (WithThreadpool ? "SingleThread" : "Threaded")
|
||||
<< "/isSymmetric" << Symmetric
|
||||
<< "/M" << M << "xN" << N << "xK" << K
|
||||
<< "/hasBias" << WithBias;
|
||||
auto test_name = ss.str();
|
||||
size_t tests_registered = 0;
|
||||
|
||||
testing::RegisterTest(
|
||||
MlasSQNBitGemmTest<BlkBitWidth, BlkLen>::GetTestSuiteName(),
|
||||
test_name.c_str(),
|
||||
nullptr,
|
||||
test_name.c_str(),
|
||||
__FILE__,
|
||||
__LINE__,
|
||||
// Important to use the fixture type as the return type here.
|
||||
[=]() -> MlasTestFixture<MlasSQNBitGemmTest<BlkBitWidth, BlkLen>>* {
|
||||
return new SQNBitGemmShortExecuteTest(
|
||||
M, N, K, WithThreadpool, Symmetric, WithBias);
|
||||
});
|
||||
if (MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType)) {
|
||||
std::stringstream ss;
|
||||
ss << (WithThreadpool ? "SingleThread" : "Threaded")
|
||||
<< "/isSymmetric" << Symmetric
|
||||
<< "/M" << M << "xN" << N << "xK" << K
|
||||
<< "/hasBias" << WithBias
|
||||
<< "/computeType" << ComputeTypeName(ComputeType);
|
||||
auto test_name = ss.str();
|
||||
|
||||
return 1;
|
||||
testing::RegisterTest(
|
||||
MlasSQNBitGemmTest<BlkBitWidth, BlkLen>::GetTestSuiteName(),
|
||||
test_name.c_str(),
|
||||
nullptr,
|
||||
test_name.c_str(),
|
||||
__FILE__,
|
||||
__LINE__,
|
||||
// Important to use the fixture type as the return type here.
|
||||
[=]() -> MlasTestFixture<MlasSQNBitGemmTest<BlkBitWidth, BlkLen>>* {
|
||||
return new SQNBitGemmShortExecuteTest(
|
||||
M, N, K, ComputeType, WithThreadpool, Symmetric, WithBias);
|
||||
});
|
||||
|
||||
tests_registered += 1;
|
||||
}
|
||||
|
||||
return tests_registered;
|
||||
}
|
||||
|
||||
static size_t RegisterShortExecuteTests() {
|
||||
size_t test_registered = 0;
|
||||
size_t tests_registered = 0;
|
||||
|
||||
if (MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen)) {
|
||||
for (MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType : {CompFp32, CompInt8}) {
|
||||
for (bool WithThreadpool : {false, true}) {
|
||||
for (bool Symmetric : {false, true}) {
|
||||
for (size_t b = 1; b < 16; b++) {
|
||||
test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, false);
|
||||
test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true);
|
||||
tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, false);
|
||||
tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, true);
|
||||
}
|
||||
for (size_t b = 16; b <= 256; b <<= 1) {
|
||||
test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, false);
|
||||
test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true);
|
||||
tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, false);
|
||||
tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, true);
|
||||
}
|
||||
for (size_t b = 256; b < 320; b += 32) {
|
||||
test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true);
|
||||
tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, true);
|
||||
}
|
||||
for (size_t b = 1; b < 96; b++) {
|
||||
test_registered += RegisterSingleTest(1, b, 32, WithThreadpool, Symmetric, false);
|
||||
test_registered += RegisterSingleTest(1, 32, b, WithThreadpool, Symmetric, true);
|
||||
test_registered += RegisterSingleTest(1, b, b, WithThreadpool, Symmetric, false);
|
||||
tests_registered += RegisterSingleTest(1, b, 32, ComputeType, WithThreadpool, Symmetric, false);
|
||||
tests_registered += RegisterSingleTest(1, 32, b, ComputeType, WithThreadpool, Symmetric, true);
|
||||
tests_registered += RegisterSingleTest(1, b, b, ComputeType, WithThreadpool, Symmetric, false);
|
||||
}
|
||||
test_registered += RegisterSingleTest(43, 500, 401, WithThreadpool, Symmetric, true);
|
||||
tests_registered += RegisterSingleTest(43, 500, 401, ComputeType, WithThreadpool, Symmetric, true);
|
||||
|
||||
// test_registered += RegisterSingleTest(1001, 1027, 1031, WithThreadpool, Symmetric, false);
|
||||
// tests_registered += RegisterSingleTest(1001, 1027, 1031, ComputeType, WithThreadpool, Symmetric, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return test_registered;
|
||||
return tests_registered;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t M_, N_, K_;
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType_;
|
||||
bool WithThreadpool_, Symmetric_, WithBias_;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue