Fix Packed MultiHead Attention (#17996)

### Description
Initialize previously unitialized parameters that were causing Op to
crash.



### Motivation and Context
Solves Cuda Memory Misalignment / Illegal Memory Access error when
FlashAttention was used in Packed Multi-Head Attention.
This commit is contained in:
aciddelgado 2023-10-18 10:52:14 -07:00 committed by GitHub
parent 22947109f2
commit a2c6283274
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 44 deletions

View file

@ -18,81 +18,89 @@ constexpr int D_DIM = 2;
struct Qkv_params {
using index_t = uint32_t;
// The QKV matrices.
void* __restrict__ q_ptr;
void* __restrict__ k_ptr;
void* __restrict__ v_ptr;
void* __restrict__ q_ptr = nullptr;
void* __restrict__ k_ptr = nullptr;
void* __restrict__ v_ptr = nullptr;
// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
index_t q_batch_stride = 0;
index_t k_batch_stride = 0;
index_t v_batch_stride = 0;
index_t q_row_stride = 0;
index_t k_row_stride = 0;
index_t v_row_stride = 0;
index_t q_head_stride = 0;
index_t k_head_stride = 0;
index_t v_head_stride = 0;
// The number of heads.
int h, h_k;
int h = 0;
int h_k = 0;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
int h_h_k_ratio = 0; // precompute h / h_k,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void* __restrict__ o_ptr;
void* __restrict__ oaccum_ptr;
void* __restrict__ o_ptr = nullptr;
void* __restrict__ oaccum_ptr = nullptr;
// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;
index_t o_batch_stride = 0;
index_t o_row_stride = 0;
index_t o_head_stride = 0;
// The pointer to the P matrix.
void* __restrict__ p_ptr;
void* __restrict__ p_ptr = nullptr;
// The pointer to the softmax sum.
void* __restrict__ softmax_lse_ptr;
void* __restrict__ softmax_lseaccum_ptr;
void* __restrict__ softmax_lse_ptr = nullptr;
void* __restrict__ softmax_lseaccum_ptr = nullptr;
// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
int b = 0;
int seqlen_q = 0;
int seqlen_k = 0;
int seqlen_knew = 0;
int d = 0;
int seqlen_q_rounded = 0;
int seqlen_k_rounded = 0;
int d_rounded = 0;
// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;
float scale_softmax = 0.0;
float scale_softmax_log2 = 0.0;
// array of length b+1 holding starting offset of each sequence.
int* __restrict__ cu_seqlens_q;
int* __restrict__ cu_seqlens_k;
int* __restrict__ cu_seqlens_q = nullptr;
int* __restrict__ cu_seqlens_k = nullptr;
int* __restrict__ blockmask;
int* __restrict__ blockmask = nullptr;
// The K_new and V_new matrices.
void* __restrict__ knew_ptr;
void* __restrict__ vnew_ptr;
void* __restrict__ knew_ptr = nullptr;
void* __restrict__ vnew_ptr = nullptr;
// The stride between rows of the Q, K and V matrices.
index_t knew_batch_stride;
index_t vnew_batch_stride;
index_t knew_row_stride;
index_t vnew_row_stride;
index_t knew_head_stride;
index_t vnew_head_stride;
index_t knew_batch_stride = 0;
index_t vnew_batch_stride = 0;
index_t knew_row_stride = 0;
index_t vnew_row_stride = 0;
index_t knew_head_stride = 0;
index_t vnew_head_stride = 0;
bool is_bf16 = false;
bool is_causal;
bool is_causal = false;
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative;
int num_splits; // For split-KV version
bool is_seqlens_k_cumulative = true;
int num_splits = 0; // For split-KV version
const cudaDeviceProp* dprops;
const cudaDeviceProp* dprops = nullptr;
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View file

@ -215,7 +215,6 @@ Status mha_fwd(const cudaDeviceProp& dprops,
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
Flash_fwd_params params;
params.dprops = &dprops;
set_params_fprop(params,
batch_size,
seqlen_q, seqlen_k,
@ -230,7 +229,7 @@ Status mha_fwd(const cudaDeviceProp& dprops,
softmax_scale,
is_causal,
kv_bsnh);
params.dprops = &dprops;
params.knew_ptr = nullptr;
params.vnew_ptr = nullptr;
params.knew_batch_stride = 0;
@ -276,7 +275,6 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
Flash_fwd_params params;
params.dprops = &dprops;
set_params_fprop(params,
batch_size,
max_seqlen_q, max_seqlen_k,
@ -290,6 +288,12 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
softmax_lse,
softmax_scale,
is_causal);
params.dprops = &dprops;
params.num_splits = 0;
params.softmax_lseaccum_ptr = nullptr;
params.oaccum_ptr = nullptr;
params.knew_ptr = nullptr;
params.vnew_ptr = nullptr;
run_mha_fwd(params, stream);
return Status::OK();
}
@ -336,7 +340,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
Flash_fwd_params params;
params.dprops = &dprops;
set_params_fprop(params,
batch_size,
seqlen_q, seqlen_k,
@ -351,6 +354,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
softmax_scale,
is_causal,
past_bsnh);
params.dprops = &dprops;
if (k != nullptr && v != nullptr) {
params.seqlen_knew = seqlen_k_new;