fixing_qwen2moe

This commit is contained in:
MekkCyber 2024-12-08 22:41:36 +00:00
parent 32ed852df0
commit ebad35797f
2 changed files with 180 additions and 240 deletions

View file

@ -56,6 +56,7 @@ from .configuration_qwen2 import Qwen2Config
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
if is_torch_greater_or_equal("2.5"):
from torch.nn.attention.flex_attention import flex_attention
@ -490,7 +491,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
self.config._attn_implementation = "flash_attention_2"
logger.warning_once(
"The `Qwen2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
"attribute of the `GemmaAttention` class! It will be removed in v4.48"
"attribute of the `Qwen2Attention` class! It will be removed in v4.48"
)
@ -499,7 +500,7 @@ class Qwen2SdpaAttention(Qwen2Attention):
super().__init__(config, layer_idx)
self.config._attn_implementation = "sdpa"
logger.warning_once(
"The `Qwen2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
"The `Qwen2SdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`"
"attribute of the `Qwen2Attention` class! It will be removed in v4.48"
)

View file

@ -46,6 +46,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_greater_or_equal,
logging,
replace_return_docstrings,
)
@ -55,6 +56,10 @@ from .configuration_qwen2_moe import Qwen2MoeConfig
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
if is_torch_greater_or_equal("2.5"):
from torch.nn.attention.flex_attention import flex_attention
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-57B-A14B"
@ -318,6 +323,139 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(config, query, key, value, mask, **_kwargs):
key_states = repeat_kv(key, config.num_key_value_groups)
value_states = repeat_kv(value, config.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(config.head_dim)
if mask is not None: # no matter the length, we just slice it
causal_mask = mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs):
key_states = repeat_kv(key, config.num_key_value_groups)
value_states = repeat_kv(value, config.num_key_value_groups)
dropout_rate = 0.0 if not config.training else config.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(config.config, "_pre_quantization_dtype"):
target_dtype = config.config._pre_quantization_dtype
else:
target_dtype = config.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query = query.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention
query_states = query.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
position_ids = _kwargs["position_ids"]
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
mask,
query_states.shape[1],
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=config.sliding_window,
is_causal=config.is_causal,
use_top_left_mask=config._flash_attn_uses_top_left_mask,
)
return attn_output, None
def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
key = repeat_kv(key, config.num_key_value_groups)
value = repeat_kv(value, config.num_key_value_groups)
q_len = query.shape[-2]
causal_mask = mask
if mask is not None:
causal_mask = mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query.device.type == "cuda" and mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=config.attention_dropout if config.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs):
causal_mask = mask
if causal_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
def causal_mod(score, b, h, q_idx, kv_idx):
if causal_mask is not None:
score += causal_mask[b][0][q_idx][kv_idx]
return score
attn_output, attn_weights = flex_attention(
query,
key,
value,
score_mod=causal_mod,
enable_gqa=True,
return_lse=True,
)
attn_weights = attn_weights.to(value.dtype)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
QWEN2_ATTENTION_FUNCTION = {
"flash_attention_2": flash_attention_forward,
"flex_attention": flex_attention_forward,
"eager": eager_attention_forward,
"sdpa": sdpa_attention_forward,
}
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe
class Qwen2MoeAttention(nn.Module):
"""
@ -346,6 +484,15 @@ class Qwen2MoeAttention(nn.Module):
self.is_causal = True
self.attention_dropout = config.attention_dropout
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
self.sliding_window = self.config.sliding_window
else:
self.sliding_window = None
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
@ -358,94 +505,6 @@ class Qwen2MoeAttention(nn.Module):
self.rotary_emb = Qwen2MoeRotaryEmbedding(config=self.config)
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe
class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
"""
Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention`
as the weights of the module stays untouched. The only required change would be on the forward pass
where it needs to correctly call the public API of flash attention and deal with padding tokens
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
config.max_window_layers layers.
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
@ -458,7 +517,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
@ -485,62 +544,23 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
attention_type = "flex_attention"
else:
sliding_window = None
attn_output = _flash_attention_forward(
attention_type = self.config._attn_implementation
attn_output, attn_weights = QWEN2_ATTENTION_FUNCTION[attention_type](
self,
query_states,
key_states,
value_states,
attention_mask,
q_len,
output_attentions=output_attentions,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
@ -549,108 +569,26 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe
class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
"""
Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`Qwen2MoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from Qwen2MoeAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"Qwen2MoeModel is using Qwen2MoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe
class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.config._attn_implementation = "flash_attention_2"
logger.warning_once(
"The `Qwen2MoeFlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
"attribute of the `Qwen2MoeAttention` class! It will be removed in v4.48"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
QWEN2MOE_ATTENTION_CLASSES = {
"eager": Qwen2MoeAttention,
"flash_attention_2": Qwen2MoeFlashAttention2,
"sdpa": Qwen2MoeSdpaAttention,
}
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe
class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.config._attn_implementation = "sdpa"
logger.warning_once(
"The `Qwen2MoeSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`"
"attribute of the `Qwen2MoeAttention` class! It will be removed in v4.48"
)
class Qwen2MoeSparseMoeBlock(nn.Module):
@ -719,8 +657,8 @@ class Qwen2MoeDecoderLayer(nn.Module):
def __init__(self, config: Qwen2MoeConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.config = config
self.self_attn = Qwen2MoeAttention(config=config, layer_idx=layer_idx)
if (layer_idx not in config.mlp_only_layers) and (
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
@ -843,6 +781,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_flex_attn = True
def _init_weights(self, module):
std = self.config.initializer_range