mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
Fix Memory Issue GQA CPU Rotary (#22290)
### Description In GQA there was a memory issue which was best described by @edgchen1 [here](https://github.com/microsoft/onnxruntime/issues/22252#issuecomment-2384559255) > here's the problematic code: > >d9de054eb5/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc (L149-L157)> > annotated: > > ```c++ > if (packed_qkv) { > // Q is an OrtValue declared in the enclosing scope. > OrtValue RotaryQKV; > Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}), allocator, RotaryQKV); > // Save pointer to Q's data in q_input. > q_input = Q.Get<Tensor>().Data<T>(); > k_input = q_input + num_heads_ * sequence_length * head_size; > q_rotary = RotaryQKV.GetMutable<Tensor>()->MutableData<T>(); > k_rotary = q_rotary + num_heads_ * sequence_length * head_size; > // Overwrite Q with RotaryQKV (OrtValues contain shared_ptr to contained value). > // Now, q_input is pointing to freed memory. > Q = RotaryQKV; > } > ``` > > later on, when we use `q_input`, there is a read access violation. > >d9de054eb5/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc (L170-L172)> > this problem showed up when CPU allocator sharing between sessions was enabled. in that case, the CPU allocator's arena was disabled. I suspect that the default usage of the arena hid this issue. > > though I debugged into the first branch, this appears to be a problem in both branches: > >d9de054eb5/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc (L149-L168)### Motivation and Context Fixes a crucial bug. The issue was found here https://github.com/microsoft/onnxruntime/issues/22252
This commit is contained in:
parent
efb8703a25
commit
cc0193cd42
1 changed files with 8 additions and 11 deletions
|
|
@ -106,6 +106,11 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
|
|||
allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V));
|
||||
}
|
||||
|
||||
OrtValue RotaryQKV;
|
||||
OrtValue RotaryQ;
|
||||
OrtValue RotaryK;
|
||||
T* q_rotary = Q.GetMutable<Tensor>()->MutableData<T>();
|
||||
T* k_rotary = packed_qkv ? nullptr : K.GetMutable<Tensor>()->MutableData<T>();
|
||||
if (do_rotary_) {
|
||||
// Initialize rotary parameters
|
||||
rotary_embedding_helper::RotaryParameters rotary_params = {};
|
||||
|
|
@ -128,7 +133,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
|
|||
if (parameters.is_first_prompt) {
|
||||
pos_ids[0] = static_cast<int64_t>(0);
|
||||
} else {
|
||||
// Note: As of now, interactive decoding supports only batch size 1 and token generation supports only sequence length 1.
|
||||
// Note: As of now, continuous decoding supports only batch size 1 and token generation supports only sequence length 1.
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
const int total_seqlen = seqlens_k->Data<int32_t>()[b] + 1;
|
||||
const int past_seqlen = total_seqlen - sequence_length;
|
||||
|
|
@ -144,27 +149,19 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
|
|||
// Initialize separate buffers for rotary embeddings
|
||||
const T* q_input;
|
||||
const T* k_input;
|
||||
T* q_rotary;
|
||||
T* k_rotary;
|
||||
if (packed_qkv) {
|
||||
OrtValue RotaryQKV;
|
||||
Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}), allocator, RotaryQKV);
|
||||
q_input = Q.Get<Tensor>().Data<T>();
|
||||
k_input = q_input + num_heads_ * sequence_length * head_size;
|
||||
q_rotary = RotaryQKV.GetMutable<Tensor>()->MutableData<T>();
|
||||
k_rotary = q_rotary + num_heads_ * sequence_length * head_size;
|
||||
Q = RotaryQKV;
|
||||
} else {
|
||||
OrtValue RotaryQ;
|
||||
Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_, sequence_length, head_size}), allocator, RotaryQ);
|
||||
OrtValue RotaryK;
|
||||
Tensor::InitOrtValue(element_type, TensorShape({batch_size, kv_num_heads_, sequence_length, head_size}), allocator, RotaryK);
|
||||
q_input = Q.Get<Tensor>().Data<T>();
|
||||
k_input = K.Get<Tensor>().Data<T>();
|
||||
q_rotary = RotaryQ.GetMutable<Tensor>()->MutableData<T>();
|
||||
k_rotary = RotaryK.GetMutable<Tensor>()->MutableData<T>();
|
||||
Q = RotaryQ;
|
||||
K = RotaryK;
|
||||
}
|
||||
// Run rotary embedding for Q and K
|
||||
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, q_input,
|
||||
|
|
@ -196,8 +193,8 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
// Compute the attention score and apply the score to V
|
||||
return ApplyAttention(Q.Get<Tensor>().Data<T>(), packed_qkv ? nullptr : K.Get<Tensor>().Data<T>(),
|
||||
packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(), past_key, past_value, output, present_k, present_v,
|
||||
return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(),
|
||||
past_key, past_value, output, present_k, present_v,
|
||||
seqlens_k, parameters, allocator, context);
|
||||
}
|
||||
} // namespace contrib
|
||||
|
|
|
|||
Loading…
Reference in a new issue