mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
use correct total length to fix static kv_cache performance
This commit is contained in:
parent
65008cbb73
commit
0a0a5ca7a8
1 changed files with 1 additions and 1 deletions
|
|
@ -439,7 +439,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
|
|||
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
|
||||
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
|
||||
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
|
||||
const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_;
|
||||
const int total_sequence_length = parameters.total_sequence_length_;
|
||||
|
||||
const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_,
|
||||
parameters.sequence_length_, total_sequence_length});
|
||||
|
|
|
|||
Loading…
Reference in a new issue