diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 8e6b202ca4..7530823895 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -103,17 +103,13 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { if (packed_qkv) { ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q)); - } else if (sequence_length > 1) { + } else { ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( allocator, batch_size, num_heads_, sequence_length, head_size, query, Q)); ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( allocator, batch_size, kv_num_heads_, sequence_length, head_size, key, K)); ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V)); - } else { - Tensor::InitOrtValue(std::move(const_cast(*query)), Q); - Tensor::InitOrtValue(std::move(const_cast(*key)), K); - Tensor::InitOrtValue(std::move(const_cast(*value)), V); } if (do_rotary_) {