register pre_hook and hook both

This commit is contained in:
ryan u 2025-02-03 19:05:15 +09:00
parent ea3c922554
commit c8132687df
3 changed files with 103 additions and 21 deletions

View file

@ -135,12 +135,12 @@ class DeepseekV3Config(PretrainedConfig):
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `DeepseekV3Model`
base_model_tp_plan = {
"layers.*.gate_proj": "colwise",
"layers.*.up_proj": "colwise",
"layers.*.down_proj": "rowwise",
"layers.*.self_attn.q_b_proj": "colwise",
"layers.*.self_attn.kv_b_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.gate_proj": "colwise",
"layers.*.up_proj": "colwise",
"layers.*.down_proj": "rowwise",
}
def __init__(

View file

@ -377,7 +377,6 @@ class DeepseekV3Attention(nn.Module):
k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
cos, sin = position_embeddings
q_rot, k_rot = self.reshape_for_rope(q_rot, k_rot)
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
@ -420,13 +419,6 @@ class DeepseekV3Attention(nn.Module):
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
def reshape_for_rope(self, q, k):
b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d)
b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d)
return q, k
class DeepseekV3DecoderLayer(nn.Module):
def __init__(self, config: DeepseekV3Config, layer_idx: int):
@ -629,7 +621,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
config: DeepseekV3Config
"""
def __init__(self, config: DeepseekV3Config):
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -641,6 +633,8 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
self.gradient_checkpointing = False
self._register_load_state_dict_pre_hook(self.load_pre_hook)
self._register_state_dict_hook(self.load_hook)
# Initialize weights and apply final processing
self.post_init()
@ -879,6 +873,52 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
return causal_mask
def load_pre_hook(self, state_dict, prefix, *args):
"""
Weights have to be permuted for correct rope formulation. We can't do this in the weights
as every other framework already uses the `Llama` original function (which is copyrighted btw).
And I am not even sure it's better.... anyways end of my rant
"""
def permute_for_rope(input_tensor):
"""
When you go from the complex ROPE formulation to sin and cos one, you need
to permute the query and key weights (to avoid doing it on the fly)
"""
n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2]
input_tensor = input_tensor.reshape(n_heads * dim1, dim2)
input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2)
input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2)
return input_tensor
def permute_layer_for_rope(key, num_heads, head_dim, rope_dim):
weight = state_dict[key]
weight = weight.view(num_heads, head_dim, -1)
weight_rot = weight[:, -rope_dim:]
weight_rot = permute_for_rope(weight_rot)
weight[:, -rope_dim:] = weight_rot
weight = weight.view(-1, weight.shape[-1])
state_dict[key] = weight
for k in state_dict:
if "q_b_proj." in k:
permute_layer_for_rope(
k,
num_heads=self.config.num_attention_heads,
head_dim=self.config.qk_head_dim,
rope_dim=self.config.qk_rope_head_dim,
)
if "kv_a_proj_with_mqa." in k:
permute_layer_for_rope(
k,
num_heads=1,
head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
rope_dim=self.config.qk_rope_head_dim,
)
def load_hook(self, module, state_dict, prefix, *args):
self.load_pre_hook(state_dict, prefix, *args)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View file

@ -225,7 +225,6 @@ class DeepseekV3Attention(nn.Module):
k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
cos, sin = position_embeddings
q_rot, k_rot = self.reshape_for_rope(q_rot, k_rot)
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
@ -268,13 +267,6 @@ class DeepseekV3Attention(nn.Module):
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
def reshape_for_rope(self, q, k):
b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d)
b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d)
return q, k
class DeepseekV3DecoderLayer(nn.Module):
def __init__(self, config: DeepseekV3Config, layer_idx: int):
@ -348,7 +340,57 @@ class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
class DeepseekV3Model(LlamaModel):
pass
def __init__(self, config):
super().__init__(config)
self._register_load_state_dict_pre_hook(self.load_pre_hook)
self._register_state_dict_hook(self.load_hook)
self.post_init()
def load_pre_hook(self, state_dict, prefix, *args):
"""
Weights have to be permuted for correct rope formulation. We can't do this in the weights
as every other framework already uses the `Llama` original function (which is copyrighted btw).
And I am not even sure it's better.... anyways end of my rant
"""
def permute_for_rope(input_tensor):
"""
When you go from the complex ROPE formulation to sin and cos one, you need
to permute the query and key weights (to avoid doing it on the fly)
"""
n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2]
input_tensor = input_tensor.reshape(n_heads * dim1, dim2)
input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2)
input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2)
return input_tensor
def permute_layer_for_rope(key, num_heads, head_dim, rope_dim):
weight = state_dict[key]
weight = weight.view(num_heads, head_dim, -1)
weight_rot = weight[:, -rope_dim:]
weight_rot = permute_for_rope(weight_rot)
weight[:, -rope_dim:] = weight_rot
weight = weight.view(-1, weight.shape[-1])
state_dict[key] = weight
for k in state_dict:
if "q_b_proj." in k:
permute_layer_for_rope(
k,
num_heads=self.config.num_attention_heads,
head_dim=self.config.qk_head_dim,
rope_dim=self.config.qk_rope_head_dim,
)
if "kv_a_proj_with_mqa." in k:
permute_layer_for_rope(
k,
num_heads=1,
head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
rope_dim=self.config.qk_rope_head_dim,
)
def load_hook(self, module, state_dict, prefix, *args):
self.load_pre_hook(state_dict, prefix, *args)
class DeepseekV3ForCausalLM(LlamaForCausalLM):