mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Feat]: Add Multithreading support for kleidiai groupwise GEMM kernels (#144074)
KleidiAI Groupwise GEMM Kernel was not 2D Blocked. This change adds supports for 2D blocking of GEMM kernel to efficiently split workload & speedup GEMM kernel over multiple threads. Performance improvements: 7B model Pre-fill speedup from 145 t/s to 175 t/s Pull Request resolved: https://github.com/pytorch/pytorch/pull/144074 Approved by: https://github.com/digantdesai
This commit is contained in:
parent
5a2e8fce9d
commit
b7f95df65b
1 changed files with 157 additions and 115 deletions
|
|
@ -146,42 +146,66 @@ static void matmul_channelwise(
|
|||
}
|
||||
|
||||
static void matmul_groupwise(
|
||||
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel,
|
||||
const size_t m,
|
||||
const size_t num_n_per_thread,
|
||||
const size_t n_start,
|
||||
const size_t k,
|
||||
const size_t bl,
|
||||
const size_t dst_stride,
|
||||
const void* lhs_ptr,
|
||||
uint8_t* rhs_packed,
|
||||
uint8_t* dst_data) {
|
||||
const size_t rhs_packed_offset =
|
||||
ukernel.get_rhs_packed_offset(n_start, k, bl);
|
||||
const size_t dst_offset = ukernel.get_dst_offset(0, n_start, dst_stride);
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4c32p& kernel_packet,
|
||||
size_t m_increment,
|
||||
size_t m_start,
|
||||
size_t m_per_thread,
|
||||
size_t n_start,
|
||||
size_t n_per_thread,
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t bl,
|
||||
size_t mr,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
size_t dst_stride,
|
||||
size_t lhs_stride,
|
||||
uint8_t* lhs_native_mtx_f32,
|
||||
uint8_t* lhs_packed_mtx_qa8dx,
|
||||
uint8_t* rhs_packed_mtx_qs4cx,
|
||||
uint8_t* dst_act_mtx_f32) {
|
||||
for (size_t m0 = 0; m0 < m_per_thread; m0 += m_increment) {
|
||||
const float* src_ptr =
|
||||
(const float*)(lhs_native_mtx_f32 +
|
||||
kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(
|
||||
m_start + m0, lhs_stride));
|
||||
size_t lhs_packed_offset =
|
||||
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(0, k, mr, kr, sr);
|
||||
void* lhs_packed_ptr = (void*)(lhs_packed_mtx_qa8dx + lhs_packed_offset);
|
||||
const void* rhs_packed_ptr =
|
||||
(const void*)((const char*)rhs_packed_mtx_qs4cx +
|
||||
kernel_packet.ukernel.get_rhs_packed_offset(
|
||||
n_start, k, bl));
|
||||
float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 +
|
||||
kernel_packet.ukernel.get_dst_offset(
|
||||
m_start + m0, n_start, dst_stride));
|
||||
|
||||
const void* rhs_ptr = (const void*)(rhs_packed + rhs_packed_offset);
|
||||
float* dst_ptr = (float*)((uint8_t*)dst_data + dst_offset);
|
||||
// Quantize and pack the Input
|
||||
kernel_packet.kai_run_lhs_quant_pack(
|
||||
m_increment, k, mr, kr, sr, 0, src_ptr, lhs_stride, lhs_packed_ptr);
|
||||
|
||||
// Run Matmul on Int4 packed weights and Quantized Packed Input
|
||||
ukernel.run_matmul(
|
||||
m,
|
||||
num_n_per_thread,
|
||||
k,
|
||||
bl,
|
||||
lhs_ptr,
|
||||
rhs_ptr,
|
||||
dst_ptr,
|
||||
dst_stride,
|
||||
sizeof(float),
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
// Run Matmul on Int4 packed weights and Quantized Packed Input
|
||||
kernel_packet.ukernel.run_matmul(
|
||||
m_increment,
|
||||
n_per_thread,
|
||||
k,
|
||||
bl,
|
||||
lhs_packed_ptr,
|
||||
rhs_packed_ptr,
|
||||
dst_ptr,
|
||||
dst_stride,
|
||||
sizeof(float),
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
struct ThreadDivision {
|
||||
int64_t num_threads_x;
|
||||
int64_t num_threads_y;
|
||||
bool use_gemm; // True if GEMM is selected, false if GEMV is used
|
||||
bool can_gemm; // True if GEMM is selected, false if GEMV is used. For Certain
|
||||
// Configurations, GEMV Kernel might be used even if M>1
|
||||
};
|
||||
|
||||
inline static unsigned int round_down_to_power_of_2(unsigned int n) {
|
||||
|
|
@ -262,71 +286,89 @@ static void kai_quant_pack_lhs_int4_mm_groupwise(
|
|||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
kai_kernel_id id = kai_kernel_id::
|
||||
// Kernel IDs for GEMM and GEMV
|
||||
kai_kernel_id gemm_id =
|
||||
kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm;
|
||||
kai_kernel_id gemv_id = kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod;
|
||||
if (cpuinfo_has_arm_i8mm() && m > 1) {
|
||||
id =
|
||||
kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm;
|
||||
|
||||
// Get the total number of threads available and choose GEMM or GEMV steps
|
||||
const int64_t total_threads = at::get_num_threads();
|
||||
auto gemm_kernel_packet = kai_select_groupwise_matmul_ukernel(gemv_id);
|
||||
if (cpuinfo_has_arm_i8mm()) {
|
||||
gemm_kernel_packet = kai_select_groupwise_matmul_ukernel(gemm_id);
|
||||
}
|
||||
auto kernel_packet = kai_select_groupwise_matmul_ukernel(id);
|
||||
auto gemv_kernel_packet = kai_select_groupwise_matmul_ukernel(gemv_id);
|
||||
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
|
||||
const size_t mr = ukernel.get_mr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
const size_t n_step = ukernel.get_n_step();
|
||||
int64_t total_threads = at::get_num_threads();
|
||||
int64_t num_threads_x = 1;
|
||||
adjust_max_threads(total_threads);
|
||||
// Split threads 1D only for now
|
||||
if (n % n_step == 0) {
|
||||
for (; total_threads > 0; total_threads -= 2) {
|
||||
if (n % total_threads == 0 && (n / total_threads) % n_step == 0) {
|
||||
num_threads_x = total_threads;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const size_t num_n_per_thread = n / num_threads_x;
|
||||
// Retrieve m_step and n_step values from GEMM and GEMV kernels
|
||||
const int64_t gemm_m_step = gemm_kernel_packet.ukernel.get_m_step();
|
||||
const int64_t gemm_n_step = gemm_kernel_packet.ukernel.get_n_step();
|
||||
const int64_t gemv_m_step = gemv_kernel_packet.ukernel.get_m_step();
|
||||
const int64_t gemv_n_step = gemv_kernel_packet.ukernel.get_n_step();
|
||||
// Determine threading and kernel type
|
||||
ThreadDivision division = get_thread_division(
|
||||
total_threads,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
gemm_m_step,
|
||||
gemm_n_step,
|
||||
gemv_m_step,
|
||||
gemv_n_step);
|
||||
// Select appropriate kernel packet based on the chosen kernel type
|
||||
auto& kernel_packet =
|
||||
division.can_gemm ? gemm_kernel_packet : gemv_kernel_packet;
|
||||
|
||||
// Thread blocking parameters
|
||||
const size_t mr = kernel_packet.ukernel.get_mr();
|
||||
const size_t nr = kernel_packet.ukernel.get_nr();
|
||||
const size_t kr = kernel_packet.ukernel.get_kr();
|
||||
const size_t sr = kernel_packet.ukernel.get_sr();
|
||||
const size_t m_increment = kernel_packet.ukernel.get_m_step();
|
||||
const size_t n_per_thread = n / division.num_threads_x;
|
||||
const size_t m_per_thread = m / division.num_threads_y;
|
||||
const int64_t num_threads = division.num_threads_y * division.num_threads_x;
|
||||
const size_t dst_stride = n * sizeof(float);
|
||||
float* lhs = reinterpret_cast<float*>(input.data_ptr());
|
||||
uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast<uint8_t*>(weight.data_ptr());
|
||||
const size_t lhs_stride = k * sizeof(float);
|
||||
|
||||
const size_t lhs_packed_size =
|
||||
kernel_packet.kai_get_lhs_packed_size(m_increment, k, mr, kr, sr);
|
||||
|
||||
uint8_t* dst_act_mtx_f32 = reinterpret_cast<uint8_t*>(output.data_ptr());
|
||||
const size_t lhs_packed_size =
|
||||
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
|
||||
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
|
||||
uint8_t* lhs_native_mtx_f32 = reinterpret_cast<uint8_t*>(input.data_ptr());
|
||||
uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast<uint8_t*>(weight.data_ptr());
|
||||
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size * num_threads);
|
||||
uint8_t* lhs_packed_base = lhs_packed.get();
|
||||
|
||||
// Quantize and pack the Input
|
||||
kernel_packet.kai_run_lhs_quant_pack(
|
||||
m,
|
||||
k,
|
||||
mr,
|
||||
kr,
|
||||
sr,
|
||||
0,
|
||||
(const float*)lhs,
|
||||
k * sizeof(float),
|
||||
(void*)lhs_packed.get());
|
||||
|
||||
at::parallel_for(0, num_threads_x, 0, [&](int begin, int end) {
|
||||
for (const auto x : c10::irange(begin, end)) {
|
||||
matmul_groupwise(
|
||||
std::ref(ukernel),
|
||||
m,
|
||||
num_n_per_thread,
|
||||
x * num_n_per_thread,
|
||||
k,
|
||||
bl,
|
||||
dst_stride,
|
||||
lhs_packed.get(),
|
||||
rhs_packed_mtx_qs4cx,
|
||||
dst_act_mtx_f32);
|
||||
}
|
||||
});
|
||||
at::parallel_for(
|
||||
0, num_threads, /*grain_size=*/0, [&](int64_t begin, int64_t end) {
|
||||
for (const auto thread_id : c10::irange(begin, end)) {
|
||||
size_t y = thread_id / division.num_threads_x;
|
||||
size_t x = thread_id % division.num_threads_x;
|
||||
uint8_t* lhs_packed_ptr = lhs_packed_base +
|
||||
(x + y * division.num_threads_x) * lhs_packed_size;
|
||||
matmul_groupwise(
|
||||
std::ref(kernel_packet),
|
||||
m_increment,
|
||||
/*m_start=*/y * m_per_thread,
|
||||
m_per_thread,
|
||||
x * n_per_thread,
|
||||
n_per_thread,
|
||||
n,
|
||||
k,
|
||||
bl,
|
||||
mr,
|
||||
nr,
|
||||
kr,
|
||||
sr,
|
||||
dst_stride,
|
||||
lhs_stride,
|
||||
lhs_native_mtx_f32,
|
||||
lhs_packed_ptr,
|
||||
rhs_packed_mtx_qs4cx,
|
||||
dst_act_mtx_f32);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static void kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
|
|
@ -367,7 +409,7 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
|
|||
gemv_n_step);
|
||||
// Select appropriate kernel packet based on the chosen kernel type
|
||||
auto& kernel_packet =
|
||||
division.use_gemm ? gemm_kernel_packet : gemv_kernel_packet;
|
||||
division.can_gemm ? gemm_kernel_packet : gemv_kernel_packet;
|
||||
|
||||
// Thread blocking parameters
|
||||
const size_t mr = kernel_packet.ukernel.get_mr();
|
||||
|
|
@ -380,7 +422,6 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
|
|||
const int64_t num_threads = division.num_threads_y * division.num_threads_x;
|
||||
const size_t dst_stride = n * sizeof(float);
|
||||
const size_t lhs_stride = k * sizeof(float);
|
||||
|
||||
const size_t lhs_packed_size =
|
||||
kernel_packet.kai_get_lhs_packed_size(m_increment, k, mr, kr, sr);
|
||||
|
||||
|
|
@ -390,33 +431,34 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
|
|||
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size * num_threads);
|
||||
uint8_t* lhs_packed_base = lhs_packed.get();
|
||||
|
||||
at::parallel_for(0, num_threads, 0, [&](int64_t begin, int64_t end) {
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
size_t y = i / division.num_threads_x;
|
||||
size_t x = i % division.num_threads_x;
|
||||
uint8_t* lhs_packed_ptr =
|
||||
lhs_packed_base + (x + y * division.num_threads_x) * lhs_packed_size;
|
||||
matmul_channelwise(
|
||||
std::ref(kernel_packet),
|
||||
m_increment,
|
||||
y * m_per_thread,
|
||||
m_per_thread,
|
||||
x * n_per_thread,
|
||||
n_per_thread,
|
||||
n,
|
||||
k,
|
||||
mr,
|
||||
nr,
|
||||
kr,
|
||||
sr,
|
||||
dst_stride,
|
||||
lhs_stride,
|
||||
lhs_native_mtx_f32,
|
||||
lhs_packed_ptr,
|
||||
rhs_packed_mtx_qs4cx,
|
||||
dst_act_mtx_f32);
|
||||
}
|
||||
});
|
||||
at::parallel_for(
|
||||
0, num_threads, /*grain_size=*/0, [&](int64_t begin, int64_t end) {
|
||||
for (const auto thread_id : c10::irange(begin, end)) {
|
||||
size_t y = thread_id / division.num_threads_x;
|
||||
size_t x = thread_id % division.num_threads_x;
|
||||
uint8_t* lhs_packed_ptr = lhs_packed_base +
|
||||
(x + y * division.num_threads_x) * lhs_packed_size;
|
||||
matmul_channelwise(
|
||||
std::ref(kernel_packet),
|
||||
m_increment,
|
||||
/*m_start=*/y * m_per_thread,
|
||||
m_per_thread,
|
||||
x * n_per_thread,
|
||||
n_per_thread,
|
||||
n,
|
||||
k,
|
||||
mr,
|
||||
nr,
|
||||
kr,
|
||||
sr,
|
||||
dst_stride,
|
||||
lhs_stride,
|
||||
lhs_native_mtx_f32,
|
||||
lhs_packed_ptr,
|
||||
rhs_packed_mtx_qs4cx,
|
||||
dst_act_mtx_f32);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void kai_quant_pack_lhs_int4_mm(
|
||||
|
|
|
|||
Loading…
Reference in a new issue