From ebad35797f18a4aa837d80e28e4de44a78a32689 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Sun, 8 Dec 2024 22:41:36 +0000 Subject: [PATCH] fixing_qwen2moe --- .../models/qwen2/modeling_qwen2.py | 5 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 415 ++++++++---------- 2 files changed, 180 insertions(+), 240 deletions(-) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index a8e3033d9..60e59cbf9 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -56,6 +56,7 @@ from .configuration_qwen2 import Qwen2Config if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward + if is_torch_greater_or_equal("2.5"): from torch.nn.attention.flex_attention import flex_attention @@ -490,7 +491,7 @@ class Qwen2FlashAttention2(Qwen2Attention): self.config._attn_implementation = "flash_attention_2" logger.warning_once( "The `Qwen2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" - "attribute of the `GemmaAttention` class! It will be removed in v4.48" + "attribute of the `Qwen2Attention` class! It will be removed in v4.48" ) @@ -499,7 +500,7 @@ class Qwen2SdpaAttention(Qwen2Attention): super().__init__(config, layer_idx) self.config._attn_implementation = "sdpa" logger.warning_once( - "The `Qwen2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "The `Qwen2SdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`" "attribute of the `Qwen2Attention` class! It will be removed in v4.48" ) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 6c5cbec22..cab3ebdc8 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -46,6 +46,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torch_greater_or_equal, logging, replace_return_docstrings, ) @@ -55,6 +56,10 @@ from .configuration_qwen2_moe import Qwen2MoeConfig if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward + +if is_torch_greater_or_equal("2.5"): + from torch.nn.attention.flex_attention import flex_attention + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-57B-A14B" @@ -318,6 +323,139 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(config, query, key, value, mask, **_kwargs): + key_states = repeat_kv(key, config.num_key_value_groups) + value_states = repeat_kv(value, config.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(config.head_dim) + if mask is not None: # no matter the length, we just slice it + causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs): + key_states = repeat_kv(key, config.num_key_value_groups) + value_states = repeat_kv(value, config.num_key_value_groups) + dropout_rate = 0.0 if not config.training else config.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(config.config, "_pre_quantization_dtype"): + target_dtype = config.config._pre_quantization_dtype + else: + target_dtype = config.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + query = query.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + position_ids = _kwargs["position_ids"] + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + mask, + query_states.shape[1], + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=config.sliding_window, + is_causal=config.is_causal, + use_top_left_mask=config._flash_attn_uses_top_left_mask, + ) + + return attn_output, None + + +def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): + key = repeat_kv(key, config.num_key_value_groups) + value = repeat_kv(value, config.num_key_value_groups) + q_len = query.shape[-2] + + causal_mask = mask + if mask is not None: + causal_mask = mask[:, :, :, : key.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query.device.type == "cuda" and mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=config.attention_dropout if config.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, None + + +def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): + causal_mask = mask + if causal_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + def causal_mod(score, b, h, q_idx, kv_idx): + if causal_mask is not None: + score += causal_mask[b][0][q_idx][kv_idx] + return score + + attn_output, attn_weights = flex_attention( + query, + key, + value, + score_mod=causal_mod, + enable_gqa=True, + return_lse=True, + ) + attn_weights = attn_weights.to(value.dtype) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +QWEN2_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "flex_attention": flex_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe class Qwen2MoeAttention(nn.Module): """ @@ -346,6 +484,15 @@ class Qwen2MoeAttention(nn.Module): self.is_causal = True self.attention_dropout = config.attention_dropout + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + self.sliding_window = self.config.sliding_window + else: + self.sliding_window = None + if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" @@ -358,94 +505,6 @@ class Qwen2MoeAttention(nn.Module): self.rotary_emb = Qwen2MoeRotaryEmbedding(config=self.config) - # Ignore copy - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe -class Qwen2MoeFlashAttention2(Qwen2MoeAttention): - """ - Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( @@ -458,7 +517,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention): use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ): + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -485,62 +544,23 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window + if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: + logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") + attention_type = "flex_attention" else: - sliding_window = None - - attn_output = _flash_attention_forward( + attention_type = self.config._attn_implementation + attn_output, attn_weights = QWEN2_ATTENTION_FUNCTION[attention_type]( + self, query_states, key_states, value_states, attention_mask, - q_len, + output_attentions=output_attentions, position_ids=position_ids, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) if not output_attentions: @@ -549,108 +569,26 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention): return attn_output, attn_weights, past_key_value -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe -class Qwen2MoeSdpaAttention(Qwen2MoeAttention): - """ - Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2MoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2MoeAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2MoeModel is using Qwen2MoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe +class Qwen2MoeFlashAttention2(Qwen2MoeAttention): + def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.config._attn_implementation = "flash_attention_2" + logger.warning_once( + "The `Qwen2MoeFlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `Qwen2MoeAttention` class! It will be removed in v4.48" ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2MOE_ATTENTION_CLASSES = { - "eager": Qwen2MoeAttention, - "flash_attention_2": Qwen2MoeFlashAttention2, - "sdpa": Qwen2MoeSdpaAttention, -} +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe +class Qwen2MoeSdpaAttention(Qwen2MoeAttention): + def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.config._attn_implementation = "sdpa" + logger.warning_once( + "The `Qwen2MoeSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `Qwen2MoeAttention` class! It will be removed in v4.48" + ) class Qwen2MoeSparseMoeBlock(nn.Module): @@ -719,8 +657,8 @@ class Qwen2MoeDecoderLayer(nn.Module): def __init__(self, config: Qwen2MoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - - self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.config = config + self.self_attn = Qwen2MoeAttention(config=config, layer_idx=layer_idx) if (layer_idx not in config.mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 @@ -843,6 +781,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_flex_attn = True def _init_weights(self, module): std = self.config.initializer_range