fix bug: input q/k/v should not be modified by operator (#20555)

### Description
<!-- Describe your changes. -->
Operator should not modify input tensors because they are managed by
framework and may be reused by other nodes.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Yufeng Li 2024-05-06 16:05:00 -07:00 committed by GitHub
parent c86476a636
commit 05b4ad2e57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -103,17 +103,13 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
if (packed_qkv) {
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q));
} else if (sequence_length > 1) {
} else {
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, num_heads_, sequence_length, head_size, query, Q));
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, kv_num_heads_, sequence_length, head_size, key, K));
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V));
} else {
Tensor::InitOrtValue(std::move(const_cast<Tensor&>(*query)), Q);
Tensor::InitOrtValue(std::move(const_cast<Tensor&>(*key)), K);
Tensor::InitOrtValue(std::move(const_cast<Tensor&>(*value)), V);
}
if (do_rotary_) {