mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Fix Packed MultiHead Attention (#17996)
### Description Initialize previously unitialized parameters that were causing Op to crash. ### Motivation and Context Solves Cuda Memory Misalignment / Illegal Memory Access error when FlashAttention was used in Packed Multi-Head Attention.
This commit is contained in:
parent
22947109f2
commit
a2c6283274
2 changed files with 56 additions and 44 deletions
|
|
@ -18,81 +18,89 @@ constexpr int D_DIM = 2;
|
|||
struct Qkv_params {
|
||||
using index_t = uint32_t;
|
||||
// The QKV matrices.
|
||||
void* __restrict__ q_ptr;
|
||||
void* __restrict__ k_ptr;
|
||||
void* __restrict__ v_ptr;
|
||||
void* __restrict__ q_ptr = nullptr;
|
||||
void* __restrict__ k_ptr = nullptr;
|
||||
void* __restrict__ v_ptr = nullptr;
|
||||
|
||||
// The stride between rows of the Q, K and V matrices.
|
||||
index_t q_batch_stride;
|
||||
index_t k_batch_stride;
|
||||
index_t v_batch_stride;
|
||||
index_t q_row_stride;
|
||||
index_t k_row_stride;
|
||||
index_t v_row_stride;
|
||||
index_t q_head_stride;
|
||||
index_t k_head_stride;
|
||||
index_t v_head_stride;
|
||||
index_t q_batch_stride = 0;
|
||||
index_t k_batch_stride = 0;
|
||||
index_t v_batch_stride = 0;
|
||||
index_t q_row_stride = 0;
|
||||
index_t k_row_stride = 0;
|
||||
index_t v_row_stride = 0;
|
||||
index_t q_head_stride = 0;
|
||||
index_t k_head_stride = 0;
|
||||
index_t v_head_stride = 0;
|
||||
|
||||
// The number of heads.
|
||||
int h, h_k;
|
||||
int h = 0;
|
||||
int h_k = 0;
|
||||
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
|
||||
// different from nheads (query).
|
||||
int h_h_k_ratio; // precompute h / h_k,
|
||||
int h_h_k_ratio = 0; // precompute h / h_k,
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Flash_fwd_params : public Qkv_params {
|
||||
// The O matrix (output).
|
||||
void* __restrict__ o_ptr;
|
||||
void* __restrict__ oaccum_ptr;
|
||||
void* __restrict__ o_ptr = nullptr;
|
||||
void* __restrict__ oaccum_ptr = nullptr;
|
||||
|
||||
// The stride between rows of O.
|
||||
index_t o_batch_stride;
|
||||
index_t o_row_stride;
|
||||
index_t o_head_stride;
|
||||
index_t o_batch_stride = 0;
|
||||
index_t o_row_stride = 0;
|
||||
index_t o_head_stride = 0;
|
||||
|
||||
// The pointer to the P matrix.
|
||||
void* __restrict__ p_ptr;
|
||||
void* __restrict__ p_ptr = nullptr;
|
||||
|
||||
// The pointer to the softmax sum.
|
||||
void* __restrict__ softmax_lse_ptr;
|
||||
void* __restrict__ softmax_lseaccum_ptr;
|
||||
void* __restrict__ softmax_lse_ptr = nullptr;
|
||||
void* __restrict__ softmax_lseaccum_ptr = nullptr;
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
|
||||
int b = 0;
|
||||
int seqlen_q = 0;
|
||||
int seqlen_k = 0;
|
||||
int seqlen_knew = 0;
|
||||
int d = 0;
|
||||
int seqlen_q_rounded = 0;
|
||||
int seqlen_k_rounded = 0;
|
||||
int d_rounded = 0;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_softmax;
|
||||
float scale_softmax_log2;
|
||||
float scale_softmax = 0.0;
|
||||
float scale_softmax_log2 = 0.0;
|
||||
|
||||
// array of length b+1 holding starting offset of each sequence.
|
||||
int* __restrict__ cu_seqlens_q;
|
||||
int* __restrict__ cu_seqlens_k;
|
||||
int* __restrict__ cu_seqlens_q = nullptr;
|
||||
int* __restrict__ cu_seqlens_k = nullptr;
|
||||
|
||||
int* __restrict__ blockmask;
|
||||
int* __restrict__ blockmask = nullptr;
|
||||
|
||||
// The K_new and V_new matrices.
|
||||
void* __restrict__ knew_ptr;
|
||||
void* __restrict__ vnew_ptr;
|
||||
void* __restrict__ knew_ptr = nullptr;
|
||||
void* __restrict__ vnew_ptr = nullptr;
|
||||
|
||||
// The stride between rows of the Q, K and V matrices.
|
||||
index_t knew_batch_stride;
|
||||
index_t vnew_batch_stride;
|
||||
index_t knew_row_stride;
|
||||
index_t vnew_row_stride;
|
||||
index_t knew_head_stride;
|
||||
index_t vnew_head_stride;
|
||||
index_t knew_batch_stride = 0;
|
||||
index_t vnew_batch_stride = 0;
|
||||
index_t knew_row_stride = 0;
|
||||
index_t vnew_row_stride = 0;
|
||||
index_t knew_head_stride = 0;
|
||||
index_t vnew_head_stride = 0;
|
||||
|
||||
bool is_bf16 = false;
|
||||
bool is_causal;
|
||||
bool is_causal = false;
|
||||
|
||||
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
||||
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
||||
bool is_seqlens_k_cumulative;
|
||||
int num_splits; // For split-KV version
|
||||
bool is_seqlens_k_cumulative = true;
|
||||
int num_splits = 0; // For split-KV version
|
||||
|
||||
const cudaDeviceProp* dprops;
|
||||
const cudaDeviceProp* dprops = nullptr;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
|||
|
|
@ -215,7 +215,6 @@ Status mha_fwd(const cudaDeviceProp& dprops,
|
|||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
Flash_fwd_params params;
|
||||
params.dprops = &dprops;
|
||||
set_params_fprop(params,
|
||||
batch_size,
|
||||
seqlen_q, seqlen_k,
|
||||
|
|
@ -230,7 +229,7 @@ Status mha_fwd(const cudaDeviceProp& dprops,
|
|||
softmax_scale,
|
||||
is_causal,
|
||||
kv_bsnh);
|
||||
|
||||
params.dprops = &dprops;
|
||||
params.knew_ptr = nullptr;
|
||||
params.vnew_ptr = nullptr;
|
||||
params.knew_batch_stride = 0;
|
||||
|
|
@ -276,7 +275,6 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
|
|||
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
||||
|
||||
Flash_fwd_params params;
|
||||
params.dprops = &dprops;
|
||||
set_params_fprop(params,
|
||||
batch_size,
|
||||
max_seqlen_q, max_seqlen_k,
|
||||
|
|
@ -290,6 +288,12 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
|
|||
softmax_lse,
|
||||
softmax_scale,
|
||||
is_causal);
|
||||
params.dprops = &dprops;
|
||||
params.num_splits = 0;
|
||||
params.softmax_lseaccum_ptr = nullptr;
|
||||
params.oaccum_ptr = nullptr;
|
||||
params.knew_ptr = nullptr;
|
||||
params.vnew_ptr = nullptr;
|
||||
run_mha_fwd(params, stream);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -336,7 +340,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
|
|||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
Flash_fwd_params params;
|
||||
params.dprops = &dprops;
|
||||
set_params_fprop(params,
|
||||
batch_size,
|
||||
seqlen_q, seqlen_k,
|
||||
|
|
@ -351,6 +354,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
|
|||
softmax_scale,
|
||||
is_causal,
|
||||
past_bsnh);
|
||||
params.dprops = &dprops;
|
||||
|
||||
if (k != nullptr && v != nullptr) {
|
||||
params.seqlen_knew = seqlen_k_new;
|
||||
|
|
|
|||
Loading…
Reference in a new issue