small updated

This commit is contained in:
Arthur Zucker 2025-01-15 17:47:36 +01:00
parent 3fc1e02e3c
commit 960e176910
3 changed files with 37 additions and 25 deletions

View file

@ -256,14 +256,16 @@ class ContinuousBatch:
self.cache_index[k] += [new_block * cache.block_size]
else:
self.cache_index[k] += [self.cache_index[k][-1] +1]
position_ids += [self.position_ids[k][-1]]
next_full_cache_position += self.cache_index[k]
# how to efficiently select the next block? -> we probably just take the next longest sequence for now!
while len(self.next_ids) < self.batch_size and len(free_block_index) > 0:
while len(self.cumulative_seqlens_q) <= self.batch_size and len(free_block_index) > 0:
next_sequence = self.input_tokens.pop(0)
sample_length = len(next_sequence)
if len(free_block_index) < (sample_length // cache.block_size) + 1:
# we have to make sure there are enough free blocks
self.input_tokens.insert(0, next_sequence)
self.input_tokens.insert(-1, next_sequence)
print("not enough memory to process this one, skippi")
continue
new_ids += next_sequence
@ -300,6 +302,7 @@ class ContinuousBatch:
self.max_seqlens_k = max_seqlens_k
self.cumulative_seqlens_k = torch.tensor(self.cumulative_seqlens_k)
self.cumulative_seqlens_q = torch.tensor(self.cumulative_seqlens_q)
self.position_ids = position_ids
assert len(new_ids) == len(next_full_cache_position) == len(position_ids), "Some preprocessing went wrong"
position_ids = torch.tensor([position_ids])
new_ids = torch.cat((self.next_ids, torch.tensor(new_ids))).reshape(1, -1) # new sequence placed at the end
@ -307,19 +310,22 @@ class ContinuousBatch:
# position_ids[0, self.cumulative_seqlens_k -1] -> the last index?
return new_ids, position_ids, torch.tensor(next_full_cache_position)
def evict_finished_sequences(self, generated_ids):
# 1. We need to del/index select only the generated_ids that are != eos token
keep_mask = generated_ids != self.eos_token_id
def update(self, generated_ids):
for k in self.next_ids:
self.generated_ids.append(k) # add the token to the full sequence
# given the cumulative_seqlens_q we need to index the cache index as we only keep the last sequences
evict_mask = generated_ids == self.eos_token_id
del self.cache_index[evict_mask]
self.finished_sequences.append(generated_ids[evict_mask])
# we need to also update cumulative_seqlens_q and k
evict_mask = generated_ids ==self.eos_token_id
keep_mask = ~evict_mask
self.next_ids = generated_ids[keep_mask].clone()
self.cumulative_seqlens_k = self.cumulative_seqlens_k[keep_mask]
self.cumulative_seqlens_q -= self.cumulative_seqlens_q[::-1] # we shift to get the previous value
self.next_ids = generated_ids.index_select(keep_mask)
return self.next_ids
if evict_mask.sum(-1) > 0:
evict_mask = torch.where(evict_mask is not False)[0] # delete the cache positions for these tokens
evict_mask = evict_mask.tolist()
del self.cache_index[evict_mask]
del self.cumulative_seqlens_k[evict_mask]
del self.cumulative_seqlens_q[evict_mask]
def __len__(self):
return len(self.input_tokens)
@ -371,11 +377,11 @@ class ContinuousMixin:
}
out = self.model.forward(
continous_batch, position_ids=position_ids, **kwargs
).last_hidden_states
logits = self.model.lm_head(out[cache_index.cumulative_seqlens_q-1])
).last_hidden_state
logits = self.lm_head(out[:,current_batch.cumulative_seqlens_q-1, :])
# we don't sample for now :)
generated_ids = torch.argmax(logits, dim=-1)
current_batch.evict_finished_sequences(generated_ids)
current_batch.update(generated_ids[0])
if len(current_batch.finished_sequences) > paged_attention_cache.batch_size:
yield current_batch.finished_sequences

View file

@ -30,7 +30,13 @@ def sdpa_attention_forward(
**kwargs,
) -> Tuple[torch.Tensor, None]:
key, value = cache.update(key, value, module.layer_idx, cumulative_seqlens_k, **kwargs)
attention_mask_ = torch.full(
[1, 1, query.shape[2],query.shape[2]+1], torch.finfo(query.dtype).min, device=query.device, dtype=query.dtype
)
attention_mask_[..., 0 : cumulative_seqlens_q[0], 0 : cumulative_seqlens_q[0]] = 0
for i in range(1, len(cumulative_seqlens_q)):
attention_mask_[..., cumulative_seqlens_q[i - 1] : cumulative_seqlens_q[i], cumulative_seqlens_q[i - 1] : cumulative_seqlens_q[i]] = 0
attention_mask = (attention_mask != 0 * attention_mask_).to(query.dtype) * torch.finfo(query.dtype).min
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)

View file

@ -637,14 +637,14 @@ class LlamaModel(LlamaPreTrainedModel):
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
# if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
# if AttentionMaskConverter._ignore_causal_mask_sdpa(
# attention_mask,
# inputs_embeds=input_tensor,
# past_key_values_length=past_seen_tokens,
# is_training=self.training,
# ):
# return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]