From 960e17691072ce9ba71aff99395ed030f1da806b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 15 Jan 2025 17:47:36 +0100 Subject: [PATCH] small updated --- .../generation/continuous_batching.py | 38 +++++++++++-------- .../integrations/sdpa_attention.py | 8 +++- .../models/llama/modeling_llama.py | 16 ++++---- 3 files changed, 37 insertions(+), 25 deletions(-) diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index 6d4c794f2..f0f1108f1 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -256,14 +256,16 @@ class ContinuousBatch: self.cache_index[k] += [new_block * cache.block_size] else: self.cache_index[k] += [self.cache_index[k][-1] +1] + position_ids += [self.position_ids[k][-1]] next_full_cache_position += self.cache_index[k] + # how to efficiently select the next block? -> we probably just take the next longest sequence for now! - while len(self.next_ids) < self.batch_size and len(free_block_index) > 0: + while len(self.cumulative_seqlens_q) <= self.batch_size and len(free_block_index) > 0: next_sequence = self.input_tokens.pop(0) sample_length = len(next_sequence) if len(free_block_index) < (sample_length // cache.block_size) + 1: # we have to make sure there are enough free blocks - self.input_tokens.insert(0, next_sequence) + self.input_tokens.insert(-1, next_sequence) print("not enough memory to process this one, skippi") continue new_ids += next_sequence @@ -300,6 +302,7 @@ class ContinuousBatch: self.max_seqlens_k = max_seqlens_k self.cumulative_seqlens_k = torch.tensor(self.cumulative_seqlens_k) self.cumulative_seqlens_q = torch.tensor(self.cumulative_seqlens_q) + self.position_ids = position_ids assert len(new_ids) == len(next_full_cache_position) == len(position_ids), "Some preprocessing went wrong" position_ids = torch.tensor([position_ids]) new_ids = torch.cat((self.next_ids, torch.tensor(new_ids))).reshape(1, -1) # new sequence placed at the end @@ -307,19 +310,22 @@ class ContinuousBatch: # position_ids[0, self.cumulative_seqlens_k -1] -> the last index? return new_ids, position_ids, torch.tensor(next_full_cache_position) - def evict_finished_sequences(self, generated_ids): - # 1. We need to del/index select only the generated_ids that are != eos token - keep_mask = generated_ids != self.eos_token_id + def update(self, generated_ids): + for k in self.next_ids: + self.generated_ids.append(k) # add the token to the full sequence - # given the cumulative_seqlens_q we need to index the cache index as we only keep the last sequences - evict_mask = generated_ids == self.eos_token_id - del self.cache_index[evict_mask] - self.finished_sequences.append(generated_ids[evict_mask]) - # we need to also update cumulative_seqlens_q and k + evict_mask = generated_ids ==self.eos_token_id + keep_mask = ~evict_mask + self.next_ids = generated_ids[keep_mask].clone() + self.cumulative_seqlens_k = self.cumulative_seqlens_k[keep_mask] - self.cumulative_seqlens_q -= self.cumulative_seqlens_q[::-1] # we shift to get the previous value - self.next_ids = generated_ids.index_select(keep_mask) - return self.next_ids + if evict_mask.sum(-1) > 0: + evict_mask = torch.where(evict_mask is not False)[0] # delete the cache positions for these tokens + evict_mask = evict_mask.tolist() + + del self.cache_index[evict_mask] + del self.cumulative_seqlens_k[evict_mask] + del self.cumulative_seqlens_q[evict_mask] def __len__(self): return len(self.input_tokens) @@ -371,11 +377,11 @@ class ContinuousMixin: } out = self.model.forward( continous_batch, position_ids=position_ids, **kwargs - ).last_hidden_states - logits = self.model.lm_head(out[cache_index.cumulative_seqlens_q-1]) + ).last_hidden_state + logits = self.lm_head(out[:,current_batch.cumulative_seqlens_q-1, :]) # we don't sample for now :) generated_ids = torch.argmax(logits, dim=-1) - current_batch.evict_finished_sequences(generated_ids) + current_batch.update(generated_ids[0]) if len(current_batch.finished_sequences) > paged_attention_cache.batch_size: yield current_batch.finished_sequences diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index ef0155b22..edd72a7e1 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -30,7 +30,13 @@ def sdpa_attention_forward( **kwargs, ) -> Tuple[torch.Tensor, None]: key, value = cache.update(key, value, module.layer_idx, cumulative_seqlens_k, **kwargs) - + attention_mask_ = torch.full( + [1, 1, query.shape[2],query.shape[2]+1], torch.finfo(query.dtype).min, device=query.device, dtype=query.dtype + ) + attention_mask_[..., 0 : cumulative_seqlens_q[0], 0 : cumulative_seqlens_q[0]] = 0 + for i in range(1, len(cumulative_seqlens_q)): + attention_mask_[..., cumulative_seqlens_q[i - 1] : cumulative_seqlens_q[i], cumulative_seqlens_q[i - 1] : cumulative_seqlens_q[i]] = 0 + attention_mask = (attention_mask != 0 * attention_mask_).to(query.dtype) * torch.finfo(query.dtype).min if hasattr(module, "num_key_value_groups"): key = repeat_kv(key, module.num_key_value_groups) value = repeat_kv(value, module.num_key_value_groups) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 16d4ba4da..bb90d5b7d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -637,14 +637,14 @@ class LlamaModel(LlamaPreTrainedModel): using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None + # if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + # if AttentionMaskConverter._ignore_causal_mask_sdpa( + # attention_mask, + # inputs_embeds=input_tensor, + # past_key_values_length=past_seen_tokens, + # is_training=self.training, + # ): + # return None dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1]