mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
SQNBitGemm - move workspace size calculation functions to hardware-specific implementations (#20757)
The workspace usage may be hardware-specific. Moving away from a common workspace size calculation allows more flexibility in the hardware-specific implementations.
This commit is contained in:
parent
d4fe4b5b51
commit
a39f8862fd
10 changed files with 256 additions and 89 deletions
|
|
@ -38,6 +38,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
|
|||
${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm.h
|
||||
${MLAS_SRC_DIR}/sqnbitgemm.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
|
||||
)
|
||||
|
||||
target_sources(onnxruntime_mlas PRIVATE
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ Abstract:
|
|||
--*/
|
||||
|
||||
#include "sqnbitgemm.h"
|
||||
#include "sqnbitgemm_q8_block.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
|
|
@ -91,54 +92,59 @@ MlasIsSQNBitGemmAvailable(
|
|||
namespace
|
||||
{
|
||||
|
||||
size_t
|
||||
SQNBitGemmWorkspaceAlignment(SQNBitGemmVariant Variant)
|
||||
{
|
||||
switch (Variant) {
|
||||
case SQNBitGemmVariant_BitWidth4_CompInt8: {
|
||||
return Q8BlkAlignment();
|
||||
}
|
||||
default: {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t
|
||||
SQNBitGemmPerGemmWorkspaceSize(
|
||||
SQNBitGemmVariant Variant,
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BlkLen
|
||||
size_t BlkBitWidth,
|
||||
size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
||||
)
|
||||
{
|
||||
MLAS_UNREFERENCED_PARAMETER(N);
|
||||
|
||||
switch (Variant) {
|
||||
case SQNBitGemmVariant_BitWidth4_CompInt8: {
|
||||
// workspace buffer is used for block quantization of A to int8
|
||||
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
|
||||
const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen);
|
||||
return PerGemmWorkspaceSize;
|
||||
}
|
||||
default: {
|
||||
return 0;
|
||||
}
|
||||
const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch;
|
||||
if (Dispatch == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceSize != nullptr) {
|
||||
return Dispatch->SQ4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t
|
||||
SQNBitGemmPerGemmWorkspaceAlignment(
|
||||
size_t BlkBitWidth,
|
||||
size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
||||
)
|
||||
{
|
||||
const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch;
|
||||
if (Dispatch == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment != nullptr) {
|
||||
return Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType);
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
size_t
|
||||
SQNBitGemmPerGemmWorkspaceStride(
|
||||
SQNBitGemmVariant Variant,
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BlkLen
|
||||
size_t BlkBitWidth,
|
||||
size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
||||
)
|
||||
{
|
||||
const auto Size = SQNBitGemmPerGemmWorkspaceSize(Variant, M, N, K, BlkLen);
|
||||
const auto Alignment = SQNBitGemmWorkspaceAlignment(Variant);
|
||||
const auto Size = SQNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType);
|
||||
const auto Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType);
|
||||
return MlasDivRoundup(Size, Alignment) * Alignment;
|
||||
}
|
||||
|
||||
|
|
@ -155,14 +161,12 @@ MlasSQNBitGemmBatchWorkspaceSize(
|
|||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
||||
)
|
||||
{
|
||||
const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType);
|
||||
|
||||
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen);
|
||||
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType);
|
||||
if (PerGemmWorkspaceStride == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant);
|
||||
const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType);
|
||||
|
||||
const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride;
|
||||
|
||||
|
|
@ -574,14 +578,14 @@ MlasSQNBitGemmBatch(
|
|||
// Ensure `Workspace` has correct alignment.
|
||||
//
|
||||
if (Workspace != nullptr) {
|
||||
const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant);
|
||||
const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType);
|
||||
const uintptr_t WorkspaceAddress = reinterpret_cast<uintptr_t>(Workspace);
|
||||
Workspace = reinterpret_cast<void*>(
|
||||
(WorkspaceAddress + Alignment - 1) & (~(Alignment - 1))
|
||||
);
|
||||
}
|
||||
|
||||
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen);
|
||||
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType);
|
||||
|
||||
if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace;
|
||||
InitializeWorkspaceOperation != nullptr) {
|
||||
|
|
|
|||
|
|
@ -22,8 +22,6 @@ Abstract:
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlas_qnbit.h"
|
||||
#include "mlasi.h"
|
||||
|
||||
|
|
@ -44,56 +42,6 @@ MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount)
|
|||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Quantized int8 block helpers.
|
||||
//
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
const float&
|
||||
Q8BlkScale(const std::byte* BlkPtr)
|
||||
{
|
||||
return *reinterpret_cast<const float*>(BlkPtr);
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
float&
|
||||
Q8BlkScale(std::byte* BlkPtr)
|
||||
{
|
||||
return *reinterpret_cast<float*>(BlkPtr);
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
const int8_t*
|
||||
Q8BlkData(const std::byte* BlkPtr)
|
||||
{
|
||||
return reinterpret_cast<const int8_t*>(BlkPtr + sizeof(float));
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
int8_t*
|
||||
Q8BlkData(std::byte* BlkPtr)
|
||||
{
|
||||
return reinterpret_cast<int8_t*>(BlkPtr + sizeof(float));
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
constexpr size_t
|
||||
Q8BlkSize(size_t BlkLen)
|
||||
{
|
||||
const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t);
|
||||
// Currently, the strictest alignment requirement of a block is for a float.
|
||||
// Ensure contiguous blocks are suitably aligned.
|
||||
assert(BlkSize % alignof(float) == 0);
|
||||
return BlkSize;
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
constexpr size_t
|
||||
Q8BlkAlignment()
|
||||
{
|
||||
return alignof(float);
|
||||
}
|
||||
|
||||
//
|
||||
// Kernel dispatch structure.
|
||||
//
|
||||
|
|
@ -126,6 +74,43 @@ struct MLAS_SQNBIT_GEMM_DISPATCH {
|
|||
|
||||
SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr;
|
||||
|
||||
//
|
||||
// Workspace size calculation function prototypes.
|
||||
//
|
||||
|
||||
/**
|
||||
* @brief Gets the required size in bytes of the per-GEMM intermediate workspace.
|
||||
* Returns a size of zero if no intermediate workspace is needed.
|
||||
*
|
||||
* @param[in] M row size of matrix A and C
|
||||
* @param[in] N column size of matrix B and C
|
||||
* @param[in] K column size of matrix A and row size of matrix B
|
||||
* @param[in] BlkLen number of quantized values per block
|
||||
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
|
||||
*/
|
||||
typedef size_t(SQ4BitGemmPerGemmWorkspaceSize_Fn)(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
||||
);
|
||||
|
||||
SQ4BitGemmPerGemmWorkspaceSize_Fn* SQ4BitGemmPerGemmWorkspaceSize = nullptr;
|
||||
|
||||
/**
|
||||
* @brief Gets the required byte alignment of the per-GEMM intermediate workspace.
|
||||
*
|
||||
* @param[in] BlkLen number of quantized values per block
|
||||
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
|
||||
*/
|
||||
typedef size_t(SQ4BitGemmPerGemmWorkspaceAlignment_Fn)(
|
||||
size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
||||
);
|
||||
|
||||
SQ4BitGemmPerGemmWorkspaceAlignment_Fn* SQ4BitGemmPerGemmWorkspaceAlignment = nullptr;
|
||||
|
||||
//
|
||||
// CompFp32 kernel function prototypes.
|
||||
//
|
||||
|
|
|
|||
|
|
@ -1103,6 +1103,9 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() {
|
|||
d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize;
|
||||
d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData;
|
||||
|
||||
d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize;
|
||||
d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment;
|
||||
|
||||
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2;
|
||||
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;
|
||||
|
||||
|
|
|
|||
|
|
@ -233,6 +233,9 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() {
|
|||
d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize;
|
||||
d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData;
|
||||
|
||||
d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize;
|
||||
d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment;
|
||||
|
||||
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512;
|
||||
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;
|
||||
|
||||
|
|
|
|||
|
|
@ -254,6 +254,9 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() {
|
|||
d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize;
|
||||
d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData;
|
||||
|
||||
d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize;
|
||||
d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment;
|
||||
|
||||
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32;
|
||||
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
#include "sqnbitgemm.h"
|
||||
#include "sqnbitgemm_q8_block.h"
|
||||
|
||||
//
|
||||
// Quantized B data packing function implementation.
|
||||
|
|
@ -99,6 +100,52 @@ SQ4BitGemmPackQuantBData(
|
|||
);
|
||||
}
|
||||
|
||||
//
|
||||
// Workspace size calculation function implementation.
|
||||
//
|
||||
|
||||
static size_t
|
||||
SQ4BitGemmPerGemmWorkspaceSize(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
||||
)
|
||||
{
|
||||
MLAS_UNREFERENCED_PARAMETER(N);
|
||||
|
||||
switch(ComputeType) {
|
||||
case CompInt8: {
|
||||
// workspace buffer is used for block quantization of A to int8
|
||||
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
|
||||
const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen);
|
||||
return PerGemmWorkspaceSize;
|
||||
}
|
||||
default: {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static size_t
|
||||
SQ4BitGemmPerGemmWorkspaceAlignment(
|
||||
size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
||||
)
|
||||
{
|
||||
MLAS_UNREFERENCED_PARAMETER(BlkLen);
|
||||
|
||||
switch (ComputeType) {
|
||||
case CompInt8: {
|
||||
return Q8BlkAlignment();
|
||||
}
|
||||
default: {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
Q4BitBlkDequantBForSgemm_CompFp32_avx2(
|
||||
const size_t BlkLen,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
#include "sqnbitgemm.h"
|
||||
#include "sqnbitgemm_kernel_avx_common.h"
|
||||
#include "sqnbitgemm_q8_block.h"
|
||||
|
||||
void
|
||||
SQ4BitGemmM1Kernel_CompInt8_avx2(
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ Abstract:
|
|||
#include <utility>
|
||||
|
||||
#include "sqnbitgemm.h"
|
||||
#include "sqnbitgemm_q8_block.h"
|
||||
|
||||
//
|
||||
// Quantized B data packing function implementation.
|
||||
|
|
@ -118,6 +119,52 @@ SQ4BitGemmPackQuantBData(
|
|||
);
|
||||
}
|
||||
|
||||
//
|
||||
// Workspace size calculation function implementation.
|
||||
//
|
||||
|
||||
size_t
|
||||
SQ4BitGemmPerGemmWorkspaceSize(
|
||||
size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
||||
)
|
||||
{
|
||||
MLAS_UNREFERENCED_PARAMETER(N);
|
||||
|
||||
switch(ComputeType) {
|
||||
case CompInt8: {
|
||||
// workspace buffer is used for block quantization of A to int8
|
||||
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
|
||||
const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen);
|
||||
return PerGemmWorkspaceSize;
|
||||
}
|
||||
default: {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t
|
||||
SQ4BitGemmPerGemmWorkspaceAlignment(
|
||||
size_t BlkLen,
|
||||
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
|
||||
)
|
||||
{
|
||||
MLAS_UNREFERENCED_PARAMETER(BlkLen);
|
||||
|
||||
switch (ComputeType) {
|
||||
case CompInt8: {
|
||||
return Q8BlkAlignment();
|
||||
}
|
||||
default: {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//
|
||||
|
|
@ -1441,6 +1488,9 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() {
|
|||
d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize;
|
||||
d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData;
|
||||
|
||||
d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize;
|
||||
d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment;
|
||||
|
||||
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32;
|
||||
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32;
|
||||
|
||||
|
|
|
|||
70
onnxruntime/core/mlas/lib/sqnbitgemm_q8_block.h
Normal file
70
onnxruntime/core/mlas/lib/sqnbitgemm_q8_block.h
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Module Name:
|
||||
|
||||
sqnbitgemm_q8_block.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This module includes helper functions for manipulating blocks of quantized
|
||||
int8 (Q8) values.
|
||||
|
||||
--*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include "mlasi.h"
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
const float&
|
||||
Q8BlkScale(const std::byte* BlkPtr)
|
||||
{
|
||||
return *reinterpret_cast<const float*>(BlkPtr);
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
float&
|
||||
Q8BlkScale(std::byte* BlkPtr)
|
||||
{
|
||||
return *reinterpret_cast<float*>(BlkPtr);
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
const int8_t*
|
||||
Q8BlkData(const std::byte* BlkPtr)
|
||||
{
|
||||
return reinterpret_cast<const int8_t*>(BlkPtr + sizeof(float));
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
int8_t*
|
||||
Q8BlkData(std::byte* BlkPtr)
|
||||
{
|
||||
return reinterpret_cast<int8_t*>(BlkPtr + sizeof(float));
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
constexpr size_t
|
||||
Q8BlkSize(size_t BlkLen)
|
||||
{
|
||||
const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t);
|
||||
// Currently, the strictest alignment requirement of a block is for a float.
|
||||
// Ensure contiguous blocks are suitably aligned.
|
||||
assert(BlkSize % alignof(float) == 0);
|
||||
return BlkSize;
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
constexpr size_t
|
||||
Q8BlkAlignment()
|
||||
{
|
||||
return alignof(float);
|
||||
}
|
||||
Loading…
Reference in a new issue