From ccaa4d1db26046bde60b7962083d9230c980647f Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Wed, 17 Apr 2024 17:50:26 -0700 Subject: [PATCH] [MLAS][AArch64] SQNBitGemm M>1 CompFp32 kernel optimization (#20319) Add ARM NEON intrinsics implementation for `Q4BitBlkDequantBForSgemm_CompFp32`. --- onnxruntime/core/mlas/lib/sqnbitgemm.h | 4 + .../core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 388 ++++++++++++++---- .../test/mlas/bench/bench_sqnbitgemm.cpp | 39 +- 3 files changed, 324 insertions(+), 107 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 3992bc3e45..318a51e1c8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -168,6 +168,10 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * * @param BlkLen Number of values in a block. * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * It should have enough space for + * (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen + * elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are + * useful, but the kernel implementation can be simplified with the extra space. * @param QuantBData Supplies the quantized B matrix block data. * @param QuantBScale Supplies the quantized B matrix block scale values. * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 9d7b0ae06e..2fc24b358b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -141,8 +141,8 @@ UnrolledLoop(IterationFn&& f) UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); } -MLAS_FORCEINLINE float32x4_t -FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) +MLAS_FORCEINLINE void +Transpose4x4(float32x4_t& a0, float32x4_t& a1, float32x4_t& a2, float32x4_t& a3) { // aN: aN_0 aN_1 aN_2 aN_3 @@ -159,7 +159,12 @@ FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); // a0_3 a1_3 a2_3 a3_3 a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); +} +MLAS_FORCEINLINE float32x4_t +FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) +{ + Transpose4x4(a0, a1, a2, a3); return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3)); } @@ -205,6 +210,26 @@ LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) namespace { +namespace fp32_conversion +{ + +// Manual conversion to float takes place in two steps: +// 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f]. +// This target float range is convenient because the 4-bit source values can be placed directly into the +// target float bits. +// 2. Subtract the conversion offset of 16 from the float result. + +// The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values. +constexpr uint16_t float_high_half_template = 0b0'10000011'0000000; +// sign|exponent|partial mantissa +// +|131: 2^4|~~~~ <- 4 bits go here + +const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template); + +constexpr float offset = 16.0f; + +} // namespace fp32_conversion + template MLAS_FORCEINLINE void ComputeDotProducts_BlkBitWidth4_CompFp32( @@ -230,25 +255,12 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( const uint8x8_t LowMask = vdup_n_u8(0x0F); - // Manual conversion to float takes place in two steps: - // 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f]. - // This target float range is convenient because the 4-bit source values can be placed directly into the - // target float bits. - // 2. Subtract the conversion offset of 16 from the float result. - - // The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values. - constexpr uint16_t float_high_half_template = 0b0'10000011'0000000; - // sign|exponent|partial mantissa - // +|131: 2^4|~~~~ <- 4 bits go here - - const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template); - float32x4_t acc[NCols]{}; const std::byte* QuantBData = QuantBDataColPtr; const float* QuantBScale = QuantBScaleColPtr; [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - // only used if HasZeroPoint == true + // only used if HasZeroPoint is true for (size_t k = 0; k < CountK; k += BlkLen) { const size_t k_blk_len = std::min(CountK - k, BlkLen); @@ -258,8 +270,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } ); - [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset of 16. - // only used if HasZeroPoint == true + [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset. + // only used if HasZeroPoint is true if constexpr (HasZeroPoint) { UnrolledLoop([&](size_t i) { const std::byte zp_packed = @@ -267,7 +279,7 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) ? (zp_packed >> 4) : (zp_packed & std::byte{0x0F}); - offset[i] = 16.0f + std::to_integer(zp); + offset[i] = fp32_conversion::offset + std::to_integer(zp); }); } @@ -304,8 +316,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( // combine 4 bits with float high half template UnrolledLoop([&](size_t i) { - bv_u16[i][0] = vorrq_u16(bv_u16[i][0], float_high_half_template_v); - bv_u16[i][1] = vorrq_u16(bv_u16[i][1], float_high_half_template_v); + bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); + bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); }); // `SubBlkLen` floats of B @@ -321,14 +333,14 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); }); - // subtract float conversion offset (16) and zero point + // subtract float conversion offset and zero point if constexpr (HasZeroPoint) { UnrolledLoop([&](size_t i) { const float32x4_t offset_v = vdupq_n_f32(offset[i]); UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); }); } else { - const float32x4_t offset_v = vdupq_n_f32(16.0f + 8.0f); + const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); UnrolledLoop([&](size_t i) { UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); }); @@ -454,7 +466,7 @@ SQ4BitGemmM1Kernel_CompFp32_Impl( } } -MLAS_FORCEINLINE void +void SQ4BitGemmM1Kernel_CompFp32( size_t BlkLen, const float* A, @@ -497,7 +509,247 @@ SQ4BitGemmM1Kernel_CompFp32( } } +// Block dequantize a 16 x NCols section of B from column major source to row major destination. +template MLAS_FORCEINLINE void +Q4BitBlkDequantB_16xNCols( + const std::byte* QuantBDataPtr, + size_t StrideQuantBData, + const float* QuantBColScalePtr, // pointer to NCols scales of adjacent columns + [[maybe_unused]] const float* QuantBColOffsetPtr, // pointer to NCols offsets of adjacent columns + // only used if HasZeroPoint is true + float* DstColPtr +) +{ + const uint8x8_t LowMask = vdup_n_u8(0x0F); + + // load B column vectors + uint8x8_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBDataPtr) + i * StrideQuantBData + ); + }); + + uint8x8_t bv_u8[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); + }); + + // shift left 3 and widen to 16 bits + uint16x8_t bv_u16[NCols][2]; + UnrolledLoop([&](size_t i) { + constexpr int shift = 3; + bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); + bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); + }); + + // combine 4 bits with float high half template + UnrolledLoop([&](size_t i) { + bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); + bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); + }); + + // `SubBlkLen` floats of B + float32x4_t bv[NCols][4]; + + // shift left 16, widen to 32 bits, and reinterpret as float + UnrolledLoop([&](size_t i) { + constexpr int shift = 16; + bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); + bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); + + bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); + bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); + }); + + // subtract float conversion offset and zero point + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(QuantBColOffsetPtr[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } else { + const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } + + // multiply by scale + UnrolledLoop([&](size_t i) { + const float32x4_t scale_v = vdupq_n_f32(QuantBColScalePtr[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); + }); + + // write, transposed, 16 x NCols values + if constexpr (NCols == 4) { + UnrolledLoop<4>([&](size_t j) { + Transpose4x4(bv[0][j], bv[1][j], bv[2][j], bv[3][j]); + + vst1q_f32(&DstColPtr[(j * 4 + 0) * 16], bv[0][j]); + vst1q_f32(&DstColPtr[(j * 4 + 1) * 16], bv[1][j]); + vst1q_f32(&DstColPtr[(j * 4 + 2) * 16], bv[2][j]); + vst1q_f32(&DstColPtr[(j * 4 + 3) * 16], bv[3][j]); + }); + } else { + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { + DstColPtr[(j * 4 + 0) * 16 + i] = vgetq_lane_f32(bv[i][j], 0); + DstColPtr[(j * 4 + 1) * 16 + i] = vgetq_lane_f32(bv[i][j], 1); + DstColPtr[(j * 4 + 2) * 16 + i] = vgetq_lane_f32(bv[i][j], 2); + DstColPtr[(j * 4 + 3) * 16 + i] = vgetq_lane_f32(bv[i][j], 3); + }); + }); + } +} + +template +void +Q4BitBlkDequantBForSgemm_CompFp32_Impl( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB +) +{ + constexpr size_t BlkBitWidth = 4; + + float* Dst = FpData; + + const std::byte* QuantBDataCol = QuantBData; + const float* QuantBScaleCol = QuantBScale; + [[maybe_unused]] const std::byte* QuantBZeroPointCol = QuantBZeroPoint; // only used if HasZeroPoint is true + + const size_t StrideQuantBData = BlockStrideQuantB * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + [[maybe_unused]] const size_t StrideQuantBZeroPoint = // only used if HasZeroPoint is true + MlasQNBitZeroPointsForBlksSizeInBytes(BlockStrideQuantB); + + // + // Proceed down 16 column-wide regions of B. Dequantize and write output 16 x 16 elements at a time. + // + + // scales of blocks from 16 adjacent columns + float scale[16]; + // float conversion offsets (including zero point) of blocks from 16 adjacent columns + [[maybe_unused]] float offset[16]; // only used if HasZeroPoint is true + + size_t n_cols_remaining = CountN; + while (n_cols_remaining > 15) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { + for (size_t nn = 0; nn < 16; ++nn) { + scale[nn] = QuantBScaleCol[nn * BlockStrideQuantB + k_blk_idx]; + + if constexpr (HasZeroPoint) { + const std::byte zp_packed = + QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; + const std::byte zp = ((k_blk_idx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[nn] = fp32_conversion::offset + std::to_integer(zp); + } + } + + const size_t kklen = std::min(CountK - k, BlkLen); + + for (size_t kk = 0; kk < kklen; kk += 16) { + constexpr size_t NCols = 4; + + const float* ScalePtr = &scale[0]; + const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; + + float* DstColPtr = Dst; + + for (size_t nn = 0; nn < 16; nn += NCols) { + const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; + + Q4BitBlkDequantB_16xNCols( + QuantBDataPtr, + StrideQuantBData, + ScalePtr, + OffsetPtr, + DstColPtr + ); + + ScalePtr += NCols; + if constexpr (HasZeroPoint) { + OffsetPtr += NCols; + } + DstColPtr += NCols; + } + + Dst += 16 * std::min(kklen - kk, size_t{16}); + } + } + + n_cols_remaining -= 16; + + QuantBDataCol += 16 * StrideQuantBData; + QuantBScaleCol += 16 * BlockStrideQuantB; + if constexpr (HasZeroPoint) { + QuantBZeroPointCol += 16 * StrideQuantBZeroPoint; + } + } + + if (n_cols_remaining > 0) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { + for (size_t nn = 0; nn < n_cols_remaining; ++nn) { + scale[nn] = QuantBScaleCol[nn * BlockStrideQuantB + k_blk_idx]; + + if constexpr (HasZeroPoint) { + const std::byte zp_packed = + QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; + const std::byte zp = ((k_blk_idx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[nn] = fp32_conversion::offset + std::to_integer(zp); + } + } + + const size_t kklen = std::min(CountK - k, BlkLen); + + for (size_t kk = 0; kk < kklen; kk += 16) { + // zero out the 16x16 block in Dst first to ensure zero padding + const float32x4_t zero_v = vdupq_n_f32(0.0f); + UnrolledLoop<16 * 4>([&](size_t i) { + vst1q_f32(Dst + 4 * i, zero_v); + }); + + const float* ScalePtr = &scale[0]; + const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; + + float* DstColPtr = Dst; + + for (size_t nn = 0; nn < n_cols_remaining; ++nn) { + const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; + + Q4BitBlkDequantB_16xNCols<1, HasZeroPoint>( + QuantBDataPtr, + StrideQuantBData, + ScalePtr, + OffsetPtr, + DstColPtr + ); + + ScalePtr += 1; + if constexpr (HasZeroPoint) { + OffsetPtr += 1; + } + DstColPtr += 1; + } + + Dst += 16 * std::min(kklen - kk, size_t{16}); + } + } + } +} + +void Q4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, float* FpData, @@ -509,68 +761,29 @@ Q4BitBlkDequantBForSgemm_CompFp32( size_t BlockStrideQuantB ) { - auto impl0_reference = [&]() { - constexpr size_t BlkBitWidth = 4; - constexpr size_t SubBlkLen = 16; - - float* Dst = FpData; - - const std::byte* QuantBDataCol = QuantBData; - const float* QuantBScaleCol = QuantBScale; - const std::byte* QuantBZeroPointCol = QuantBZeroPoint; - - for (size_t n = 0; n < CountN; n += 16) { - const size_t nnlen = std::min(CountN - n, size_t{16}); - - for (size_t nn = 0; nn < nnlen; ++nn) { - for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) { - const size_t kklen = std::min(CountK - k, BlkLen); - - const std::byte* b_data = - QuantBDataCol + k_blk_idx * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const float b_s = QuantBScaleCol[k_blk_idx]; - const uint8_t b_z = - (QuantBZeroPointCol != nullptr) - ? ((k_blk_idx & 1) == 1) - ? std::to_integer(QuantBZeroPointCol[k_blk_idx / 2] >> 4) - : std::to_integer(QuantBZeroPointCol[k_blk_idx / 2] & std::byte{0x0F}) - : 8; - - for (size_t kk = 0; kk < kklen; ++kk) { - const size_t packed_idx = kk % SubBlkLen; - - const bool is_low_half = packed_idx < (SubBlkLen / 2); - const size_t packed_byte_idx = packed_idx % (SubBlkLen / 2); - const size_t packed_range_offset = (kk / SubBlkLen) * (SubBlkLen / 2); - - const std::byte b_packed = b_data[packed_range_offset + packed_byte_idx]; - const std::byte b_byte = is_low_half ? (b_packed & std::byte{0x0F}) : (b_packed >> 4); - const float b_value = (std::to_integer(b_byte) - b_z) * b_s; - - Dst[(k + kk) * 16 + nn] = b_value; - } - } - - QuantBDataCol += BlockStrideQuantB * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - QuantBScaleCol += BlockStrideQuantB; - if (QuantBZeroPointCol != nullptr) { - QuantBZeroPointCol += MlasQNBitZeroPointsForBlksSizeInBytes(BlockStrideQuantB); - } - } - - // zero out any remaining columns - - if (nnlen < 16) { - for (size_t k = 0; k < CountK; ++k) { - std::fill_n(Dst + (k * 16) + nnlen, 16 - nnlen, 0.0f); - } - } - - Dst += CountK * 16; - } - }; - - impl0_reference(); + if (QuantBZeroPoint != nullptr) { + Q4BitBlkDequantBForSgemm_CompFp32_Impl( + BlkLen, + FpData, + QuantBData, + QuantBScale, + QuantBZeroPoint, + CountN, + CountK, + BlockStrideQuantB + ); + } else { + Q4BitBlkDequantBForSgemm_CompFp32_Impl( + BlkLen, + FpData, + QuantBData, + QuantBScale, + QuantBZeroPoint, + CountN, + CountK, + BlockStrideQuantB + ); + } } // @@ -666,7 +879,7 @@ QuantizeBlock( } } -void MLASCALL +void QuantizeARow_CompInt8( size_t BlkLen, const float* A, @@ -1175,7 +1388,6 @@ SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( } } -MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompInt8( size_t BlkLen, diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 04f5947e13..903c5a4985 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -16,8 +16,6 @@ #include "core/util/thread_utils.h" #include "core/platform/env_var_utils.h" -using onnxruntime::narrow; - template void RunSQNBitGemmBenchmark(size_t BlkLen, size_t M, size_t N, size_t K, @@ -44,8 +42,8 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - auto A = RandomVectorUniform(static_cast(M * K), -1.0f, 1.0f); - auto B = RandomVectorUniform(static_cast(K * N), -1.0f, 1.0f); + const auto A = RandomVectorUniform(static_cast(M * K), -1.0f, 1.0f); + const auto B = RandomVectorUniform(static_cast(K * N), -1.0f, 1.0f); std::vector C(static_cast(M * N)); std::vector QuantBData(QuantBDataSizeInBytes); @@ -94,6 +92,8 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, template void SQNBITGEMM(benchmark::State& state) { + using onnxruntime::narrow; + const auto BlkLen = narrow(state.range(0)); const auto M = narrow(state.range(1)); const auto N = narrow(state.range(2)); @@ -105,6 +105,22 @@ void SQNBITGEMM(benchmark::State& state) { RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, ComputeType, state); } +static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { + b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "ComputeType"}); + + b->ArgsProduct({ + {16, 32, 64, 128, 256}, // BlkLen + {1, 1024, 2048}, // M + {4096, 11008}, // N + {4096, 11008}, // K + {1, 8}, // Threads + {int64_t{false}, int64_t{true}}, // Symmetric + {int64_t{CompFp32}, int64_t{CompInt8}}, // ComputeType + }); +} + +BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime(); + // This test gets benchmark arguments from environment variables. template void SQNBITGEMM_ENV(benchmark::State& state) { @@ -130,19 +146,4 @@ void SQNBITGEMM_ENV(benchmark::State& state) { state.SetLabel(s.str()); } -static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { - b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "ComputeType"}); - - b->ArgsProduct({ - {16, 32, 64, 128, 256}, // BlkLen - {1, 1024, 2048}, // M - {4096, 11008}, // N - {4096, 11008}, // K - {1, 8}, // Threads - {int64_t{false}, int64_t{true}}, // Symmetric - {int64_t{CompFp32}, int64_t{CompInt8}}, // ComputeType - }); -} - -BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime(); BENCHMARK(SQNBITGEMM_ENV<4>)->UseRealTime();