[Feat]: Add Multithreading support for kleidiai groupwise GEMM kernels (#144074)

KleidiAI Groupwise GEMM Kernel was not 2D Blocked. This change adds supports for 2D blocking of GEMM kernel to efficiently split workload & speedup GEMM kernel over multiple threads.

Performance improvements:
7B model Pre-fill  speedup from 145 t/s to 175 t/s

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144074
Approved by: https://github.com/digantdesai
This commit is contained in:
Nikhil Gupta 2025-01-13 20:32:20 +00:00 committed by PyTorch MergeBot
parent 5a2e8fce9d
commit b7f95df65b

View file

@ -146,42 +146,66 @@ static void matmul_channelwise(
}
static void matmul_groupwise(
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel,
const size_t m,
const size_t num_n_per_thread,
const size_t n_start,
const size_t k,
const size_t bl,
const size_t dst_stride,
const void* lhs_ptr,
uint8_t* rhs_packed,
uint8_t* dst_data) {
const size_t rhs_packed_offset =
ukernel.get_rhs_packed_offset(n_start, k, bl);
const size_t dst_offset = ukernel.get_dst_offset(0, n_start, dst_stride);
kai_matmul_ukernel_f32_qa8dxp_qs4c32p& kernel_packet,
size_t m_increment,
size_t m_start,
size_t m_per_thread,
size_t n_start,
size_t n_per_thread,
size_t n,
size_t k,
size_t bl,
size_t mr,
size_t nr,
size_t kr,
size_t sr,
size_t dst_stride,
size_t lhs_stride,
uint8_t* lhs_native_mtx_f32,
uint8_t* lhs_packed_mtx_qa8dx,
uint8_t* rhs_packed_mtx_qs4cx,
uint8_t* dst_act_mtx_f32) {
for (size_t m0 = 0; m0 < m_per_thread; m0 += m_increment) {
const float* src_ptr =
(const float*)(lhs_native_mtx_f32 +
kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(
m_start + m0, lhs_stride));
size_t lhs_packed_offset =
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(0, k, mr, kr, sr);
void* lhs_packed_ptr = (void*)(lhs_packed_mtx_qa8dx + lhs_packed_offset);
const void* rhs_packed_ptr =
(const void*)((const char*)rhs_packed_mtx_qs4cx +
kernel_packet.ukernel.get_rhs_packed_offset(
n_start, k, bl));
float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 +
kernel_packet.ukernel.get_dst_offset(
m_start + m0, n_start, dst_stride));
const void* rhs_ptr = (const void*)(rhs_packed + rhs_packed_offset);
float* dst_ptr = (float*)((uint8_t*)dst_data + dst_offset);
// Quantize and pack the Input
kernel_packet.kai_run_lhs_quant_pack(
m_increment, k, mr, kr, sr, 0, src_ptr, lhs_stride, lhs_packed_ptr);
// Run Matmul on Int4 packed weights and Quantized Packed Input
ukernel.run_matmul(
m,
num_n_per_thread,
k,
bl,
lhs_ptr,
rhs_ptr,
dst_ptr,
dst_stride,
sizeof(float),
-FLT_MAX,
FLT_MAX);
// Run Matmul on Int4 packed weights and Quantized Packed Input
kernel_packet.ukernel.run_matmul(
m_increment,
n_per_thread,
k,
bl,
lhs_packed_ptr,
rhs_packed_ptr,
dst_ptr,
dst_stride,
sizeof(float),
-FLT_MAX,
FLT_MAX);
}
}
struct ThreadDivision {
int64_t num_threads_x;
int64_t num_threads_y;
bool use_gemm; // True if GEMM is selected, false if GEMV is used
bool can_gemm; // True if GEMM is selected, false if GEMV is used. For Certain
// Configurations, GEMV Kernel might be used even if M>1
};
inline static unsigned int round_down_to_power_of_2(unsigned int n) {
@ -262,71 +286,89 @@ static void kai_quant_pack_lhs_int4_mm_groupwise(
const int64_t n,
const int64_t k,
const int64_t bl) {
kai_kernel_id id = kai_kernel_id::
// Kernel IDs for GEMM and GEMV
kai_kernel_id gemm_id =
kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm;
kai_kernel_id gemv_id = kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod;
if (cpuinfo_has_arm_i8mm() && m > 1) {
id =
kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm;
// Get the total number of threads available and choose GEMM or GEMV steps
const int64_t total_threads = at::get_num_threads();
auto gemm_kernel_packet = kai_select_groupwise_matmul_ukernel(gemv_id);
if (cpuinfo_has_arm_i8mm()) {
gemm_kernel_packet = kai_select_groupwise_matmul_ukernel(gemm_id);
}
auto kernel_packet = kai_select_groupwise_matmul_ukernel(id);
auto gemv_kernel_packet = kai_select_groupwise_matmul_ukernel(gemv_id);
const auto& ukernel = kernel_packet.ukernel;
const size_t mr = ukernel.get_mr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
const size_t n_step = ukernel.get_n_step();
int64_t total_threads = at::get_num_threads();
int64_t num_threads_x = 1;
adjust_max_threads(total_threads);
// Split threads 1D only for now
if (n % n_step == 0) {
for (; total_threads > 0; total_threads -= 2) {
if (n % total_threads == 0 && (n / total_threads) % n_step == 0) {
num_threads_x = total_threads;
break;
}
}
}
const size_t num_n_per_thread = n / num_threads_x;
// Retrieve m_step and n_step values from GEMM and GEMV kernels
const int64_t gemm_m_step = gemm_kernel_packet.ukernel.get_m_step();
const int64_t gemm_n_step = gemm_kernel_packet.ukernel.get_n_step();
const int64_t gemv_m_step = gemv_kernel_packet.ukernel.get_m_step();
const int64_t gemv_n_step = gemv_kernel_packet.ukernel.get_n_step();
// Determine threading and kernel type
ThreadDivision division = get_thread_division(
total_threads,
m,
n,
k,
gemm_m_step,
gemm_n_step,
gemv_m_step,
gemv_n_step);
// Select appropriate kernel packet based on the chosen kernel type
auto& kernel_packet =
division.can_gemm ? gemm_kernel_packet : gemv_kernel_packet;
// Thread blocking parameters
const size_t mr = kernel_packet.ukernel.get_mr();
const size_t nr = kernel_packet.ukernel.get_nr();
const size_t kr = kernel_packet.ukernel.get_kr();
const size_t sr = kernel_packet.ukernel.get_sr();
const size_t m_increment = kernel_packet.ukernel.get_m_step();
const size_t n_per_thread = n / division.num_threads_x;
const size_t m_per_thread = m / division.num_threads_y;
const int64_t num_threads = division.num_threads_y * division.num_threads_x;
const size_t dst_stride = n * sizeof(float);
float* lhs = reinterpret_cast<float*>(input.data_ptr());
uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast<uint8_t*>(weight.data_ptr());
const size_t lhs_stride = k * sizeof(float);
const size_t lhs_packed_size =
kernel_packet.kai_get_lhs_packed_size(m_increment, k, mr, kr, sr);
uint8_t* dst_act_mtx_f32 = reinterpret_cast<uint8_t*>(output.data_ptr());
const size_t lhs_packed_size =
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
uint8_t* lhs_native_mtx_f32 = reinterpret_cast<uint8_t*>(input.data_ptr());
uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast<uint8_t*>(weight.data_ptr());
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size * num_threads);
uint8_t* lhs_packed_base = lhs_packed.get();
// Quantize and pack the Input
kernel_packet.kai_run_lhs_quant_pack(
m,
k,
mr,
kr,
sr,
0,
(const float*)lhs,
k * sizeof(float),
(void*)lhs_packed.get());
at::parallel_for(0, num_threads_x, 0, [&](int begin, int end) {
for (const auto x : c10::irange(begin, end)) {
matmul_groupwise(
std::ref(ukernel),
m,
num_n_per_thread,
x * num_n_per_thread,
k,
bl,
dst_stride,
lhs_packed.get(),
rhs_packed_mtx_qs4cx,
dst_act_mtx_f32);
}
});
at::parallel_for(
0, num_threads, /*grain_size=*/0, [&](int64_t begin, int64_t end) {
for (const auto thread_id : c10::irange(begin, end)) {
size_t y = thread_id / division.num_threads_x;
size_t x = thread_id % division.num_threads_x;
uint8_t* lhs_packed_ptr = lhs_packed_base +
(x + y * division.num_threads_x) * lhs_packed_size;
matmul_groupwise(
std::ref(kernel_packet),
m_increment,
/*m_start=*/y * m_per_thread,
m_per_thread,
x * n_per_thread,
n_per_thread,
n,
k,
bl,
mr,
nr,
kr,
sr,
dst_stride,
lhs_stride,
lhs_native_mtx_f32,
lhs_packed_ptr,
rhs_packed_mtx_qs4cx,
dst_act_mtx_f32);
}
});
}
static void kai_quant_pack_lhs_int4_mm_channelwise(
@ -367,7 +409,7 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
gemv_n_step);
// Select appropriate kernel packet based on the chosen kernel type
auto& kernel_packet =
division.use_gemm ? gemm_kernel_packet : gemv_kernel_packet;
division.can_gemm ? gemm_kernel_packet : gemv_kernel_packet;
// Thread blocking parameters
const size_t mr = kernel_packet.ukernel.get_mr();
@ -380,7 +422,6 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
const int64_t num_threads = division.num_threads_y * division.num_threads_x;
const size_t dst_stride = n * sizeof(float);
const size_t lhs_stride = k * sizeof(float);
const size_t lhs_packed_size =
kernel_packet.kai_get_lhs_packed_size(m_increment, k, mr, kr, sr);
@ -390,33 +431,34 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size * num_threads);
uint8_t* lhs_packed_base = lhs_packed.get();
at::parallel_for(0, num_threads, 0, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
size_t y = i / division.num_threads_x;
size_t x = i % division.num_threads_x;
uint8_t* lhs_packed_ptr =
lhs_packed_base + (x + y * division.num_threads_x) * lhs_packed_size;
matmul_channelwise(
std::ref(kernel_packet),
m_increment,
y * m_per_thread,
m_per_thread,
x * n_per_thread,
n_per_thread,
n,
k,
mr,
nr,
kr,
sr,
dst_stride,
lhs_stride,
lhs_native_mtx_f32,
lhs_packed_ptr,
rhs_packed_mtx_qs4cx,
dst_act_mtx_f32);
}
});
at::parallel_for(
0, num_threads, /*grain_size=*/0, [&](int64_t begin, int64_t end) {
for (const auto thread_id : c10::irange(begin, end)) {
size_t y = thread_id / division.num_threads_x;
size_t x = thread_id % division.num_threads_x;
uint8_t* lhs_packed_ptr = lhs_packed_base +
(x + y * division.num_threads_x) * lhs_packed_size;
matmul_channelwise(
std::ref(kernel_packet),
m_increment,
/*m_start=*/y * m_per_thread,
m_per_thread,
x * n_per_thread,
n_per_thread,
n,
k,
mr,
nr,
kr,
sr,
dst_stride,
lhs_stride,
lhs_native_mtx_f32,
lhs_packed_ptr,
rhs_packed_mtx_qs4cx,
dst_act_mtx_f32);
}
});
}
void kai_quant_pack_lhs_int4_mm(