mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
[MLAS][AArch64] SQNBitGemm M>1 CompFp32 kernel optimization (#20319)
Add ARM NEON intrinsics implementation for `Q4BitBlkDequantBForSgemm_CompFp32`.
This commit is contained in:
parent
6bd6d879a3
commit
ccaa4d1db2
3 changed files with 324 additions and 107 deletions
|
|
@ -168,6 +168,10 @@ struct MLAS_SQNBIT_GEMM_DISPATCH {
|
|||
*
|
||||
* @param BlkLen Number of values in a block.
|
||||
* @param[out] FpData Supplies the output buffer for the dequantized B float data.
|
||||
* It should have enough space for
|
||||
* (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen
|
||||
* elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are
|
||||
* useful, but the kernel implementation can be simplified with the extra space.
|
||||
* @param QuantBData Supplies the quantized B matrix block data.
|
||||
* @param QuantBScale Supplies the quantized B matrix block scale values.
|
||||
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
|
||||
|
|
|
|||
|
|
@ -141,8 +141,8 @@ UnrolledLoop(IterationFn&& f)
|
|||
UnrolledLoopIterations(std::forward<IterationFn>(f), std::make_index_sequence<N>());
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE float32x4_t
|
||||
FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3)
|
||||
MLAS_FORCEINLINE void
|
||||
Transpose4x4(float32x4_t& a0, float32x4_t& a1, float32x4_t& a2, float32x4_t& a3)
|
||||
{
|
||||
// aN: aN_0 aN_1 aN_2 aN_3
|
||||
|
||||
|
|
@ -159,7 +159,12 @@ FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3)
|
|||
a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3)));
|
||||
// a0_3 a1_3 a2_3 a3_3
|
||||
a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3)));
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE float32x4_t
|
||||
FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3)
|
||||
{
|
||||
Transpose4x4(a0, a1, a2, a3);
|
||||
return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3));
|
||||
}
|
||||
|
||||
|
|
@ -205,6 +210,26 @@ LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4])
|
|||
namespace
|
||||
{
|
||||
|
||||
namespace fp32_conversion
|
||||
{
|
||||
|
||||
// Manual conversion to float takes place in two steps:
|
||||
// 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f].
|
||||
// This target float range is convenient because the 4-bit source values can be placed directly into the
|
||||
// target float bits.
|
||||
// 2. Subtract the conversion offset of 16 from the float result.
|
||||
|
||||
// The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values.
|
||||
constexpr uint16_t float_high_half_template = 0b0'10000011'0000000;
|
||||
// sign|exponent|partial mantissa
|
||||
// +|131: 2^4|~~~~ <- 4 bits go here
|
||||
|
||||
const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template);
|
||||
|
||||
constexpr float offset = 16.0f;
|
||||
|
||||
} // namespace fp32_conversion
|
||||
|
||||
template <size_t NCols, bool HasZeroPoint>
|
||||
MLAS_FORCEINLINE void
|
||||
ComputeDotProducts_BlkBitWidth4_CompFp32(
|
||||
|
|
@ -230,25 +255,12 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
|
|||
|
||||
const uint8x8_t LowMask = vdup_n_u8(0x0F);
|
||||
|
||||
// Manual conversion to float takes place in two steps:
|
||||
// 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f].
|
||||
// This target float range is convenient because the 4-bit source values can be placed directly into the
|
||||
// target float bits.
|
||||
// 2. Subtract the conversion offset of 16 from the float result.
|
||||
|
||||
// The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values.
|
||||
constexpr uint16_t float_high_half_template = 0b0'10000011'0000000;
|
||||
// sign|exponent|partial mantissa
|
||||
// +|131: 2^4|~~~~ <- 4 bits go here
|
||||
|
||||
const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template);
|
||||
|
||||
float32x4_t acc[NCols]{};
|
||||
|
||||
const std::byte* QuantBData = QuantBDataColPtr;
|
||||
const float* QuantBScale = QuantBScaleColPtr;
|
||||
[[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
|
||||
// only used if HasZeroPoint == true
|
||||
// only used if HasZeroPoint is true
|
||||
|
||||
for (size_t k = 0; k < CountK; k += BlkLen) {
|
||||
const size_t k_blk_len = std::min(CountK - k, BlkLen);
|
||||
|
|
@ -258,8 +270,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
|
|||
[&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; }
|
||||
);
|
||||
|
||||
[[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset of 16.
|
||||
// only used if HasZeroPoint == true
|
||||
[[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset.
|
||||
// only used if HasZeroPoint is true
|
||||
if constexpr (HasZeroPoint) {
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
const std::byte zp_packed =
|
||||
|
|
@ -267,7 +279,7 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
|
|||
const std::byte zp = ((QuantBZeroPointIdx & 1) == 1)
|
||||
? (zp_packed >> 4)
|
||||
: (zp_packed & std::byte{0x0F});
|
||||
offset[i] = 16.0f + std::to_integer<uint8_t>(zp);
|
||||
offset[i] = fp32_conversion::offset + std::to_integer<uint8_t>(zp);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -304,8 +316,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
|
|||
|
||||
// combine 4 bits with float high half template
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
bv_u16[i][0] = vorrq_u16(bv_u16[i][0], float_high_half_template_v);
|
||||
bv_u16[i][1] = vorrq_u16(bv_u16[i][1], float_high_half_template_v);
|
||||
bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v);
|
||||
bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v);
|
||||
});
|
||||
|
||||
// `SubBlkLen` floats of B
|
||||
|
|
@ -321,14 +333,14 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
|
|||
bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift));
|
||||
});
|
||||
|
||||
// subtract float conversion offset (16) and zero point
|
||||
// subtract float conversion offset and zero point
|
||||
if constexpr (HasZeroPoint) {
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
const float32x4_t offset_v = vdupq_n_f32(offset[i]);
|
||||
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
|
||||
});
|
||||
} else {
|
||||
const float32x4_t offset_v = vdupq_n_f32(16.0f + 8.0f);
|
||||
const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f);
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
|
||||
});
|
||||
|
|
@ -454,7 +466,7 @@ SQ4BitGemmM1Kernel_CompFp32_Impl(
|
|||
}
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE void
|
||||
void
|
||||
SQ4BitGemmM1Kernel_CompFp32(
|
||||
size_t BlkLen,
|
||||
const float* A,
|
||||
|
|
@ -497,7 +509,247 @@ SQ4BitGemmM1Kernel_CompFp32(
|
|||
}
|
||||
}
|
||||
|
||||
// Block dequantize a 16 x NCols section of B from column major source to row major destination.
|
||||
template <size_t NCols, bool HasZeroPoint>
|
||||
MLAS_FORCEINLINE void
|
||||
Q4BitBlkDequantB_16xNCols(
|
||||
const std::byte* QuantBDataPtr,
|
||||
size_t StrideQuantBData,
|
||||
const float* QuantBColScalePtr, // pointer to NCols scales of adjacent columns
|
||||
[[maybe_unused]] const float* QuantBColOffsetPtr, // pointer to NCols offsets of adjacent columns
|
||||
// only used if HasZeroPoint is true
|
||||
float* DstColPtr
|
||||
)
|
||||
{
|
||||
const uint8x8_t LowMask = vdup_n_u8(0x0F);
|
||||
|
||||
// load B column vectors
|
||||
uint8x8_t bv_packed[NCols];
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
bv_packed[i] = vld1_u8(
|
||||
reinterpret_cast<const uint8_t*>(QuantBDataPtr) + i * StrideQuantBData
|
||||
);
|
||||
});
|
||||
|
||||
uint8x8_t bv_u8[NCols][2];
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
bv_u8[i][0] = vand_u8(bv_packed[i], LowMask);
|
||||
bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4);
|
||||
});
|
||||
|
||||
// shift left 3 and widen to 16 bits
|
||||
uint16x8_t bv_u16[NCols][2];
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
constexpr int shift = 3;
|
||||
bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift);
|
||||
bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift);
|
||||
});
|
||||
|
||||
// combine 4 bits with float high half template
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v);
|
||||
bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v);
|
||||
});
|
||||
|
||||
// `SubBlkLen` floats of B
|
||||
float32x4_t bv[NCols][4];
|
||||
|
||||
// shift left 16, widen to 32 bits, and reinterpret as float
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
constexpr int shift = 16;
|
||||
bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift));
|
||||
bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift));
|
||||
|
||||
bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift));
|
||||
bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift));
|
||||
});
|
||||
|
||||
// subtract float conversion offset and zero point
|
||||
if constexpr (HasZeroPoint) {
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
const float32x4_t offset_v = vdupq_n_f32(QuantBColOffsetPtr[i]);
|
||||
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
|
||||
});
|
||||
} else {
|
||||
const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f);
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
|
||||
});
|
||||
}
|
||||
|
||||
// multiply by scale
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
const float32x4_t scale_v = vdupq_n_f32(QuantBColScalePtr[i]);
|
||||
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); });
|
||||
});
|
||||
|
||||
// write, transposed, 16 x NCols values
|
||||
if constexpr (NCols == 4) {
|
||||
UnrolledLoop<4>([&](size_t j) {
|
||||
Transpose4x4(bv[0][j], bv[1][j], bv[2][j], bv[3][j]);
|
||||
|
||||
vst1q_f32(&DstColPtr[(j * 4 + 0) * 16], bv[0][j]);
|
||||
vst1q_f32(&DstColPtr[(j * 4 + 1) * 16], bv[1][j]);
|
||||
vst1q_f32(&DstColPtr[(j * 4 + 2) * 16], bv[2][j]);
|
||||
vst1q_f32(&DstColPtr[(j * 4 + 3) * 16], bv[3][j]);
|
||||
});
|
||||
} else {
|
||||
UnrolledLoop<NCols>([&](size_t i) {
|
||||
UnrolledLoop<4>([&](size_t j) {
|
||||
DstColPtr[(j * 4 + 0) * 16 + i] = vgetq_lane_f32(bv[i][j], 0);
|
||||
DstColPtr[(j * 4 + 1) * 16 + i] = vgetq_lane_f32(bv[i][j], 1);
|
||||
DstColPtr[(j * 4 + 2) * 16 + i] = vgetq_lane_f32(bv[i][j], 2);
|
||||
DstColPtr[(j * 4 + 3) * 16 + i] = vgetq_lane_f32(bv[i][j], 3);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasZeroPoint>
|
||||
void
|
||||
Q4BitBlkDequantBForSgemm_CompFp32_Impl(
|
||||
size_t BlkLen,
|
||||
float* FpData,
|
||||
const std::byte* QuantBData,
|
||||
const float* QuantBScale,
|
||||
const std::byte* QuantBZeroPoint,
|
||||
size_t CountN,
|
||||
size_t CountK,
|
||||
size_t BlockStrideQuantB
|
||||
)
|
||||
{
|
||||
constexpr size_t BlkBitWidth = 4;
|
||||
|
||||
float* Dst = FpData;
|
||||
|
||||
const std::byte* QuantBDataCol = QuantBData;
|
||||
const float* QuantBScaleCol = QuantBScale;
|
||||
[[maybe_unused]] const std::byte* QuantBZeroPointCol = QuantBZeroPoint; // only used if HasZeroPoint is true
|
||||
|
||||
const size_t StrideQuantBData = BlockStrideQuantB * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
||||
[[maybe_unused]] const size_t StrideQuantBZeroPoint = // only used if HasZeroPoint is true
|
||||
MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(BlockStrideQuantB);
|
||||
|
||||
//
|
||||
// Proceed down 16 column-wide regions of B. Dequantize and write output 16 x 16 elements at a time.
|
||||
//
|
||||
|
||||
// scales of blocks from 16 adjacent columns
|
||||
float scale[16];
|
||||
// float conversion offsets (including zero point) of blocks from 16 adjacent columns
|
||||
[[maybe_unused]] float offset[16]; // only used if HasZeroPoint is true
|
||||
|
||||
size_t n_cols_remaining = CountN;
|
||||
while (n_cols_remaining > 15) {
|
||||
for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) {
|
||||
for (size_t nn = 0; nn < 16; ++nn) {
|
||||
scale[nn] = QuantBScaleCol[nn * BlockStrideQuantB + k_blk_idx];
|
||||
|
||||
if constexpr (HasZeroPoint) {
|
||||
const std::byte zp_packed =
|
||||
QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2];
|
||||
const std::byte zp = ((k_blk_idx & 1) == 1)
|
||||
? (zp_packed >> 4)
|
||||
: (zp_packed & std::byte{0x0F});
|
||||
offset[nn] = fp32_conversion::offset + std::to_integer<uint8_t>(zp);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t kklen = std::min(CountK - k, BlkLen);
|
||||
|
||||
for (size_t kk = 0; kk < kklen; kk += 16) {
|
||||
constexpr size_t NCols = 4;
|
||||
|
||||
const float* ScalePtr = &scale[0];
|
||||
const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr;
|
||||
|
||||
float* DstColPtr = Dst;
|
||||
|
||||
for (size_t nn = 0; nn < 16; nn += NCols) {
|
||||
const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8;
|
||||
|
||||
Q4BitBlkDequantB_16xNCols<NCols, HasZeroPoint>(
|
||||
QuantBDataPtr,
|
||||
StrideQuantBData,
|
||||
ScalePtr,
|
||||
OffsetPtr,
|
||||
DstColPtr
|
||||
);
|
||||
|
||||
ScalePtr += NCols;
|
||||
if constexpr (HasZeroPoint) {
|
||||
OffsetPtr += NCols;
|
||||
}
|
||||
DstColPtr += NCols;
|
||||
}
|
||||
|
||||
Dst += 16 * std::min(kklen - kk, size_t{16});
|
||||
}
|
||||
}
|
||||
|
||||
n_cols_remaining -= 16;
|
||||
|
||||
QuantBDataCol += 16 * StrideQuantBData;
|
||||
QuantBScaleCol += 16 * BlockStrideQuantB;
|
||||
if constexpr (HasZeroPoint) {
|
||||
QuantBZeroPointCol += 16 * StrideQuantBZeroPoint;
|
||||
}
|
||||
}
|
||||
|
||||
if (n_cols_remaining > 0) {
|
||||
for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) {
|
||||
for (size_t nn = 0; nn < n_cols_remaining; ++nn) {
|
||||
scale[nn] = QuantBScaleCol[nn * BlockStrideQuantB + k_blk_idx];
|
||||
|
||||
if constexpr (HasZeroPoint) {
|
||||
const std::byte zp_packed =
|
||||
QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2];
|
||||
const std::byte zp = ((k_blk_idx & 1) == 1)
|
||||
? (zp_packed >> 4)
|
||||
: (zp_packed & std::byte{0x0F});
|
||||
offset[nn] = fp32_conversion::offset + std::to_integer<uint8_t>(zp);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t kklen = std::min(CountK - k, BlkLen);
|
||||
|
||||
for (size_t kk = 0; kk < kklen; kk += 16) {
|
||||
// zero out the 16x16 block in Dst first to ensure zero padding
|
||||
const float32x4_t zero_v = vdupq_n_f32(0.0f);
|
||||
UnrolledLoop<16 * 4>([&](size_t i) {
|
||||
vst1q_f32(Dst + 4 * i, zero_v);
|
||||
});
|
||||
|
||||
const float* ScalePtr = &scale[0];
|
||||
const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr;
|
||||
|
||||
float* DstColPtr = Dst;
|
||||
|
||||
for (size_t nn = 0; nn < n_cols_remaining; ++nn) {
|
||||
const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8;
|
||||
|
||||
Q4BitBlkDequantB_16xNCols<1, HasZeroPoint>(
|
||||
QuantBDataPtr,
|
||||
StrideQuantBData,
|
||||
ScalePtr,
|
||||
OffsetPtr,
|
||||
DstColPtr
|
||||
);
|
||||
|
||||
ScalePtr += 1;
|
||||
if constexpr (HasZeroPoint) {
|
||||
OffsetPtr += 1;
|
||||
}
|
||||
DstColPtr += 1;
|
||||
}
|
||||
|
||||
Dst += 16 * std::min(kklen - kk, size_t{16});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
Q4BitBlkDequantBForSgemm_CompFp32(
|
||||
size_t BlkLen,
|
||||
float* FpData,
|
||||
|
|
@ -509,68 +761,29 @@ Q4BitBlkDequantBForSgemm_CompFp32(
|
|||
size_t BlockStrideQuantB
|
||||
)
|
||||
{
|
||||
auto impl0_reference = [&]() {
|
||||
constexpr size_t BlkBitWidth = 4;
|
||||
constexpr size_t SubBlkLen = 16;
|
||||
|
||||
float* Dst = FpData;
|
||||
|
||||
const std::byte* QuantBDataCol = QuantBData;
|
||||
const float* QuantBScaleCol = QuantBScale;
|
||||
const std::byte* QuantBZeroPointCol = QuantBZeroPoint;
|
||||
|
||||
for (size_t n = 0; n < CountN; n += 16) {
|
||||
const size_t nnlen = std::min(CountN - n, size_t{16});
|
||||
|
||||
for (size_t nn = 0; nn < nnlen; ++nn) {
|
||||
for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) {
|
||||
const size_t kklen = std::min(CountK - k, BlkLen);
|
||||
|
||||
const std::byte* b_data =
|
||||
QuantBDataCol + k_blk_idx * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
||||
const float b_s = QuantBScaleCol[k_blk_idx];
|
||||
const uint8_t b_z =
|
||||
(QuantBZeroPointCol != nullptr)
|
||||
? ((k_blk_idx & 1) == 1)
|
||||
? std::to_integer<uint8_t>(QuantBZeroPointCol[k_blk_idx / 2] >> 4)
|
||||
: std::to_integer<uint8_t>(QuantBZeroPointCol[k_blk_idx / 2] & std::byte{0x0F})
|
||||
: 8;
|
||||
|
||||
for (size_t kk = 0; kk < kklen; ++kk) {
|
||||
const size_t packed_idx = kk % SubBlkLen;
|
||||
|
||||
const bool is_low_half = packed_idx < (SubBlkLen / 2);
|
||||
const size_t packed_byte_idx = packed_idx % (SubBlkLen / 2);
|
||||
const size_t packed_range_offset = (kk / SubBlkLen) * (SubBlkLen / 2);
|
||||
|
||||
const std::byte b_packed = b_data[packed_range_offset + packed_byte_idx];
|
||||
const std::byte b_byte = is_low_half ? (b_packed & std::byte{0x0F}) : (b_packed >> 4);
|
||||
const float b_value = (std::to_integer<int8_t>(b_byte) - b_z) * b_s;
|
||||
|
||||
Dst[(k + kk) * 16 + nn] = b_value;
|
||||
}
|
||||
}
|
||||
|
||||
QuantBDataCol += BlockStrideQuantB * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
|
||||
QuantBScaleCol += BlockStrideQuantB;
|
||||
if (QuantBZeroPointCol != nullptr) {
|
||||
QuantBZeroPointCol += MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(BlockStrideQuantB);
|
||||
}
|
||||
}
|
||||
|
||||
// zero out any remaining columns
|
||||
|
||||
if (nnlen < 16) {
|
||||
for (size_t k = 0; k < CountK; ++k) {
|
||||
std::fill_n(Dst + (k * 16) + nnlen, 16 - nnlen, 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
Dst += CountK * 16;
|
||||
}
|
||||
};
|
||||
|
||||
impl0_reference();
|
||||
if (QuantBZeroPoint != nullptr) {
|
||||
Q4BitBlkDequantBForSgemm_CompFp32_Impl<true>(
|
||||
BlkLen,
|
||||
FpData,
|
||||
QuantBData,
|
||||
QuantBScale,
|
||||
QuantBZeroPoint,
|
||||
CountN,
|
||||
CountK,
|
||||
BlockStrideQuantB
|
||||
);
|
||||
} else {
|
||||
Q4BitBlkDequantBForSgemm_CompFp32_Impl<false>(
|
||||
BlkLen,
|
||||
FpData,
|
||||
QuantBData,
|
||||
QuantBScale,
|
||||
QuantBZeroPoint,
|
||||
CountN,
|
||||
CountK,
|
||||
BlockStrideQuantB
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
|
|
@ -666,7 +879,7 @@ QuantizeBlock(
|
|||
}
|
||||
}
|
||||
|
||||
void MLASCALL
|
||||
void
|
||||
QuantizeARow_CompInt8(
|
||||
size_t BlkLen,
|
||||
const float* A,
|
||||
|
|
@ -1175,7 +1388,6 @@ SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(
|
|||
}
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
SQ4BitGemmM1Kernel_CompInt8(
|
||||
size_t BlkLen,
|
||||
|
|
|
|||
|
|
@ -16,8 +16,6 @@
|
|||
#include "core/util/thread_utils.h"
|
||||
#include "core/platform/env_var_utils.h"
|
||||
|
||||
using onnxruntime::narrow;
|
||||
|
||||
template <size_t BlkBitWidth>
|
||||
void RunSQNBitGemmBenchmark(size_t BlkLen,
|
||||
size_t M, size_t N, size_t K,
|
||||
|
|
@ -44,8 +42,8 @@ void RunSQNBitGemmBenchmark(size_t BlkLen,
|
|||
onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(),
|
||||
tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP));
|
||||
|
||||
auto A = RandomVectorUniform(static_cast<size_t>(M * K), -1.0f, 1.0f);
|
||||
auto B = RandomVectorUniform(static_cast<size_t>(K * N), -1.0f, 1.0f);
|
||||
const auto A = RandomVectorUniform(static_cast<size_t>(M * K), -1.0f, 1.0f);
|
||||
const auto B = RandomVectorUniform(static_cast<size_t>(K * N), -1.0f, 1.0f);
|
||||
std::vector<float> C(static_cast<size_t>(M * N));
|
||||
|
||||
std::vector<uint8_t> QuantBData(QuantBDataSizeInBytes);
|
||||
|
|
@ -94,6 +92,8 @@ void RunSQNBitGemmBenchmark(size_t BlkLen,
|
|||
|
||||
template <size_t BlkBitWidth>
|
||||
void SQNBITGEMM(benchmark::State& state) {
|
||||
using onnxruntime::narrow;
|
||||
|
||||
const auto BlkLen = narrow<size_t>(state.range(0));
|
||||
const auto M = narrow<size_t>(state.range(1));
|
||||
const auto N = narrow<size_t>(state.range(2));
|
||||
|
|
@ -105,6 +105,22 @@ void SQNBITGEMM(benchmark::State& state) {
|
|||
RunSQNBitGemmBenchmark<BlkBitWidth>(BlkLen, M, N, K, Threads, Symmetric, ComputeType, state);
|
||||
}
|
||||
|
||||
static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) {
|
||||
b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "ComputeType"});
|
||||
|
||||
b->ArgsProduct({
|
||||
{16, 32, 64, 128, 256}, // BlkLen
|
||||
{1, 1024, 2048}, // M
|
||||
{4096, 11008}, // N
|
||||
{4096, 11008}, // K
|
||||
{1, 8}, // Threads
|
||||
{int64_t{false}, int64_t{true}}, // Symmetric
|
||||
{int64_t{CompFp32}, int64_t{CompInt8}}, // ComputeType
|
||||
});
|
||||
}
|
||||
|
||||
BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime();
|
||||
|
||||
// This test gets benchmark arguments from environment variables.
|
||||
template <size_t BlkBitWidth>
|
||||
void SQNBITGEMM_ENV(benchmark::State& state) {
|
||||
|
|
@ -130,19 +146,4 @@ void SQNBITGEMM_ENV(benchmark::State& state) {
|
|||
state.SetLabel(s.str());
|
||||
}
|
||||
|
||||
static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) {
|
||||
b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "ComputeType"});
|
||||
|
||||
b->ArgsProduct({
|
||||
{16, 32, 64, 128, 256}, // BlkLen
|
||||
{1, 1024, 2048}, // M
|
||||
{4096, 11008}, // N
|
||||
{4096, 11008}, // K
|
||||
{1, 8}, // Threads
|
||||
{int64_t{false}, int64_t{true}}, // Symmetric
|
||||
{int64_t{CompFp32}, int64_t{CompInt8}}, // ComputeType
|
||||
});
|
||||
}
|
||||
|
||||
BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime();
|
||||
BENCHMARK(SQNBITGEMM_ENV<4>)->UseRealTime();
|
||||
|
|
|
|||
Loading…
Reference in a new issue