From 0a0a5ca7a83111cec5c0c44215028d15a4aa3413 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 7 Feb 2025 08:21:06 -0800 Subject: [PATCH] use correct total length to fix static kv_cache performance --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 568e75b38a..33f67375b4 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -439,7 +439,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; - const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_; + const int total_sequence_length = parameters.total_sequence_length_; const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, total_sequence_length});