SQNBitGemm - move workspace size calculation functions to hardware-specific implementations (#20757)

The workspace usage may be hardware-specific. Moving away from a common workspace size calculation allows more flexibility in the hardware-specific implementations.
This commit is contained in:
Edward Chen 2024-05-22 15:12:17 -07:00 committed by GitHub
parent d4fe4b5b51
commit a39f8862fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 256 additions and 89 deletions

View file

@ -38,6 +38,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp
${MLAS_SRC_DIR}/sqnbitgemm.h
${MLAS_SRC_DIR}/sqnbitgemm.cpp
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
)
target_sources(onnxruntime_mlas PRIVATE

View file

@ -16,6 +16,7 @@ Abstract:
--*/
#include "sqnbitgemm.h"
#include "sqnbitgemm_q8_block.h"
#include <cassert>
@ -91,54 +92,59 @@ MlasIsSQNBitGemmAvailable(
namespace
{
size_t
SQNBitGemmWorkspaceAlignment(SQNBitGemmVariant Variant)
{
switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompInt8: {
return Q8BlkAlignment();
}
default: {
return 1;
}
}
}
size_t
SQNBitGemmPerGemmWorkspaceSize(
SQNBitGemmVariant Variant,
size_t M,
size_t N,
size_t K,
size_t BlkLen
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
MLAS_UNREFERENCED_PARAMETER(N);
switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompInt8: {
// workspace buffer is used for block quantization of A to int8
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen);
return PerGemmWorkspaceSize;
}
default: {
return 0;
}
const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch;
if (Dispatch == nullptr) {
return 0;
}
if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceSize != nullptr) {
return Dispatch->SQ4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType);
}
return 0;
}
size_t
SQNBitGemmPerGemmWorkspaceAlignment(
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch;
if (Dispatch == nullptr) {
return 1;
}
if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment != nullptr) {
return Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType);
}
return 1;
}
size_t
SQNBitGemmPerGemmWorkspaceStride(
SQNBitGemmVariant Variant,
size_t M,
size_t N,
size_t K,
size_t BlkLen
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
const auto Size = SQNBitGemmPerGemmWorkspaceSize(Variant, M, N, K, BlkLen);
const auto Alignment = SQNBitGemmWorkspaceAlignment(Variant);
const auto Size = SQNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType);
const auto Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType);
return MlasDivRoundup(Size, Alignment) * Alignment;
}
@ -155,14 +161,12 @@ MlasSQNBitGemmBatchWorkspaceSize(
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType);
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen);
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType);
if (PerGemmWorkspaceStride == 0) {
return 0;
}
const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant);
const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType);
const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride;
@ -574,14 +578,14 @@ MlasSQNBitGemmBatch(
// Ensure `Workspace` has correct alignment.
//
if (Workspace != nullptr) {
const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant);
const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType);
const uintptr_t WorkspaceAddress = reinterpret_cast<uintptr_t>(Workspace);
Workspace = reinterpret_cast<void*>(
(WorkspaceAddress + Alignment - 1) & (~(Alignment - 1))
);
}
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen);
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType);
if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace;
InitializeWorkspaceOperation != nullptr) {

View file

@ -22,8 +22,6 @@ Abstract:
#pragma once
#include <cassert>
#include "mlas_qnbit.h"
#include "mlasi.h"
@ -44,56 +42,6 @@ MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount)
}
}
//
// Quantized int8 block helpers.
//
MLAS_FORCEINLINE
const float&
Q8BlkScale(const std::byte* BlkPtr)
{
return *reinterpret_cast<const float*>(BlkPtr);
}
MLAS_FORCEINLINE
float&
Q8BlkScale(std::byte* BlkPtr)
{
return *reinterpret_cast<float*>(BlkPtr);
}
MLAS_FORCEINLINE
const int8_t*
Q8BlkData(const std::byte* BlkPtr)
{
return reinterpret_cast<const int8_t*>(BlkPtr + sizeof(float));
}
MLAS_FORCEINLINE
int8_t*
Q8BlkData(std::byte* BlkPtr)
{
return reinterpret_cast<int8_t*>(BlkPtr + sizeof(float));
}
MLAS_FORCEINLINE
constexpr size_t
Q8BlkSize(size_t BlkLen)
{
const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t);
// Currently, the strictest alignment requirement of a block is for a float.
// Ensure contiguous blocks are suitably aligned.
assert(BlkSize % alignof(float) == 0);
return BlkSize;
}
MLAS_FORCEINLINE
constexpr size_t
Q8BlkAlignment()
{
return alignof(float);
}
//
// Kernel dispatch structure.
//
@ -126,6 +74,43 @@ struct MLAS_SQNBIT_GEMM_DISPATCH {
SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr;
//
// Workspace size calculation function prototypes.
//
/**
* @brief Gets the required size in bytes of the per-GEMM intermediate workspace.
* Returns a size of zero if no intermediate workspace is needed.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
typedef size_t(SQ4BitGemmPerGemmWorkspaceSize_Fn)(
size_t M,
size_t N,
size_t K,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
);
SQ4BitGemmPerGemmWorkspaceSize_Fn* SQ4BitGemmPerGemmWorkspaceSize = nullptr;
/**
* @brief Gets the required byte alignment of the per-GEMM intermediate workspace.
*
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
typedef size_t(SQ4BitGemmPerGemmWorkspaceAlignment_Fn)(
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
);
SQ4BitGemmPerGemmWorkspaceAlignment_Fn* SQ4BitGemmPerGemmWorkspaceAlignment = nullptr;
//
// CompFp32 kernel function prototypes.
//

View file

@ -1103,6 +1103,9 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() {
d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize;
d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData;
d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize;
d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment;
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

View file

@ -233,6 +233,9 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() {
d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize;
d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData;
d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize;
d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment;
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

View file

@ -254,6 +254,9 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() {
d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize;
d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData;
d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize;
d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment;
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

View file

@ -1,5 +1,6 @@
#pragma once
#include "sqnbitgemm.h"
#include "sqnbitgemm_q8_block.h"
//
// Quantized B data packing function implementation.
@ -99,6 +100,52 @@ SQ4BitGemmPackQuantBData(
);
}
//
// Workspace size calculation function implementation.
//
static size_t
SQ4BitGemmPerGemmWorkspaceSize(
size_t M,
size_t N,
size_t K,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
MLAS_UNREFERENCED_PARAMETER(N);
switch(ComputeType) {
case CompInt8: {
// workspace buffer is used for block quantization of A to int8
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen);
return PerGemmWorkspaceSize;
}
default: {
return 0;
}
}
}
static size_t
SQ4BitGemmPerGemmWorkspaceAlignment(
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
MLAS_UNREFERENCED_PARAMETER(BlkLen);
switch (ComputeType) {
case CompInt8: {
return Q8BlkAlignment();
}
default: {
return 1;
}
}
}
void
Q4BitBlkDequantBForSgemm_CompFp32_avx2(
const size_t BlkLen,

View file

@ -5,6 +5,7 @@
#include "sqnbitgemm.h"
#include "sqnbitgemm_kernel_avx_common.h"
#include "sqnbitgemm_q8_block.h"
void
SQ4BitGemmM1Kernel_CompInt8_avx2(

View file

@ -22,6 +22,7 @@ Abstract:
#include <utility>
#include "sqnbitgemm.h"
#include "sqnbitgemm_q8_block.h"
//
// Quantized B data packing function implementation.
@ -118,6 +119,52 @@ SQ4BitGemmPackQuantBData(
);
}
//
// Workspace size calculation function implementation.
//
size_t
SQ4BitGemmPerGemmWorkspaceSize(
size_t M,
size_t N,
size_t K,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
MLAS_UNREFERENCED_PARAMETER(N);
switch(ComputeType) {
case CompInt8: {
// workspace buffer is used for block quantization of A to int8
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen);
return PerGemmWorkspaceSize;
}
default: {
return 0;
}
}
}
size_t
SQ4BitGemmPerGemmWorkspaceAlignment(
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
MLAS_UNREFERENCED_PARAMETER(BlkLen);
switch (ComputeType) {
case CompInt8: {
return Q8BlkAlignment();
}
default: {
return 1;
}
}
}
} // namespace
//
@ -1441,6 +1488,9 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() {
d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize;
d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData;
d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize;
d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment;
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32;

View file

@ -0,0 +1,70 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
sqnbitgemm_q8_block.h
Abstract:
This module includes helper functions for manipulating blocks of quantized
int8 (Q8) values.
--*/
#pragma once
#include <cassert>
#include <cstddef>
#include <cstdint>
#include "mlasi.h"
MLAS_FORCEINLINE
const float&
Q8BlkScale(const std::byte* BlkPtr)
{
return *reinterpret_cast<const float*>(BlkPtr);
}
MLAS_FORCEINLINE
float&
Q8BlkScale(std::byte* BlkPtr)
{
return *reinterpret_cast<float*>(BlkPtr);
}
MLAS_FORCEINLINE
const int8_t*
Q8BlkData(const std::byte* BlkPtr)
{
return reinterpret_cast<const int8_t*>(BlkPtr + sizeof(float));
}
MLAS_FORCEINLINE
int8_t*
Q8BlkData(std::byte* BlkPtr)
{
return reinterpret_cast<int8_t*>(BlkPtr + sizeof(float));
}
MLAS_FORCEINLINE
constexpr size_t
Q8BlkSize(size_t BlkLen)
{
const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t);
// Currently, the strictest alignment requirement of a block is for a float.
// Ensure contiguous blocks are suitably aligned.
assert(BlkSize % alignof(float) == 0);
return BlkSize;
}
MLAS_FORCEINLINE
constexpr size_t
Q8BlkAlignment()
{
return alignof(float);
}