From 05ea7b79e6903623e4d8e697c9be88462a8d8071 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Fri, 3 Nov 2023 12:05:55 +0100 Subject: [PATCH] Refactor: Use Llama RoPE implementation for Falcon (#26933) * Use Llama RoPE implementation for Falcon + Add copy functionalities * Use standard cache format for Falcon * Simplify apply_rotary_pos_emb, copy from Llama * Remove unnecessary cache conversion test We don't need to convert any caches anymore! * Resolve copy complaint --- .../models/falcon/modeling_falcon.py | 327 +++++++----------- tests/models/falcon/test_modeling_falcon.py | 18 - 2 files changed, 128 insertions(+), 217 deletions(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 511c55a84..d4c647c84 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -71,12 +71,43 @@ class FalconLinear(nn.Linear): return hidden_states + self.bias -# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...) +# Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): - x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) @@ -90,138 +121,88 @@ def _get_unpad_data(attention_mask): ) -# TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon class FalconRotaryEmbedding(nn.Module): - """Implementation of RotaryEmbedding from GPT-NeoX. - This implementation is designed to operate on queries and keys that are compatible with `[batch_size, - n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format). - """ - - def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() - self.base = base + + self.dim = dim self.max_position_embeddings = max_position_embeddings - inv_freq = 1.0 / (self.base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - self.head_dim = head_dim - self.seq_len_cached = -1 - self.cos_cached: torch.Tensor | None = None - self.sin_cached: torch.Tensor | None = None + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) def _set_cos_sin_cache(self, seq_len, device, dtype): - self.seq_len_cached = seq_len - t = torch.arange(seq_len, device=device).to(dtype) + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1).to(device) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - if dtype in [torch.float16, torch.bfloat16]: - emb = emb.float() + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - self.cos_cached = emb.cos() - self.sin_cached = emb.sin() - - self.cos_cached = self.cos_cached.type(dtype) - self.sin_cached = self.sin_cached.type(dtype) - - def cos_sin( - self, seq_len: int, past_key_values_length: int, position_ids: torch.Tensor, device="cpu", dtype=torch.bfloat16 - ) -> torch.Tensor: - total_length = seq_len + past_key_values_length - if total_length > self.seq_len_cached: - self._set_cos_sin_cache(total_length, device, dtype) - - # the cached tensors need to update their devices (for example, after we change the model's device) - self.cos_cached = self.cos_cached.to(device) - self.sin_cached = self.sin_cached.to(device) - - # Gather cos, sin at the designated position ids - cos = self.cos_cached[position_ids] # [bs, seq_len, dim] - sin = self.sin_cached[position_ids] # [bs, seq_len, dim] - return cos, sin - - def forward(self, query, key, past_key_values_length, position_ids): - _, seq_len, _ = query.shape - cos, sin = self.cos_sin(seq_len, past_key_values_length, position_ids, query.device, query.dtype) - # Query and key's shapes are [bs * num_heads, seq_len, dim], might need manual expansion. Ifs and elses used to - # avoid unnecessary repeat_interleave operations. - query_expansion_factor = int(query.shape[0] / cos.shape[0]) - if query_expansion_factor > 1: - query_cos = torch.repeat_interleave(cos, query_expansion_factor, dim=0) - query_sin = torch.repeat_interleave(sin, query_expansion_factor, dim=0) - else: - query_cos, query_sin = cos, sin - - key_expansion_factor = int(key.shape[0] / cos.shape[0]) - if key_expansion_factor > 1: - if key_expansion_factor != query_expansion_factor: - key_cos = torch.repeat_interleave(cos, key_expansion_factor, dim=0) - key_sin = torch.repeat_interleave(sin, key_expansion_factor, dim=0) - else: - key_cos, key_sin = query_cos, query_sin - else: - key_cos, key_sin = cos, sin - - return (query * query_cos) + (rotate_half(query) * query_sin), (key * key_cos) + (rotate_half(key) * key_sin) + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding): """FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048, scaling_factor=1.0): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor - super().__init__(head_dim, base, max_position_embeddings) + super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): - self.seq_len_cached = seq_len - t = torch.arange(seq_len, device=device).to(dtype) - # This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = t / self.scaling_factor freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1).to(device) - - if dtype in [torch.float16, torch.bfloat16]: - emb = emb.float() - - self.cos_cached = emb.cos() - self.sin_cached = emb.sin() - - self.cos_cached = self.cos_cached.type(dtype) - self.sin_cached = self.sin_cached.type(dtype) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding): - """ - FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla - """ + """FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048, scaling_factor=1.0): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor - super().__init__(head_dim, base, max_position_embeddings) + super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): - self.seq_len_cached = seq_len + self.max_seq_len_cached = seq_len - # This if block is the only difference from FalconRotaryEmbedding._set_cos_sin_cache if seq_len > self.max_position_embeddings: base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.head_dim / (self.head_dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim)) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange(seq_len, device=device).to(dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1).to(device) - - if dtype in [torch.float16, torch.bfloat16]: - emb = emb.float() - - self.cos_cached = emb.cos() - self.sin_cached = emb.sin() - - self.cos_cached = self.cos_cached.type(dtype) - self.sin_cached = self.sin_cached.type(dtype) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def _prepare_4d_attention_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor: @@ -293,6 +274,8 @@ class FalconAttention(nn.Module): self.head_dim = self.hidden_size // self.num_heads self.split_size = self.hidden_size self.hidden_dropout = config.hidden_dropout + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta self.is_causal = True if self.head_dim * self.num_heads != self.hidden_size: @@ -301,7 +284,8 @@ class FalconAttention(nn.Module): f" {self.num_heads})." ) - self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t, p: (q, k) + if config.rotary: + self._init_rope() # Layer-wise attention scaling self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) @@ -319,33 +303,33 @@ class FalconAttention(nn.Module): self.attention_dropout = nn.Dropout(config.attention_dropout) self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 + # Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Falcon def _init_rope(self): if self.config.rope_scaling is None: - rotary_emb = FalconRotaryEmbedding( + self.rotary_emb = FalconRotaryEmbedding( self.head_dim, - base=self.config.rope_theta, - max_position_embeddings=self.config.max_position_embeddings, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, ) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": - rotary_emb = FalconLinearScalingRotaryEmbedding( + self.rotary_emb = FalconLinearScalingRotaryEmbedding( self.head_dim, - base=self.config.rope_theta, - max_position_embeddings=self.config.max_position_embeddings, + max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, + base=self.rope_theta, ) elif scaling_type == "dynamic": - rotary_emb = FalconDynamicNTKScalingRotaryEmbedding( + self.rotary_emb = FalconDynamicNTKScalingRotaryEmbedding( self.head_dim, - base=self.config.rope_theta, - max_position_embeddings=self.config.max_position_embeddings, + max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, + base=self.rope_theta, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - return rotary_emb def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -428,35 +412,31 @@ class FalconAttention(nn.Module): batch_size, query_length, _, _ = query_layer.shape - query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) - key_layer = key_layer.transpose(1, 2).reshape( - batch_size * num_kv_heads, - query_length, - self.head_dim, - ) - value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) + query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) + value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] - query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids) + kv_seq_len = key_layer.shape[-2] + if layer_past is not None: + kv_seq_len += layer_past[0].shape[-2] + if alibi is None: + cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) + query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) if layer_past is not None: past_key, past_value = layer_past # concatenate along seq_length dimension: - # - key: [batch_size * self.num_heads, kv_length, head_dim] - # - value: [batch_size * self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=1) - value_layer = torch.cat((past_value, value_layer), dim=1) + # - key: [batch_size, self.num_heads, kv_length, head_dim] + # - value: [batch_size, self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=-2) + value_layer = torch.cat((past_value, value_layer), dim=-2) - _, kv_length, _ = key_layer.shape + kv_length = key_layer.shape[-2] if use_cache: present = (key_layer, value_layer) else: present = None - query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) - key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) - value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) - if alibi is None: if hasattr(F, "scaled_dot_product_attention") and not output_attentions: # TODO: deprecate this once we add FA2 support in Falcon @@ -467,15 +447,15 @@ class FalconAttention(nn.Module): ) attn_output = F.scaled_dot_product_attention( - query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False + query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False ) attention_scores = None else: - attention_scores = query_layer_ @ key_layer_.transpose(-1, -2) + attention_scores = query_layer @ key_layer.transpose(-1, -2) attention_scores /= math.sqrt(self.head_dim) attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) - attn_output = attention_scores @ value_layer_ + attn_output = attention_scores @ value_layer attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) attn_output = attn_output.permute(0, 2, 1, 3) @@ -489,7 +469,7 @@ class FalconAttention(nn.Module): return output_tensor, present else: - matmul_result = query_layer_ @ key_layer_.transpose(-1, -2) + matmul_result = query_layer @ key_layer.transpose(-1, -2) # change view to [batch_size, num_heads, q_length, kv_length] attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) @@ -516,7 +496,7 @@ class FalconAttention(nn.Module): attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) # matmul: [batch_size * num_heads, q_length, head_dim] - context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1) + context_layer = (attention_probs_reshaped @ value_layer).flatten(0, 1) # change view [batch_size, q_length, num_heads * head_dim] context_layer = self._merge_heads(context_layer) @@ -563,37 +543,27 @@ class FalconFlashAttention2(FalconAttention): batch_size, query_length, _, _ = query_layer.shape - query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) - key_layer = key_layer.transpose(1, 2).reshape( - batch_size * num_kv_heads, - query_length, - self.head_dim, - ) - value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) + query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) + value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] - query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids) + kv_seq_len = key_layer.shape[-2] + if layer_past is not None: + kv_seq_len += layer_past[0].shape[-2] + if alibi is None: + cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) + query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) if layer_past is not None and use_cache: past_key, past_value = layer_past # concatenate along seq_length dimension: - # - key: [batch_size * self.num_heads, kv_length, head_dim] - # - value: [batch_size * self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=1) - value_layer = torch.cat((past_value, value_layer), dim=1) - - _, kv_seq_length, _ = key_layer.shape - - torch_dtype = query_layer.dtype + # - key: [batch_size, self.num_heads, kv_length, head_dim] + # - value: [batch_size, self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=-2) + value_layer = torch.cat((past_value, value_layer), dim=-2) past_key_value = (key_layer, value_layer) if use_cache else None - query_layer = ( - query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).to(torch_dtype) - ) - key_layer = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).to(torch_dtype) - value_layer = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).to(torch_dtype) - if alibi is not None: raise ValueError("`alibi` is not supported when `use_flash_attn` is True") @@ -940,42 +910,6 @@ class FalconPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - @staticmethod - def _convert_cache_to_standard_format( - past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: - """ - Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, - num_heads, ...])) - """ - batch_size_times_num_heads, kv_length, head_dim = past_key_value[0][0].shape - # [batch_size * self.num_heads, kv_length, head_dim] -> [batch_size, num_heads, kv_length, head_dim] - # Note that don't want to use self.num_attention_heads because the number of heads may vary depending - # on whether we use multi_query attention. - num_heads = batch_size_times_num_heads // batch_size - return tuple( - ( - layer_past[0].view(batch_size, num_heads, kv_length, head_dim), - layer_past[1].view(batch_size, num_heads, kv_length, head_dim), - ) - for layer_past in past_key_value - ) - - @staticmethod - def _convert_to_rw_cache( - past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: - batch_size, num_heads, kv_length, head_dim = past_key_value[0][0].shape - batch_size_times_num_heads = batch_size * num_heads - # [batch_size, num_heads, kv_length, head_dim] -> [batch_size * num_heads, kv_length, head_dim] - return tuple( - ( - layer_past[0].view(batch_size_times_num_heads, kv_length, head_dim), - layer_past[1].view(batch_size_times_num_heads, kv_length, head_dim), - ) - for layer_past in past_key_value - ) - @add_start_docstrings( "The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.", @@ -1046,8 +980,6 @@ class FalconModel(FalconPreTrainedModel): if past_key_values is None: past_key_values = tuple([None] * len(self.h)) - else: - past_key_values = self._convert_to_rw_cache(past_key_values) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -1073,7 +1005,7 @@ class FalconModel(FalconPreTrainedModel): # Compute alibi tensor: check build_alibi_tensor documentation past_key_values_length = 0 if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format + past_key_values_length = past_key_values[0][0].shape[-2] if self.use_alibi: mask = ( @@ -1143,9 +1075,6 @@ class FalconModel(FalconPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if presents is not None: - presents = self._convert_cache_to_standard_format(presents, batch_size) - if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index 75b1e3e46..5956a9ed6 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -340,24 +340,6 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - def test_cache_conversions(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - input_ids = input_dict["input_ids"] - model = FalconForCausalLM(config) - model.to(torch_device) - model.eval() - result = model(input_ids, use_cache=True) - batch_size = input_ids.shape[0] - rw_cache = model._convert_to_rw_cache(result.past_key_values) - standard_cache = model._convert_cache_to_standard_format(rw_cache, batch_size) - for layer in range(len(rw_cache)): - for tensor_idx in range(2): - self.assertTrue(rw_cache[layer][tensor_idx].ndim == 3) - self.assertTrue(result.past_key_values[layer][tensor_idx].ndim == 4) - self.assertTrue( - torch.all(result.past_key_values[layer][tensor_idx] == standard_cache[layer][tensor_idx]) - ) - def test_falcon_sequence_classification_model_for_multi_label(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3