diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 682dcfc5fe..304aa77f54 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -38,6 +38,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp ${MLAS_SRC_DIR}/sqnbitgemm.h ${MLAS_SRC_DIR}/sqnbitgemm.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ) target_sources(onnxruntime_mlas PRIVATE diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 4fd28f5e69..4b852be951 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -16,6 +16,7 @@ Abstract: --*/ #include "sqnbitgemm.h" +#include "sqnbitgemm_q8_block.h" #include @@ -91,54 +92,59 @@ MlasIsSQNBitGemmAvailable( namespace { -size_t -SQNBitGemmWorkspaceAlignment(SQNBitGemmVariant Variant) -{ - switch (Variant) { - case SQNBitGemmVariant_BitWidth4_CompInt8: { - return Q8BlkAlignment(); - } - default: { - return 1; - } - } -} - size_t SQNBitGemmPerGemmWorkspaceSize( - SQNBitGemmVariant Variant, size_t M, size_t N, size_t K, - size_t BlkLen + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - MLAS_UNREFERENCED_PARAMETER(N); - - switch (Variant) { - case SQNBitGemmVariant_BitWidth4_CompInt8: { - // workspace buffer is used for block quantization of A to int8 - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); - return PerGemmWorkspaceSize; - } - default: { - return 0; - } + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return 0; } + + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceSize != nullptr) { + return Dispatch->SQ4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); + } + + return 0; +} + +size_t +SQNBitGemmPerGemmWorkspaceAlignment( + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return 1; + } + + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment != nullptr) { + return Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); + } + + return 1; } size_t SQNBitGemmPerGemmWorkspaceStride( - SQNBitGemmVariant Variant, size_t M, size_t N, size_t K, - size_t BlkLen + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto Size = SQNBitGemmPerGemmWorkspaceSize(Variant, M, N, K, BlkLen); - const auto Alignment = SQNBitGemmWorkspaceAlignment(Variant); + const auto Size = SQNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); return MlasDivRoundup(Size, Alignment) * Alignment; } @@ -155,14 +161,12 @@ MlasSQNBitGemmBatchWorkspaceSize( MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); - - const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen); + const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); if (PerGemmWorkspaceStride == 0) { return 0; } - const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant); + const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride; @@ -574,14 +578,14 @@ MlasSQNBitGemmBatch( // Ensure `Workspace` has correct alignment. // if (Workspace != nullptr) { - const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant); + const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); const uintptr_t WorkspaceAddress = reinterpret_cast(Workspace); Workspace = reinterpret_cast( (WorkspaceAddress + Alignment - 1) & (~(Alignment - 1)) ); } - const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen); + const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace; InitializeWorkspaceOperation != nullptr) { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 318a51e1c8..effb59b250 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -22,8 +22,6 @@ Abstract: #pragma once -#include - #include "mlas_qnbit.h" #include "mlasi.h" @@ -44,56 +42,6 @@ MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) } } -// -// Quantized int8 block helpers. -// - -MLAS_FORCEINLINE -const float& -Q8BlkScale(const std::byte* BlkPtr) -{ - return *reinterpret_cast(BlkPtr); -} - -MLAS_FORCEINLINE -float& -Q8BlkScale(std::byte* BlkPtr) -{ - return *reinterpret_cast(BlkPtr); -} - -MLAS_FORCEINLINE -const int8_t* -Q8BlkData(const std::byte* BlkPtr) -{ - return reinterpret_cast(BlkPtr + sizeof(float)); -} - -MLAS_FORCEINLINE -int8_t* -Q8BlkData(std::byte* BlkPtr) -{ - return reinterpret_cast(BlkPtr + sizeof(float)); -} - -MLAS_FORCEINLINE -constexpr size_t -Q8BlkSize(size_t BlkLen) -{ - const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t); - // Currently, the strictest alignment requirement of a block is for a float. - // Ensure contiguous blocks are suitably aligned. - assert(BlkSize % alignof(float) == 0); - return BlkSize; -} - -MLAS_FORCEINLINE -constexpr size_t -Q8BlkAlignment() -{ - return alignof(float); -} - // // Kernel dispatch structure. // @@ -126,6 +74,43 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + // + // Workspace size calculation function prototypes. + // + + /** + * @brief Gets the required size in bytes of the per-GEMM intermediate workspace. + * Returns a size of zero if no intermediate workspace is needed. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + */ + typedef size_t(SQ4BitGemmPerGemmWorkspaceSize_Fn)( + size_t M, + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + SQ4BitGemmPerGemmWorkspaceSize_Fn* SQ4BitGemmPerGemmWorkspaceSize = nullptr; + + /** + * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. + * + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + */ + typedef size_t(SQ4BitGemmPerGemmWorkspaceAlignment_Fn)( + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + SQ4BitGemmPerGemmWorkspaceAlignment_Fn* SQ4BitGemmPerGemmWorkspaceAlignment = nullptr; + // // CompFp32 kernel function prototypes. // diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index b5d7a4e78f..be573381c3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -1103,6 +1103,9 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 1eca0960cf..0099b61d81 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -233,6 +233,9 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 45a69c4f20..27310d8253 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -254,6 +254,9 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index abace949a1..cfc0564cd0 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -1,5 +1,6 @@ #pragma once #include "sqnbitgemm.h" +#include "sqnbitgemm_q8_block.h" // // Quantized B data packing function implementation. @@ -99,6 +100,52 @@ SQ4BitGemmPackQuantBData( ); } +// +// Workspace size calculation function implementation. +// + +static size_t +SQ4BitGemmPerGemmWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(N); + + switch(ComputeType) { + case CompInt8: { + // workspace buffer is used for block quantization of A to int8 + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + return PerGemmWorkspaceSize; + } + default: { + return 0; + } + } +} + +static size_t +SQ4BitGemmPerGemmWorkspaceAlignment( + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(BlkLen); + + switch (ComputeType) { + case CompInt8: { + return Q8BlkAlignment(); + } + default: { + return 1; + } + } +} + void Q4BitBlkDequantBForSgemm_CompFp32_avx2( const size_t BlkLen, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h index 8f8506cb3a..250ffeacd7 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h @@ -5,6 +5,7 @@ #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_q8_block.h" void SQ4BitGemmM1Kernel_CompInt8_avx2( diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index ffa8b79ebd..6d1864794f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -22,6 +22,7 @@ Abstract: #include #include "sqnbitgemm.h" +#include "sqnbitgemm_q8_block.h" // // Quantized B data packing function implementation. @@ -118,6 +119,52 @@ SQ4BitGemmPackQuantBData( ); } +// +// Workspace size calculation function implementation. +// + +size_t +SQ4BitGemmPerGemmWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(N); + + switch(ComputeType) { + case CompInt8: { + // workspace buffer is used for block quantization of A to int8 + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + return PerGemmWorkspaceSize; + } + default: { + return 0; + } + } +} + +size_t +SQ4BitGemmPerGemmWorkspaceAlignment( + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(BlkLen); + + switch (ComputeType) { + case CompInt8: { + return Q8BlkAlignment(); + } + default: { + return 1; + } + } +} + } // namespace // @@ -1441,6 +1488,9 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_q8_block.h b/onnxruntime/core/mlas/lib/sqnbitgemm_q8_block.h new file mode 100644 index 0000000000..80af2f4679 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_q8_block.h @@ -0,0 +1,70 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_q8_block.h + +Abstract: + + This module includes helper functions for manipulating blocks of quantized + int8 (Q8) values. + +--*/ + +#pragma once + +#include +#include +#include + +#include "mlasi.h" + +MLAS_FORCEINLINE +const float& +Q8BlkScale(const std::byte* BlkPtr) +{ + return *reinterpret_cast(BlkPtr); +} + +MLAS_FORCEINLINE +float& +Q8BlkScale(std::byte* BlkPtr) +{ + return *reinterpret_cast(BlkPtr); +} + +MLAS_FORCEINLINE +const int8_t* +Q8BlkData(const std::byte* BlkPtr) +{ + return reinterpret_cast(BlkPtr + sizeof(float)); +} + +MLAS_FORCEINLINE +int8_t* +Q8BlkData(std::byte* BlkPtr) +{ + return reinterpret_cast(BlkPtr + sizeof(float)); +} + +MLAS_FORCEINLINE +constexpr size_t +Q8BlkSize(size_t BlkLen) +{ + const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t); + // Currently, the strictest alignment requirement of a block is for a float. + // Ensure contiguous blocks are suitably aligned. + assert(BlkSize % alignof(float) == 0); + return BlkSize; +} + +MLAS_FORCEINLINE +constexpr size_t +Q8BlkAlignment() +{ + return alignof(float); +}