mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
register pre_hook and hook both
This commit is contained in:
parent
ea3c922554
commit
c8132687df
3 changed files with 103 additions and 21 deletions
|
|
@ -135,12 +135,12 @@ class DeepseekV3Config(PretrainedConfig):
|
|||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `DeepseekV3Model`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.gate_proj": "colwise",
|
||||
"layers.*.up_proj": "colwise",
|
||||
"layers.*.down_proj": "rowwise",
|
||||
"layers.*.self_attn.q_b_proj": "colwise",
|
||||
"layers.*.self_attn.kv_b_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.gate_proj": "colwise",
|
||||
"layers.*.up_proj": "colwise",
|
||||
"layers.*.down_proj": "rowwise",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -377,7 +377,6 @@ class DeepseekV3Attention(nn.Module):
|
|||
k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
q_rot, k_rot = self.reshape_for_rope(q_rot, k_rot)
|
||||
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
|
||||
k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
|
||||
|
||||
|
|
@ -420,13 +419,6 @@ class DeepseekV3Attention(nn.Module):
|
|||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
def reshape_for_rope(self, q, k):
|
||||
b, h, s, d = q.shape
|
||||
q = q.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d)
|
||||
b, h, s, d = k.shape
|
||||
k = k.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d)
|
||||
return q, k
|
||||
|
||||
|
||||
class DeepseekV3DecoderLayer(nn.Module):
|
||||
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
||||
|
|
@ -629,7 +621,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
|||
config: DeepseekV3Config
|
||||
"""
|
||||
|
||||
def __init__(self, config: DeepseekV3Config):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
|
@ -641,6 +633,8 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
|||
self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
self._register_load_state_dict_pre_hook(self.load_pre_hook)
|
||||
self._register_state_dict_hook(self.load_hook)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
|
@ -879,6 +873,52 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
|||
|
||||
return causal_mask
|
||||
|
||||
def load_pre_hook(self, state_dict, prefix, *args):
|
||||
"""
|
||||
Weights have to be permuted for correct rope formulation. We can't do this in the weights
|
||||
as every other framework already uses the `Llama` original function (which is copyrighted btw).
|
||||
And I am not even sure it's better.... anyways end of my rant
|
||||
"""
|
||||
|
||||
def permute_for_rope(input_tensor):
|
||||
"""
|
||||
When you go from the complex ROPE formulation to sin and cos one, you need
|
||||
to permute the query and key weights (to avoid doing it on the fly)
|
||||
"""
|
||||
n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2]
|
||||
input_tensor = input_tensor.reshape(n_heads * dim1, dim2)
|
||||
input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2)
|
||||
input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2)
|
||||
return input_tensor
|
||||
|
||||
def permute_layer_for_rope(key, num_heads, head_dim, rope_dim):
|
||||
weight = state_dict[key]
|
||||
weight = weight.view(num_heads, head_dim, -1)
|
||||
weight_rot = weight[:, -rope_dim:]
|
||||
weight_rot = permute_for_rope(weight_rot)
|
||||
weight[:, -rope_dim:] = weight_rot
|
||||
weight = weight.view(-1, weight.shape[-1])
|
||||
state_dict[key] = weight
|
||||
|
||||
for k in state_dict:
|
||||
if "q_b_proj." in k:
|
||||
permute_layer_for_rope(
|
||||
k,
|
||||
num_heads=self.config.num_attention_heads,
|
||||
head_dim=self.config.qk_head_dim,
|
||||
rope_dim=self.config.qk_rope_head_dim,
|
||||
)
|
||||
if "kv_a_proj_with_mqa." in k:
|
||||
permute_layer_for_rope(
|
||||
k,
|
||||
num_heads=1,
|
||||
head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
|
||||
rope_dim=self.config.qk_rope_head_dim,
|
||||
)
|
||||
|
||||
def load_hook(self, module, state_dict, prefix, *args):
|
||||
self.load_pre_hook(state_dict, prefix, *args)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
|
|
|||
|
|
@ -225,7 +225,6 @@ class DeepseekV3Attention(nn.Module):
|
|||
k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
q_rot, k_rot = self.reshape_for_rope(q_rot, k_rot)
|
||||
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
|
||||
k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
|
||||
|
||||
|
|
@ -268,13 +267,6 @@ class DeepseekV3Attention(nn.Module):
|
|||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
def reshape_for_rope(self, q, k):
|
||||
b, h, s, d = q.shape
|
||||
q = q.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d)
|
||||
b, h, s, d = k.shape
|
||||
k = k.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d)
|
||||
return q, k
|
||||
|
||||
|
||||
class DeepseekV3DecoderLayer(nn.Module):
|
||||
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
||||
|
|
@ -348,7 +340,57 @@ class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
|
|||
|
||||
|
||||
class DeepseekV3Model(LlamaModel):
|
||||
pass
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self._register_load_state_dict_pre_hook(self.load_pre_hook)
|
||||
self._register_state_dict_hook(self.load_hook)
|
||||
self.post_init()
|
||||
|
||||
def load_pre_hook(self, state_dict, prefix, *args):
|
||||
"""
|
||||
Weights have to be permuted for correct rope formulation. We can't do this in the weights
|
||||
as every other framework already uses the `Llama` original function (which is copyrighted btw).
|
||||
And I am not even sure it's better.... anyways end of my rant
|
||||
"""
|
||||
|
||||
def permute_for_rope(input_tensor):
|
||||
"""
|
||||
When you go from the complex ROPE formulation to sin and cos one, you need
|
||||
to permute the query and key weights (to avoid doing it on the fly)
|
||||
"""
|
||||
n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2]
|
||||
input_tensor = input_tensor.reshape(n_heads * dim1, dim2)
|
||||
input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2)
|
||||
input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2)
|
||||
return input_tensor
|
||||
|
||||
def permute_layer_for_rope(key, num_heads, head_dim, rope_dim):
|
||||
weight = state_dict[key]
|
||||
weight = weight.view(num_heads, head_dim, -1)
|
||||
weight_rot = weight[:, -rope_dim:]
|
||||
weight_rot = permute_for_rope(weight_rot)
|
||||
weight[:, -rope_dim:] = weight_rot
|
||||
weight = weight.view(-1, weight.shape[-1])
|
||||
state_dict[key] = weight
|
||||
|
||||
for k in state_dict:
|
||||
if "q_b_proj." in k:
|
||||
permute_layer_for_rope(
|
||||
k,
|
||||
num_heads=self.config.num_attention_heads,
|
||||
head_dim=self.config.qk_head_dim,
|
||||
rope_dim=self.config.qk_rope_head_dim,
|
||||
)
|
||||
if "kv_a_proj_with_mqa." in k:
|
||||
permute_layer_for_rope(
|
||||
k,
|
||||
num_heads=1,
|
||||
head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
|
||||
rope_dim=self.config.qk_rope_head_dim,
|
||||
)
|
||||
|
||||
def load_hook(self, module, state_dict, prefix, *args):
|
||||
self.load_pre_hook(state_dict, prefix, *args)
|
||||
|
||||
|
||||
class DeepseekV3ForCausalLM(LlamaForCausalLM):
|
||||
|
|
|
|||
Loading…
Reference in a new issue