use correct total length to fix static kv_cache performance

This commit is contained in:
Guenther Schmuelling 2025-02-07 08:21:06 -08:00
parent 65008cbb73
commit 0a0a5ca7a8

View file

@ -439,7 +439,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_;
const int total_sequence_length = parameters.total_sequence_length_;
const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_,
parameters.sequence_length_, total_sequence_length});