mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
hold code only for checkpoints congifuration; remove redundant
This commit is contained in:
parent
51990b9436
commit
f4f0ebd81c
3 changed files with 337 additions and 470 deletions
|
|
@ -45,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig):
|
|||
Dimension of the MoE representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 61):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
|
||||
Number of nextn predict layers in the DeepSeekV3 Model.
|
||||
num_attention_heads (`int`, *optional*, defaults to 128):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 128):
|
||||
|
|
@ -58,11 +56,9 @@ class DeepseekV3Config(PretrainedConfig):
|
|||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
n_shared_experts (`int`, *optional*, defaults to 1):
|
||||
Number of shared experts, None means dense model.
|
||||
Number of shared experts.
|
||||
n_routed_experts (`int`, *optional*, defaults to 256):
|
||||
Number of routed experts, None means dense model.
|
||||
ep_size (`int`, *optional*, defaults to 1):
|
||||
Expert parallelism size for distributed training.
|
||||
Number of routed experts.
|
||||
routed_scaling_factor (`float`, *optional*, defaults to 2.5):
|
||||
Scaling factor or routed experts.
|
||||
kv_lora_rank (`int`, *optional*, defaults to 512):
|
||||
|
|
@ -75,28 +71,20 @@ class DeepseekV3Config(PretrainedConfig):
|
|||
Dimension of the value heads.
|
||||
qk_nope_head_dim (`int`, *optional*, defaults to 128):
|
||||
Dimension of the query/key heads that don't use rotary position embeddings.
|
||||
topk_method (`str`, *optional*, defaults to `"noaux_tc"`):
|
||||
Topk method used in routed gate.
|
||||
n_group (`int`, *optional*, defaults to 8):
|
||||
Number of groups for routed experts.
|
||||
topk_group (`int`, *optional*, defaults to 4):
|
||||
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
|
||||
num_experts_per_tok (`int`, *optional*, defaults to 8):
|
||||
Number of selected experts, None means dense model.
|
||||
moe_layer_freq (`int`, *optional*, defaults to 1):
|
||||
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
|
||||
first_k_dense_replace (`int`, *optional*, defaults to 3):
|
||||
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
||||
\--k dense layers--/
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the weights of the routed experts.
|
||||
scoring_func (`str`, *optional*, defaults to `"sigmoid"`):
|
||||
Method of computing expert weights.
|
||||
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
|
||||
Auxiliary loss weight coefficient.
|
||||
Whether to compute the auxiliary loss for each individual sample.
|
||||
seq_aux (`bool`, *optional*, defaults to `True`):
|
||||
Whether to compute auxiliary loss at sequence level.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
||||
|
|
@ -145,6 +133,12 @@ class DeepseekV3Config(PretrainedConfig):
|
|||
|
||||
model_type = "deepseek_v3"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `DeepseekV3Model`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.gate_proj": "colwise",
|
||||
"layers.*.up_proj": "colwise",
|
||||
"layers.*.down_proj": "rowwise",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -153,28 +147,22 @@ class DeepseekV3Config(PretrainedConfig):
|
|||
intermediate_size=18432,
|
||||
moe_intermediate_size=2048,
|
||||
num_hidden_layers=61,
|
||||
num_nextn_predict_layers=1,
|
||||
num_attention_heads=128,
|
||||
num_key_value_heads=128,
|
||||
n_shared_experts=1,
|
||||
n_routed_experts=256,
|
||||
ep_size=1,
|
||||
routed_scaling_factor=2.5,
|
||||
kv_lora_rank=512,
|
||||
q_lora_rank=1536,
|
||||
qk_rope_head_dim=64,
|
||||
v_head_dim=128,
|
||||
qk_nope_head_dim=128,
|
||||
topk_method="noaux_tc",
|
||||
n_group=8,
|
||||
topk_group=4,
|
||||
num_experts_per_tok=8,
|
||||
moe_layer_freq=1,
|
||||
first_k_dense_replace=3,
|
||||
norm_topk_prob=True,
|
||||
scoring_func="sigmoid",
|
||||
aux_loss_alpha=0.001,
|
||||
seq_aux=True,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=4096,
|
||||
initializer_range=0.02,
|
||||
|
|
@ -197,27 +185,23 @@ class DeepseekV3Config(PretrainedConfig):
|
|||
self.intermediate_size = intermediate_size
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_nextn_predict_layers = num_nextn_predict_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.ep_size = ep_size
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.topk_method = topk_method
|
||||
self.q_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.head_dim = qk_rope_head_dim
|
||||
self.n_group = n_group
|
||||
self.topk_group = topk_group
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.moe_layer_freq = moe_layer_freq
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.scoring_func = scoring_func
|
||||
self.aux_loss_alpha = aux_loss_alpha
|
||||
self.seq_aux = seq_aux
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
|
|
|
|||
|
|
@ -7,9 +7,7 @@
|
|||
import math
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
|
@ -142,63 +140,41 @@ class MoEGate(nn.Module):
|
|||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.scoring_func = config.scoring_func
|
||||
self.seq_aux = config.seq_aux
|
||||
self.topk_method = config.topk_method
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
|
||||
# topk selection algorithm
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.hidden_size
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
if self.topk_method == "noaux_tc":
|
||||
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
import torch.nn.init as init
|
||||
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
|
||||
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
|
||||
if self.scoring_func == "sigmoid":
|
||||
scores = logits.sigmoid()
|
||||
else:
|
||||
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
|
||||
batch_size, seq_length = hidden_states.shape[:-1]
|
||||
hidden_states = hidden_states.view(-1, self.config.hidden_size)
|
||||
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
|
||||
|
||||
### select top-k experts
|
||||
if self.topk_method == "noaux_tc":
|
||||
assert not self.training
|
||||
scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.reshape(bsz * seq_len, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
||||
topk_weight = scores.gather(1, topk_idx)
|
||||
else:
|
||||
raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}")
|
||||
|
||||
### norm gate to sum 1
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
|
||||
|
||||
return topk_idx, topk_weight
|
||||
scores = logits.sigmoid()
|
||||
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.reshape(-1, self.n_routed_experts)
|
||||
) # [n, e]
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
_, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if self.norm_topk_prob:
|
||||
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weights /= denominator
|
||||
topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor
|
||||
return topk_indices, topk_weights
|
||||
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
|
|
@ -209,116 +185,84 @@ class DeepseekV3MoE(nn.Module):
|
|||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
|
||||
if hasattr(config, "ep_size") and config.ep_size > 1:
|
||||
assert config.ep_size == dist.get_world_size()
|
||||
self.ep_size = config.ep_size
|
||||
self.experts_per_rank = config.n_routed_experts // config.ep_size
|
||||
self.ep_rank = dist.get_rank()
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
(
|
||||
DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
|
||||
if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank
|
||||
else None
|
||||
)
|
||||
for i in range(config.n_routed_experts)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.ep_size = 1
|
||||
self.experts_per_rank = config.n_routed_experts
|
||||
self.ep_rank = 0
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
|
||||
for i in range(config.n_routed_experts)
|
||||
]
|
||||
)
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
|
||||
for _ in range(config.n_routed_experts)
|
||||
]
|
||||
)
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size)
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_idx, topk_weight = self.gate(hidden_states)
|
||||
topk_indices, topk_weights = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
if not self.training:
|
||||
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(identity)
|
||||
y = self.moe_infer(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
||||
y = y + self.shared_experts(identity)
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
sorted_tokens_shape = sorted_tokens.shape
|
||||
if self.ep_size > 1:
|
||||
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
|
||||
tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
|
||||
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
|
||||
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist()
|
||||
gathered_tokens = sorted_tokens.new_empty(
|
||||
tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
|
||||
)
|
||||
input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
|
||||
dist.all_to_all(
|
||||
list(gathered_tokens.split(output_splits)),
|
||||
list(sorted_tokens.split(input_split_sizes)),
|
||||
)
|
||||
tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(
|
||||
dim=0
|
||||
)
|
||||
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
|
||||
s = 0
|
||||
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
|
||||
gatherd_idxs[s : s + k] = i % self.experts_per_rank
|
||||
s += k
|
||||
gatherd_idxs = gatherd_idxs.argsort()
|
||||
sorted_tokens = gathered_tokens[gatherd_idxs]
|
||||
tokens_per_expert = tokens_per_expert_post_gather
|
||||
def moe_infer(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
|
||||
"""
|
||||
Perform inference using a Mixture of Experts (MoE) model.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Input hidden states.
|
||||
topk_indices (torch.Tensor): Indices of the top-k experts for each token.
|
||||
topk_weights (torch.Tensor): Weights associated with the top-k experts.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output of the MoE model.
|
||||
"""
|
||||
num_experts = len(self.experts)
|
||||
batch_size, num_topk = topk_indices.shape
|
||||
|
||||
# Count the number of tokens assigned to each expert
|
||||
expert_counts = topk_indices.new_zeros((batch_size, num_experts))
|
||||
expert_counts.scatter_(1, topk_indices, 1)
|
||||
tokens_per_expert = expert_counts.sum(dim=0)
|
||||
|
||||
# Sort tokens by their assigned expert
|
||||
sorted_indices = topk_indices.view(-1).argsort()
|
||||
sorted_tokens = hidden_states[sorted_indices // num_topk]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
# Process tokens through their assigned experts
|
||||
expert_outputs = []
|
||||
current_pos = 0
|
||||
|
||||
for expert_idx, num_tokens in enumerate(tokens_per_expert):
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
if self.ep_size > 1:
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[gatherd_idxs] = outs
|
||||
gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
|
||||
dist.all_to_all(
|
||||
list(gathered_tokens.split(input_split_sizes)),
|
||||
list(new_x.split(output_splits)),
|
||||
)
|
||||
outs = gathered_tokens
|
||||
next_pos = current_pos + num_tokens
|
||||
expert = self.experts[expert_idx]
|
||||
expert_tokens = sorted_tokens[current_pos:next_pos]
|
||||
expert_outputs.append(expert(expert_tokens))
|
||||
current_pos = next_pos
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
# Combine the outputs from all experts
|
||||
expert_outputs = torch.cat(expert_outputs, dim=0) if expert_outputs else sorted_tokens.new_empty(0)
|
||||
|
||||
# Reorder the outputs to match the original token sequence
|
||||
reordered_outputs = torch.empty_like(expert_outputs)
|
||||
reordered_outputs[sorted_indices] = expert_outputs
|
||||
|
||||
# Reshape and apply the expert weights
|
||||
reordered_outputs = reordered_outputs.view(batch_size, num_topk, -1).type(topk_weights.dtype)
|
||||
moe_output = torch.matmul(topk_weights.unsqueeze(1), reordered_outputs)
|
||||
moe_output = moe_output.sum(dim=1).type(hidden_states.dtype)
|
||||
return moe_output
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
|
|
@ -359,16 +303,14 @@ def eager_attention_forward(
|
|||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
|
|
@ -379,15 +321,22 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_length, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_length, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
the shape [batch_size, seq_length, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
|
||||
b, h, s, d = q.shape
|
||||
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||
|
||||
b, h, s, d = k.shape
|
||||
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
|
@ -396,111 +345,87 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|||
class DeepseekV3Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None):
|
||||
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
||||
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
self.q_head_dim = config.q_head_dim
|
||||
|
||||
self.is_causal = True
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False)
|
||||
else:
|
||||
self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)
|
||||
self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
|
||||
self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
|
||||
self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
|
||||
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
self.hidden_size,
|
||||
config.kv_lora_rank + config.qk_rope_head_dim,
|
||||
config.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
|
||||
self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
config.kv_lora_rank,
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
self.rotary_emb = DeepseekV3RotaryEmbedding(
|
||||
config=self.config,
|
||||
)
|
||||
self.scaling = self.q_head_dim ** (-0.5)
|
||||
if self.config.rope_scaling is not None:
|
||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, self.num_heads, -1)
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2)
|
||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
||||
kv = (
|
||||
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
||||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
k_pe = k_pe.view(*input_shape, 1, self.qk_rope_head_dim).transpose(1, 2)
|
||||
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
kv_seq_len = value_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
||||
cos, sin = position_embeddings
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)
|
||||
|
||||
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
||||
query_states = k_pe.new_empty(batch_size, self.num_heads, seq_length, self.q_head_dim)
|
||||
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
||||
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
||||
|
||||
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
||||
key_states = k_pe.new_empty(batch_size, self.num_heads, seq_length, self.q_head_dim)
|
||||
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
||||
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
||||
|
||||
if self.q_head_dim != self.v_head_dim:
|
||||
if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
|
||||
value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
|
||||
|
||||
if past_key_value is not None:
|
||||
|
|
@ -528,9 +453,12 @@ class DeepseekV3Attention(nn.Module):
|
|||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
|
||||
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
|
|
@ -541,15 +469,11 @@ class DeepseekV3DecoderLayer(nn.Module):
|
|||
|
||||
self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = (
|
||||
DeepseekV3MoE(config)
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0
|
||||
)
|
||||
else DeepseekV3MLP(config)
|
||||
)
|
||||
if layer_idx >= config.first_k_dense_replace:
|
||||
self.mlp = DeepseekV3MoE(config)
|
||||
else:
|
||||
self.mlp = DeepseekV3MLP(config)
|
||||
|
||||
self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
import math
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
|
@ -15,7 +13,6 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|||
from ...processing_utils import Unpack
|
||||
from ...utils import logging
|
||||
from ..llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaModel,
|
||||
|
|
@ -23,6 +20,7 @@ from ..llama.modeling_llama import (
|
|||
LlamaRMSNorm,
|
||||
LlamaRotaryEmbedding,
|
||||
eager_attention_forward,
|
||||
rotate_half,
|
||||
)
|
||||
from .configuration_deepseek_v3 import DeepseekV3Config
|
||||
|
||||
|
|
@ -38,16 +36,14 @@ class DeepseekV3RotaryEmbedding(LlamaRotaryEmbedding):
|
|||
pass
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
|
|
@ -58,15 +54,22 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_length, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_length, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
the shape [batch_size, seq_length, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
|
||||
b, h, s, d = q.shape
|
||||
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||
|
||||
b, h, s, d = k.shape
|
||||
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
|
@ -96,63 +99,41 @@ class MoEGate(nn.Module):
|
|||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.scoring_func = config.scoring_func
|
||||
self.seq_aux = config.seq_aux
|
||||
self.topk_method = config.topk_method
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
|
||||
# topk selection algorithm
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.hidden_size
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
if self.topk_method == "noaux_tc":
|
||||
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
import torch.nn.init as init
|
||||
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
|
||||
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
|
||||
if self.scoring_func == "sigmoid":
|
||||
scores = logits.sigmoid()
|
||||
else:
|
||||
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
|
||||
batch_size, seq_length = hidden_states.shape[:-1]
|
||||
hidden_states = hidden_states.view(-1, self.config.hidden_size)
|
||||
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
|
||||
|
||||
### select top-k experts
|
||||
if self.topk_method == "noaux_tc":
|
||||
assert not self.training
|
||||
scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.reshape(bsz * seq_len, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
||||
topk_weight = scores.gather(1, topk_idx)
|
||||
else:
|
||||
raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}")
|
||||
|
||||
### norm gate to sum 1
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
|
||||
|
||||
return topk_idx, topk_weight
|
||||
scores = logits.sigmoid()
|
||||
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.reshape(-1, self.n_routed_experts)
|
||||
) # [n, e]
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
_, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if self.norm_topk_prob:
|
||||
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weights /= denominator
|
||||
topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor
|
||||
return topk_indices, topk_weights
|
||||
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
|
|
@ -163,226 +144,163 @@ class DeepseekV3MoE(nn.Module):
|
|||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
|
||||
if hasattr(config, "ep_size") and config.ep_size > 1:
|
||||
assert config.ep_size == dist.get_world_size()
|
||||
self.ep_size = config.ep_size
|
||||
self.experts_per_rank = config.n_routed_experts // config.ep_size
|
||||
self.ep_rank = dist.get_rank()
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
(
|
||||
DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
|
||||
if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank
|
||||
else None
|
||||
)
|
||||
for i in range(config.n_routed_experts)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.ep_size = 1
|
||||
self.experts_per_rank = config.n_routed_experts
|
||||
self.ep_rank = 0
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
|
||||
for i in range(config.n_routed_experts)
|
||||
]
|
||||
)
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
|
||||
for _ in range(config.n_routed_experts)
|
||||
]
|
||||
)
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size)
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_idx, topk_weight = self.gate(hidden_states)
|
||||
topk_indices, topk_weights = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
if not self.training:
|
||||
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(identity)
|
||||
y = self.moe_infer(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
||||
y = y + self.shared_experts(identity)
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
sorted_tokens_shape = sorted_tokens.shape
|
||||
if self.ep_size > 1:
|
||||
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
|
||||
tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
|
||||
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
|
||||
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist()
|
||||
gathered_tokens = sorted_tokens.new_empty(
|
||||
tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
|
||||
)
|
||||
input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
|
||||
dist.all_to_all(
|
||||
list(gathered_tokens.split(output_splits)),
|
||||
list(sorted_tokens.split(input_split_sizes)),
|
||||
)
|
||||
tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(
|
||||
dim=0
|
||||
)
|
||||
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
|
||||
s = 0
|
||||
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
|
||||
gatherd_idxs[s : s + k] = i % self.experts_per_rank
|
||||
s += k
|
||||
gatherd_idxs = gatherd_idxs.argsort()
|
||||
sorted_tokens = gathered_tokens[gatherd_idxs]
|
||||
tokens_per_expert = tokens_per_expert_post_gather
|
||||
def moe_infer(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
|
||||
"""
|
||||
Perform inference using a Mixture of Experts (MoE) model.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Input hidden states.
|
||||
topk_indices (torch.Tensor): Indices of the top-k experts for each token.
|
||||
topk_weights (torch.Tensor): Weights associated with the top-k experts.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output of the MoE model.
|
||||
"""
|
||||
num_experts = len(self.experts)
|
||||
batch_size, num_topk = topk_indices.shape
|
||||
|
||||
# Count the number of tokens assigned to each expert
|
||||
expert_counts = topk_indices.new_zeros((batch_size, num_experts))
|
||||
expert_counts.scatter_(1, topk_indices, 1)
|
||||
tokens_per_expert = expert_counts.sum(dim=0)
|
||||
|
||||
# Sort tokens by their assigned expert
|
||||
sorted_indices = topk_indices.view(-1).argsort()
|
||||
sorted_tokens = hidden_states[sorted_indices // num_topk]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
# Process tokens through their assigned experts
|
||||
expert_outputs = []
|
||||
current_pos = 0
|
||||
|
||||
for expert_idx, num_tokens in enumerate(tokens_per_expert):
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
if self.ep_size > 1:
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[gatherd_idxs] = outs
|
||||
gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
|
||||
dist.all_to_all(
|
||||
list(gathered_tokens.split(input_split_sizes)),
|
||||
list(new_x.split(output_splits)),
|
||||
)
|
||||
outs = gathered_tokens
|
||||
next_pos = current_pos + num_tokens
|
||||
expert = self.experts[expert_idx]
|
||||
expert_tokens = sorted_tokens[current_pos:next_pos]
|
||||
expert_outputs.append(expert(expert_tokens))
|
||||
current_pos = next_pos
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
# Combine the outputs from all experts
|
||||
expert_outputs = torch.cat(expert_outputs, dim=0) if expert_outputs else sorted_tokens.new_empty(0)
|
||||
|
||||
# Reorder the outputs to match the original token sequence
|
||||
reordered_outputs = torch.empty_like(expert_outputs)
|
||||
reordered_outputs[sorted_indices] = expert_outputs
|
||||
|
||||
# Reshape and apply the expert weights
|
||||
reordered_outputs = reordered_outputs.view(batch_size, num_topk, -1).type(topk_weights.dtype)
|
||||
moe_output = torch.matmul(topk_weights.unsqueeze(1), reordered_outputs)
|
||||
moe_output = moe_output.sum(dim=1).type(hidden_states.dtype)
|
||||
return moe_output
|
||||
|
||||
|
||||
class DeepseekV3Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None):
|
||||
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
||||
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
self.q_head_dim = config.q_head_dim
|
||||
|
||||
self.is_causal = True
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False)
|
||||
else:
|
||||
self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)
|
||||
self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
|
||||
self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
|
||||
self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
|
||||
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
self.hidden_size,
|
||||
config.kv_lora_rank + config.qk_rope_head_dim,
|
||||
config.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
|
||||
self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
config.kv_lora_rank,
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
self.rotary_emb = DeepseekV3RotaryEmbedding(
|
||||
config=self.config,
|
||||
)
|
||||
self.scaling = self.q_head_dim ** (-0.5)
|
||||
if self.config.rope_scaling is not None:
|
||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, self.num_heads, -1)
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2)
|
||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
||||
kv = (
|
||||
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
||||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
k_pe = k_pe.view(*input_shape, 1, self.qk_rope_head_dim).transpose(1, 2)
|
||||
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
kv_seq_len = value_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
||||
cos, sin = position_embeddings
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)
|
||||
|
||||
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
||||
query_states = k_pe.new_empty(batch_size, self.num_heads, seq_length, self.q_head_dim)
|
||||
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
||||
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
||||
|
||||
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
||||
key_states = k_pe.new_empty(batch_size, self.num_heads, seq_length, self.q_head_dim)
|
||||
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
||||
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
||||
|
||||
if self.q_head_dim != self.v_head_dim:
|
||||
if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
|
||||
value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
|
||||
|
||||
if past_key_value is not None:
|
||||
|
|
@ -410,31 +328,72 @@ class DeepseekV3Attention(nn.Module):
|
|||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
|
||||
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class DeepseekV3DecoderLayer(LlamaDecoderLayer):
|
||||
class DeepseekV3DecoderLayer(nn.Module):
|
||||
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = (
|
||||
DeepseekV3MoE(config)
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0
|
||||
)
|
||||
else DeepseekV3MLP(config)
|
||||
)
|
||||
if layer_idx >= config.first_k_dense_replace:
|
||||
self.mlp = DeepseekV3MoE(config)
|
||||
else:
|
||||
self.mlp = DeepseekV3MLP(config)
|
||||
|
||||
self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue