From f4f0ebd81cef3ecf140bab89fcd931b57b89b8a1 Mon Sep 17 00:00:00 2001 From: ryan u Date: Thu, 30 Jan 2025 20:00:39 +0900 Subject: [PATCH] hold code only for checkpoints congifuration; remove redundant --- .../deepseek_v3/configuration_deepseek_v3.py | 36 +- .../deepseek_v3/modeling_deepseek_v3.py | 366 +++++++--------- .../models/deepseek_v3/modular_deepseek_v3.py | 405 ++++++++---------- 3 files changed, 337 insertions(+), 470 deletions(-) diff --git a/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py b/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py index 7011fdeeb..8d1c56343 100644 --- a/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py @@ -45,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig): Dimension of the MoE representations. num_hidden_layers (`int`, *optional*, defaults to 61): Number of hidden layers in the Transformer decoder. - num_nextn_predict_layers (`int`, *optional*, defaults to 1): - Number of nextn predict layers in the DeepSeekV3 Model. num_attention_heads (`int`, *optional*, defaults to 128): Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*, defaults to 128): @@ -58,11 +56,9 @@ class DeepseekV3Config(PretrainedConfig): paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. n_shared_experts (`int`, *optional*, defaults to 1): - Number of shared experts, None means dense model. + Number of shared experts. n_routed_experts (`int`, *optional*, defaults to 256): - Number of routed experts, None means dense model. - ep_size (`int`, *optional*, defaults to 1): - Expert parallelism size for distributed training. + Number of routed experts. routed_scaling_factor (`float`, *optional*, defaults to 2.5): Scaling factor or routed experts. kv_lora_rank (`int`, *optional*, defaults to 512): @@ -75,28 +71,20 @@ class DeepseekV3Config(PretrainedConfig): Dimension of the value heads. qk_nope_head_dim (`int`, *optional*, defaults to 128): Dimension of the query/key heads that don't use rotary position embeddings. - topk_method (`str`, *optional*, defaults to `"noaux_tc"`): - Topk method used in routed gate. n_group (`int`, *optional*, defaults to 8): Number of groups for routed experts. topk_group (`int`, *optional*, defaults to 4): Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). num_experts_per_tok (`int`, *optional*, defaults to 8): Number of selected experts, None means dense model. - moe_layer_freq (`int`, *optional*, defaults to 1): - The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. first_k_dense_replace (`int`, *optional*, defaults to 3): Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). \--k dense layers--/ norm_topk_prob (`bool`, *optional*, defaults to `True`): Whether to normalize the weights of the routed experts. - scoring_func (`str`, *optional*, defaults to `"sigmoid"`): - Method of computing expert weights. aux_loss_alpha (`float`, *optional*, defaults to 0.001): Auxiliary loss weight coefficient. Whether to compute the auxiliary loss for each individual sample. - seq_aux (`bool`, *optional*, defaults to `True`): - Whether to compute auxiliary loss at sequence level. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 4096): @@ -145,6 +133,12 @@ class DeepseekV3Config(PretrainedConfig): model_type = "deepseek_v3" keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `DeepseekV3Model` + base_model_tp_plan = { + "layers.*.gate_proj": "colwise", + "layers.*.up_proj": "colwise", + "layers.*.down_proj": "rowwise", + } def __init__( self, @@ -153,28 +147,22 @@ class DeepseekV3Config(PretrainedConfig): intermediate_size=18432, moe_intermediate_size=2048, num_hidden_layers=61, - num_nextn_predict_layers=1, num_attention_heads=128, num_key_value_heads=128, n_shared_experts=1, n_routed_experts=256, - ep_size=1, routed_scaling_factor=2.5, kv_lora_rank=512, q_lora_rank=1536, qk_rope_head_dim=64, v_head_dim=128, qk_nope_head_dim=128, - topk_method="noaux_tc", n_group=8, topk_group=4, num_experts_per_tok=8, - moe_layer_freq=1, first_k_dense_replace=3, norm_topk_prob=True, - scoring_func="sigmoid", aux_loss_alpha=0.001, - seq_aux=True, hidden_act="silu", max_position_embeddings=4096, initializer_range=0.02, @@ -197,27 +185,23 @@ class DeepseekV3Config(PretrainedConfig): self.intermediate_size = intermediate_size self.moe_intermediate_size = moe_intermediate_size self.num_hidden_layers = num_hidden_layers - self.num_nextn_predict_layers = num_nextn_predict_layers self.num_attention_heads = num_attention_heads self.n_shared_experts = n_shared_experts self.n_routed_experts = n_routed_experts - self.ep_size = ep_size self.routed_scaling_factor = routed_scaling_factor self.kv_lora_rank = kv_lora_rank self.q_lora_rank = q_lora_rank self.qk_rope_head_dim = qk_rope_head_dim self.v_head_dim = v_head_dim self.qk_nope_head_dim = qk_nope_head_dim - self.topk_method = topk_method + self.q_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.head_dim = qk_rope_head_dim self.n_group = n_group self.topk_group = topk_group self.num_experts_per_tok = num_experts_per_tok - self.moe_layer_freq = moe_layer_freq self.first_k_dense_replace = first_k_dense_replace self.norm_topk_prob = norm_topk_prob - self.scoring_func = scoring_func self.aux_loss_alpha = aux_loss_alpha - self.seq_aux = seq_aux # for backward compatibility if num_key_value_heads is None: diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index c823a24af..bb5aad407 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -7,9 +7,7 @@ import math from typing import Callable, List, Optional, Tuple, Union -import numpy as np import torch -import torch.distributed as dist import torch.nn.functional as F from torch import nn @@ -142,63 +140,41 @@ class MoEGate(nn.Module): self.top_k = config.num_experts_per_tok self.n_routed_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor - self.scoring_func = config.scoring_func - self.seq_aux = config.seq_aux - self.topk_method = config.topk_method self.n_group = config.n_group self.topk_group = config.topk_group - - # topk selection algorithm self.norm_topk_prob = config.norm_topk_prob - self.gating_dim = config.hidden_size - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) - if self.topk_method == "noaux_tc": - self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) - self.reset_parameters() - def reset_parameters(self) -> None: - import torch.nn.init as init - - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) def forward(self, hidden_states): - bsz, seq_len, h = hidden_states.shape - ### compute gating score - hidden_states = hidden_states.view(-1, h) - logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None) - if self.scoring_func == "sigmoid": - scores = logits.sigmoid() - else: - raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") + batch_size, seq_length = hidden_states.shape[:-1] + hidden_states = hidden_states.view(-1, self.config.hidden_size) + logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - ### select top-k experts - if self.topk_method == "noaux_tc": - assert not self.training - scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) - group_scores = ( - scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group) - .reshape(bsz * seq_len, -1) - ) # [n, e] - tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] - _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) - topk_weight = scores.gather(1, topk_idx) - else: - raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}") - - ### norm gate to sum 1 - if self.top_k > 1 and self.norm_topk_prob: - denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 - topk_weight = topk_weight / denominator - topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor - - return topk_idx, topk_weight + scores = logits.sigmoid() + scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) # [n, e] + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] + _, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False) + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor + return topk_indices, topk_weights class DeepseekV3MoE(nn.Module): @@ -209,116 +185,84 @@ class DeepseekV3MoE(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.num_experts_per_tok = config.num_experts_per_tok - - if hasattr(config, "ep_size") and config.ep_size > 1: - assert config.ep_size == dist.get_world_size() - self.ep_size = config.ep_size - self.experts_per_rank = config.n_routed_experts // config.ep_size - self.ep_rank = dist.get_rank() - self.experts = nn.ModuleList( - [ - ( - DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) - if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank - else None - ) - for i in range(config.n_routed_experts) - ] - ) - else: - self.ep_size = 1 - self.experts_per_rank = config.n_routed_experts - self.ep_rank = 0 - self.experts = nn.ModuleList( - [ - DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) - for i in range(config.n_routed_experts) - ] - ) + self.experts = nn.ModuleList( + [ + DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(config.n_routed_experts) + ] + ) self.gate = MoEGate(config) - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size) + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size) def forward(self, hidden_states): identity = hidden_states orig_shape = hidden_states.shape - topk_idx, topk_weight = self.gate(hidden_states) + topk_indices, topk_weights = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - if not self.training: - y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) - if self.config.n_shared_experts is not None: - y = y + self.shared_experts(identity) + y = self.moe_infer(hidden_states, topk_indices, topk_weights).view(*orig_shape) + y = y + self.shared_experts(identity) return y @torch.no_grad() - def moe_infer(self, x, topk_ids, topk_weight): - cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) - cnts.scatter_(1, topk_ids, 1) - tokens_per_expert = cnts.sum(dim=0) - idxs = topk_ids.view(-1).argsort() - sorted_tokens = x[idxs // topk_ids.shape[1]] - sorted_tokens_shape = sorted_tokens.shape - if self.ep_size > 1: - tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) - tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0]) - dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) - output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist() - gathered_tokens = sorted_tokens.new_empty( - tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] - ) - input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() - dist.all_to_all( - list(gathered_tokens.split(output_splits)), - list(sorted_tokens.split(input_split_sizes)), - ) - tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum( - dim=0 - ) - gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) - s = 0 - for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): - gatherd_idxs[s : s + k] = i % self.experts_per_rank - s += k - gatherd_idxs = gatherd_idxs.argsort() - sorted_tokens = gathered_tokens[gatherd_idxs] - tokens_per_expert = tokens_per_expert_post_gather + def moe_infer(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): + """ + Perform inference using a Mixture of Experts (MoE) model. + + Args: + hidden_states (torch.Tensor): Input hidden states. + topk_indices (torch.Tensor): Indices of the top-k experts for each token. + topk_weights (torch.Tensor): Weights associated with the top-k experts. + + Returns: + torch.Tensor: Output of the MoE model. + """ + num_experts = len(self.experts) + batch_size, num_topk = topk_indices.shape + + # Count the number of tokens assigned to each expert + expert_counts = topk_indices.new_zeros((batch_size, num_experts)) + expert_counts.scatter_(1, topk_indices, 1) + tokens_per_expert = expert_counts.sum(dim=0) + + # Sort tokens by their assigned expert + sorted_indices = topk_indices.view(-1).argsort() + sorted_tokens = hidden_states[sorted_indices // num_topk] tokens_per_expert = tokens_per_expert.cpu().numpy() - outputs = [] - start_idx = 0 - for i, num_tokens in enumerate(tokens_per_expert): - end_idx = start_idx + num_tokens + # Process tokens through their assigned experts + expert_outputs = [] + current_pos = 0 + + for expert_idx, num_tokens in enumerate(tokens_per_expert): if num_tokens == 0: continue - expert = self.experts[i + self.ep_rank * self.experts_per_rank] - tokens_for_this_expert = sorted_tokens[start_idx:end_idx] - expert_out = expert(tokens_for_this_expert) - outputs.append(expert_out) - start_idx = end_idx - outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) - if self.ep_size > 1: - new_x = torch.empty_like(outs) - new_x[gatherd_idxs] = outs - gathered_tokens = new_x.new_empty(*sorted_tokens_shape) - dist.all_to_all( - list(gathered_tokens.split(input_split_sizes)), - list(new_x.split(output_splits)), - ) - outs = gathered_tokens + next_pos = current_pos + num_tokens + expert = self.experts[expert_idx] + expert_tokens = sorted_tokens[current_pos:next_pos] + expert_outputs.append(expert(expert_tokens)) + current_pos = next_pos - new_x = torch.empty_like(outs) - new_x[idxs] = outs - final_out = ( - new_x.view(*topk_ids.shape, -1) - .type(topk_weight.dtype) - .mul_(topk_weight.unsqueeze(dim=-1)) - .sum(dim=1) - .type(new_x.dtype) - ) - return final_out + # Combine the outputs from all experts + expert_outputs = torch.cat(expert_outputs, dim=0) if expert_outputs else sorted_tokens.new_empty(0) + + # Reorder the outputs to match the original token sequence + reordered_outputs = torch.empty_like(expert_outputs) + reordered_outputs[sorted_indices] = expert_outputs + + # Reshape and apply the expert weights + reordered_outputs = reordered_outputs.view(batch_size, num_topk, -1).type(topk_weights.dtype) + moe_output = torch.matmul(topk_weights.unsqueeze(1), reordered_outputs) + moe_output = moe_output.sum(dim=1).type(hidden_states.dtype) + return moe_output + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -359,16 +303,14 @@ def eager_attention_forward( return attn_output, attn_weights -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. - Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. @@ -379,15 +321,22 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_length, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_length, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + the shape [batch_size, seq_length, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -396,111 +345,87 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class DeepseekV3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): + def __init__(self, config: DeepseekV3Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - - self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.q_lora_rank = config.q_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim self.kv_lora_rank = config.kv_lora_rank self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim - self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + self.q_head_dim = config.q_head_dim self.is_causal = True - - if self.q_lora_rank is None: - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) - else: - self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias) - self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) - self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) self.kv_a_proj_with_mqa = nn.Linear( - self.hidden_size, - config.kv_lora_rank + config.qk_rope_head_dim, + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, bias=config.attention_bias, ) - self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) + self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank) self.kv_b_proj = nn.Linear( - config.kv_lora_rank, + self.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias=False, ) self.o_proj = nn.Linear( self.num_heads * self.v_head_dim, - self.hidden_size, + config.hidden_size, bias=config.attention_bias, ) - self.rotary_emb = DeepseekV3RotaryEmbedding( - config=self.config, - ) + self.scaling = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scaling = self.scaling * mscale * mscale def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, self.num_heads, -1) + batch_size, seq_length = input_shape - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2) q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - .transpose(1, 2) - ) + k_pe = k_pe.view(*input_shape, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(hidden_shape).transpose(1, 2) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - kv_seq_len = value_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + cos, sin = position_embeddings + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states = k_pe.new_empty(batch_size, self.num_heads, seq_length, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states = k_pe.new_empty(batch_size, self.num_heads, seq_length, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe - if self.q_head_dim != self.v_head_dim: + if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) if past_key_value is not None: @@ -528,9 +453,12 @@ class DeepseekV3Attention(nn.Module): scaling=self.scaling, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) - attn_output = self.o_proj(attn_output) + if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) return attn_output, attn_weights @@ -541,15 +469,11 @@ class DeepseekV3DecoderLayer(nn.Module): self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx) - self.mlp = ( - DeepseekV3MoE(config) - if ( - config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0 - ) - else DeepseekV3MLP(config) - ) + if layer_idx >= config.first_k_dense_replace: + self.mlp = DeepseekV3MoE(config) + else: + self.mlp = DeepseekV3MLP(config) + self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index ee648fae5..8b20f068d 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -1,9 +1,7 @@ import math from typing import Callable, Optional, Tuple -import numpy as np import torch -import torch.distributed as dist import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -15,7 +13,6 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import logging from ..llama.modeling_llama import ( - LlamaDecoderLayer, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, @@ -23,6 +20,7 @@ from ..llama.modeling_llama import ( LlamaRMSNorm, LlamaRotaryEmbedding, eager_attention_forward, + rotate_half, ) from .configuration_deepseek_v3 import DeepseekV3Config @@ -38,16 +36,14 @@ class DeepseekV3RotaryEmbedding(LlamaRotaryEmbedding): pass -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. - Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. @@ -58,15 +54,22 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_length, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_length, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + the shape [batch_size, seq_length, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -96,63 +99,41 @@ class MoEGate(nn.Module): self.top_k = config.num_experts_per_tok self.n_routed_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor - self.scoring_func = config.scoring_func - self.seq_aux = config.seq_aux - self.topk_method = config.topk_method self.n_group = config.n_group self.topk_group = config.topk_group - - # topk selection algorithm self.norm_topk_prob = config.norm_topk_prob - self.gating_dim = config.hidden_size - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) - if self.topk_method == "noaux_tc": - self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) - self.reset_parameters() - def reset_parameters(self) -> None: - import torch.nn.init as init - - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) def forward(self, hidden_states): - bsz, seq_len, h = hidden_states.shape - ### compute gating score - hidden_states = hidden_states.view(-1, h) - logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None) - if self.scoring_func == "sigmoid": - scores = logits.sigmoid() - else: - raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") + batch_size, seq_length = hidden_states.shape[:-1] + hidden_states = hidden_states.view(-1, self.config.hidden_size) + logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - ### select top-k experts - if self.topk_method == "noaux_tc": - assert not self.training - scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) - group_scores = ( - scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group) - .reshape(bsz * seq_len, -1) - ) # [n, e] - tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] - _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) - topk_weight = scores.gather(1, topk_idx) - else: - raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}") - - ### norm gate to sum 1 - if self.top_k > 1 and self.norm_topk_prob: - denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 - topk_weight = topk_weight / denominator - topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor - - return topk_idx, topk_weight + scores = logits.sigmoid() + scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) # [n, e] + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] + _, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False) + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor + return topk_indices, topk_weights class DeepseekV3MoE(nn.Module): @@ -163,226 +144,163 @@ class DeepseekV3MoE(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.num_experts_per_tok = config.num_experts_per_tok - - if hasattr(config, "ep_size") and config.ep_size > 1: - assert config.ep_size == dist.get_world_size() - self.ep_size = config.ep_size - self.experts_per_rank = config.n_routed_experts // config.ep_size - self.ep_rank = dist.get_rank() - self.experts = nn.ModuleList( - [ - ( - DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) - if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank - else None - ) - for i in range(config.n_routed_experts) - ] - ) - else: - self.ep_size = 1 - self.experts_per_rank = config.n_routed_experts - self.ep_rank = 0 - self.experts = nn.ModuleList( - [ - DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) - for i in range(config.n_routed_experts) - ] - ) + self.experts = nn.ModuleList( + [ + DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(config.n_routed_experts) + ] + ) self.gate = MoEGate(config) - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size) + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size) def forward(self, hidden_states): identity = hidden_states orig_shape = hidden_states.shape - topk_idx, topk_weight = self.gate(hidden_states) + topk_indices, topk_weights = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - if not self.training: - y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) - if self.config.n_shared_experts is not None: - y = y + self.shared_experts(identity) + y = self.moe_infer(hidden_states, topk_indices, topk_weights).view(*orig_shape) + y = y + self.shared_experts(identity) return y @torch.no_grad() - def moe_infer(self, x, topk_ids, topk_weight): - cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) - cnts.scatter_(1, topk_ids, 1) - tokens_per_expert = cnts.sum(dim=0) - idxs = topk_ids.view(-1).argsort() - sorted_tokens = x[idxs // topk_ids.shape[1]] - sorted_tokens_shape = sorted_tokens.shape - if self.ep_size > 1: - tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) - tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0]) - dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) - output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist() - gathered_tokens = sorted_tokens.new_empty( - tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] - ) - input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() - dist.all_to_all( - list(gathered_tokens.split(output_splits)), - list(sorted_tokens.split(input_split_sizes)), - ) - tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum( - dim=0 - ) - gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) - s = 0 - for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): - gatherd_idxs[s : s + k] = i % self.experts_per_rank - s += k - gatherd_idxs = gatherd_idxs.argsort() - sorted_tokens = gathered_tokens[gatherd_idxs] - tokens_per_expert = tokens_per_expert_post_gather + def moe_infer(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): + """ + Perform inference using a Mixture of Experts (MoE) model. + + Args: + hidden_states (torch.Tensor): Input hidden states. + topk_indices (torch.Tensor): Indices of the top-k experts for each token. + topk_weights (torch.Tensor): Weights associated with the top-k experts. + + Returns: + torch.Tensor: Output of the MoE model. + """ + num_experts = len(self.experts) + batch_size, num_topk = topk_indices.shape + + # Count the number of tokens assigned to each expert + expert_counts = topk_indices.new_zeros((batch_size, num_experts)) + expert_counts.scatter_(1, topk_indices, 1) + tokens_per_expert = expert_counts.sum(dim=0) + + # Sort tokens by their assigned expert + sorted_indices = topk_indices.view(-1).argsort() + sorted_tokens = hidden_states[sorted_indices // num_topk] tokens_per_expert = tokens_per_expert.cpu().numpy() - outputs = [] - start_idx = 0 - for i, num_tokens in enumerate(tokens_per_expert): - end_idx = start_idx + num_tokens + # Process tokens through their assigned experts + expert_outputs = [] + current_pos = 0 + + for expert_idx, num_tokens in enumerate(tokens_per_expert): if num_tokens == 0: continue - expert = self.experts[i + self.ep_rank * self.experts_per_rank] - tokens_for_this_expert = sorted_tokens[start_idx:end_idx] - expert_out = expert(tokens_for_this_expert) - outputs.append(expert_out) - start_idx = end_idx - outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) - if self.ep_size > 1: - new_x = torch.empty_like(outs) - new_x[gatherd_idxs] = outs - gathered_tokens = new_x.new_empty(*sorted_tokens_shape) - dist.all_to_all( - list(gathered_tokens.split(input_split_sizes)), - list(new_x.split(output_splits)), - ) - outs = gathered_tokens + next_pos = current_pos + num_tokens + expert = self.experts[expert_idx] + expert_tokens = sorted_tokens[current_pos:next_pos] + expert_outputs.append(expert(expert_tokens)) + current_pos = next_pos - new_x = torch.empty_like(outs) - new_x[idxs] = outs - final_out = ( - new_x.view(*topk_ids.shape, -1) - .type(topk_weight.dtype) - .mul_(topk_weight.unsqueeze(dim=-1)) - .sum(dim=1) - .type(new_x.dtype) - ) - return final_out + # Combine the outputs from all experts + expert_outputs = torch.cat(expert_outputs, dim=0) if expert_outputs else sorted_tokens.new_empty(0) + + # Reorder the outputs to match the original token sequence + reordered_outputs = torch.empty_like(expert_outputs) + reordered_outputs[sorted_indices] = expert_outputs + + # Reshape and apply the expert weights + reordered_outputs = reordered_outputs.view(batch_size, num_topk, -1).type(topk_weights.dtype) + moe_output = torch.matmul(topk_weights.unsqueeze(1), reordered_outputs) + moe_output = moe_output.sum(dim=1).type(hidden_states.dtype) + return moe_output class DeepseekV3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): + def __init__(self, config: DeepseekV3Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - - self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.q_lora_rank = config.q_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim self.kv_lora_rank = config.kv_lora_rank self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim - self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + self.q_head_dim = config.q_head_dim self.is_causal = True - - if self.q_lora_rank is None: - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) - else: - self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias) - self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) - self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) self.kv_a_proj_with_mqa = nn.Linear( - self.hidden_size, - config.kv_lora_rank + config.qk_rope_head_dim, + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, bias=config.attention_bias, ) - self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) + self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank) self.kv_b_proj = nn.Linear( - config.kv_lora_rank, + self.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias=False, ) self.o_proj = nn.Linear( self.num_heads * self.v_head_dim, - self.hidden_size, + config.hidden_size, bias=config.attention_bias, ) - self.rotary_emb = DeepseekV3RotaryEmbedding( - config=self.config, - ) + self.scaling = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scaling = self.scaling * mscale * mscale def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, self.num_heads, -1) + batch_size, seq_length = input_shape - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2) q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - .transpose(1, 2) - ) + k_pe = k_pe.view(*input_shape, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(hidden_shape).transpose(1, 2) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - kv_seq_len = value_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + cos, sin = position_embeddings + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states = k_pe.new_empty(batch_size, self.num_heads, seq_length, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states = k_pe.new_empty(batch_size, self.num_heads, seq_length, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe - if self.q_head_dim != self.v_head_dim: + if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) if past_key_value is not None: @@ -410,31 +328,72 @@ class DeepseekV3Attention(nn.Module): scaling=self.scaling, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) - attn_output = self.o_proj(attn_output) + if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) return attn_output, attn_weights -class DeepseekV3DecoderLayer(LlamaDecoderLayer): +class DeepseekV3DecoderLayer(nn.Module): def __init__(self, config: DeepseekV3Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx) - self.mlp = ( - DeepseekV3MoE(config) - if ( - config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0 - ) - else DeepseekV3MLP(config) - ) + if layer_idx >= config.first_k_dense_replace: + self.mlp = DeepseekV3MoE(config) + else: + self.mlp = DeepseekV3MLP(config) + self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + class DeepseekV3PreTrainedModel(LlamaPreTrainedModel): pass