mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
small updated
This commit is contained in:
parent
3fc1e02e3c
commit
960e176910
3 changed files with 37 additions and 25 deletions
|
|
@ -256,14 +256,16 @@ class ContinuousBatch:
|
|||
self.cache_index[k] += [new_block * cache.block_size]
|
||||
else:
|
||||
self.cache_index[k] += [self.cache_index[k][-1] +1]
|
||||
position_ids += [self.position_ids[k][-1]]
|
||||
next_full_cache_position += self.cache_index[k]
|
||||
|
||||
# how to efficiently select the next block? -> we probably just take the next longest sequence for now!
|
||||
while len(self.next_ids) < self.batch_size and len(free_block_index) > 0:
|
||||
while len(self.cumulative_seqlens_q) <= self.batch_size and len(free_block_index) > 0:
|
||||
next_sequence = self.input_tokens.pop(0)
|
||||
sample_length = len(next_sequence)
|
||||
if len(free_block_index) < (sample_length // cache.block_size) + 1:
|
||||
# we have to make sure there are enough free blocks
|
||||
self.input_tokens.insert(0, next_sequence)
|
||||
self.input_tokens.insert(-1, next_sequence)
|
||||
print("not enough memory to process this one, skippi")
|
||||
continue
|
||||
new_ids += next_sequence
|
||||
|
|
@ -300,6 +302,7 @@ class ContinuousBatch:
|
|||
self.max_seqlens_k = max_seqlens_k
|
||||
self.cumulative_seqlens_k = torch.tensor(self.cumulative_seqlens_k)
|
||||
self.cumulative_seqlens_q = torch.tensor(self.cumulative_seqlens_q)
|
||||
self.position_ids = position_ids
|
||||
assert len(new_ids) == len(next_full_cache_position) == len(position_ids), "Some preprocessing went wrong"
|
||||
position_ids = torch.tensor([position_ids])
|
||||
new_ids = torch.cat((self.next_ids, torch.tensor(new_ids))).reshape(1, -1) # new sequence placed at the end
|
||||
|
|
@ -307,19 +310,22 @@ class ContinuousBatch:
|
|||
# position_ids[0, self.cumulative_seqlens_k -1] -> the last index?
|
||||
return new_ids, position_ids, torch.tensor(next_full_cache_position)
|
||||
|
||||
def evict_finished_sequences(self, generated_ids):
|
||||
# 1. We need to del/index select only the generated_ids that are != eos token
|
||||
keep_mask = generated_ids != self.eos_token_id
|
||||
def update(self, generated_ids):
|
||||
for k in self.next_ids:
|
||||
self.generated_ids.append(k) # add the token to the full sequence
|
||||
|
||||
# given the cumulative_seqlens_q we need to index the cache index as we only keep the last sequences
|
||||
evict_mask = generated_ids == self.eos_token_id
|
||||
del self.cache_index[evict_mask]
|
||||
self.finished_sequences.append(generated_ids[evict_mask])
|
||||
# we need to also update cumulative_seqlens_q and k
|
||||
evict_mask = generated_ids ==self.eos_token_id
|
||||
keep_mask = ~evict_mask
|
||||
self.next_ids = generated_ids[keep_mask].clone()
|
||||
self.cumulative_seqlens_k = self.cumulative_seqlens_k[keep_mask]
|
||||
|
||||
self.cumulative_seqlens_q -= self.cumulative_seqlens_q[::-1] # we shift to get the previous value
|
||||
self.next_ids = generated_ids.index_select(keep_mask)
|
||||
return self.next_ids
|
||||
if evict_mask.sum(-1) > 0:
|
||||
evict_mask = torch.where(evict_mask is not False)[0] # delete the cache positions for these tokens
|
||||
evict_mask = evict_mask.tolist()
|
||||
|
||||
del self.cache_index[evict_mask]
|
||||
del self.cumulative_seqlens_k[evict_mask]
|
||||
del self.cumulative_seqlens_q[evict_mask]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_tokens)
|
||||
|
|
@ -371,11 +377,11 @@ class ContinuousMixin:
|
|||
}
|
||||
out = self.model.forward(
|
||||
continous_batch, position_ids=position_ids, **kwargs
|
||||
).last_hidden_states
|
||||
logits = self.model.lm_head(out[cache_index.cumulative_seqlens_q-1])
|
||||
).last_hidden_state
|
||||
logits = self.lm_head(out[:,current_batch.cumulative_seqlens_q-1, :])
|
||||
# we don't sample for now :)
|
||||
generated_ids = torch.argmax(logits, dim=-1)
|
||||
current_batch.evict_finished_sequences(generated_ids)
|
||||
current_batch.update(generated_ids[0])
|
||||
if len(current_batch.finished_sequences) > paged_attention_cache.batch_size:
|
||||
yield current_batch.finished_sequences
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,13 @@ def sdpa_attention_forward(
|
|||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
key, value = cache.update(key, value, module.layer_idx, cumulative_seqlens_k, **kwargs)
|
||||
|
||||
attention_mask_ = torch.full(
|
||||
[1, 1, query.shape[2],query.shape[2]+1], torch.finfo(query.dtype).min, device=query.device, dtype=query.dtype
|
||||
)
|
||||
attention_mask_[..., 0 : cumulative_seqlens_q[0], 0 : cumulative_seqlens_q[0]] = 0
|
||||
for i in range(1, len(cumulative_seqlens_q)):
|
||||
attention_mask_[..., cumulative_seqlens_q[i - 1] : cumulative_seqlens_q[i], cumulative_seqlens_q[i - 1] : cumulative_seqlens_q[i]] = 0
|
||||
attention_mask = (attention_mask != 0 * attention_mask_).to(query.dtype) * torch.finfo(query.dtype).min
|
||||
if hasattr(module, "num_key_value_groups"):
|
||||
key = repeat_kv(key, module.num_key_value_groups)
|
||||
value = repeat_kv(value, module.num_key_value_groups)
|
||||
|
|
|
|||
|
|
@ -637,14 +637,14 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
# if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||
# if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
# attention_mask,
|
||||
# inputs_embeds=input_tensor,
|
||||
# past_key_values_length=past_seen_tokens,
|
||||
# is_training=self.training,
|
||||
# ):
|
||||
# return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
sequence_length = input_tensor.shape[1]
|
||||
|
|
|
|||
Loading…
Reference in a new issue