From 05b4ad2e574367abead2b733d0cf42acadc877b3 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Mon, 6 May 2024 16:05:00 -0700 Subject: [PATCH] fix bug: input q/k/v should not be modified by operator (#20555) ### Description Operator should not modify input tensors because they are managed by framework and may be reused by other nodes. ### Motivation and Context --- onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 8e6b202ca4..7530823895 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -103,17 +103,13 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { if (packed_qkv) { ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q)); - } else if (sequence_length > 1) { + } else { ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( allocator, batch_size, num_heads_, sequence_length, head_size, query, Q)); ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( allocator, batch_size, kv_num_heads_, sequence_length, head_size, key, K)); ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V)); - } else { - Tensor::InitOrtValue(std::move(const_cast(*query)), Q); - Tensor::InitOrtValue(std::move(const_cast(*key)), K); - Tensor::InitOrtValue(std::move(const_cast(*value)), V); } if (do_rotary_) {