mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
fix bug: input q/k/v should not be modified by operator (#20555)
### Description <!-- Describe your changes. --> Operator should not modify input tensors because they are managed by framework and may be reused by other nodes. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
c86476a636
commit
05b4ad2e57
1 changed files with 1 additions and 5 deletions
|
|
@ -103,17 +103,13 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
|
|||
if (packed_qkv) {
|
||||
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
|
||||
allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q));
|
||||
} else if (sequence_length > 1) {
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
|
||||
allocator, batch_size, num_heads_, sequence_length, head_size, query, Q));
|
||||
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
|
||||
allocator, batch_size, kv_num_heads_, sequence_length, head_size, key, K));
|
||||
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
|
||||
allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V));
|
||||
} else {
|
||||
Tensor::InitOrtValue(std::move(const_cast<Tensor&>(*query)), Q);
|
||||
Tensor::InitOrtValue(std::move(const_cast<Tensor&>(*key)), K);
|
||||
Tensor::InitOrtValue(std::move(const_cast<Tensor&>(*value)), V);
|
||||
}
|
||||
|
||||
if (do_rotary_) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue