diff --git a/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py b/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py index 5943aa126..c0b412dde 100644 --- a/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py @@ -135,12 +135,12 @@ class DeepseekV3Config(PretrainedConfig): keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `DeepseekV3Model` base_model_tp_plan = { - "layers.*.gate_proj": "colwise", - "layers.*.up_proj": "colwise", - "layers.*.down_proj": "rowwise", "layers.*.self_attn.q_b_proj": "colwise", "layers.*.self_attn.kv_b_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", + "layers.*.gate_proj": "colwise", + "layers.*.up_proj": "colwise", + "layers.*.down_proj": "rowwise", } def __init__( diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index ef6be2f6a..3757a26ca 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -377,7 +377,6 @@ class DeepseekV3Attention(nn.Module): k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) cos, sin = position_embeddings - q_rot, k_rot = self.reshape_for_rope(q_rot, k_rot) q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) k_rot = k_rot.expand(*k_pass.shape[:-1], -1) @@ -420,13 +419,6 @@ class DeepseekV3Attention(nn.Module): attn_output = self.o_proj(attn_output) return attn_output, attn_weights - def reshape_for_rope(self, q, k): - b, h, s, d = q.shape - q = q.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d) - b, h, s, d = k.shape - k = k.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d) - return q, k - class DeepseekV3DecoderLayer(nn.Module): def __init__(self, config: DeepseekV3Config, layer_idx: int): @@ -629,7 +621,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): config: DeepseekV3Config """ - def __init__(self, config: DeepseekV3Config): + def __init__(self, config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -641,6 +633,8 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) self.gradient_checkpointing = False + self._register_load_state_dict_pre_hook(self.load_pre_hook) + self._register_state_dict_hook(self.load_hook) # Initialize weights and apply final processing self.post_init() @@ -879,6 +873,52 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): return causal_mask + def load_pre_hook(self, state_dict, prefix, *args): + """ + Weights have to be permuted for correct rope formulation. We can't do this in the weights + as every other framework already uses the `Llama` original function (which is copyrighted btw). + And I am not even sure it's better.... anyways end of my rant + """ + + def permute_for_rope(input_tensor): + """ + When you go from the complex ROPE formulation to sin and cos one, you need + to permute the query and key weights (to avoid doing it on the fly) + """ + n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2] + input_tensor = input_tensor.reshape(n_heads * dim1, dim2) + input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2) + input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2) + return input_tensor + + def permute_layer_for_rope(key, num_heads, head_dim, rope_dim): + weight = state_dict[key] + weight = weight.view(num_heads, head_dim, -1) + weight_rot = weight[:, -rope_dim:] + weight_rot = permute_for_rope(weight_rot) + weight[:, -rope_dim:] = weight_rot + weight = weight.view(-1, weight.shape[-1]) + state_dict[key] = weight + + for k in state_dict: + if "q_b_proj." in k: + permute_layer_for_rope( + k, + num_heads=self.config.num_attention_heads, + head_dim=self.config.qk_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) + if "kv_a_proj_with_mqa." in k: + permute_layer_for_rope( + k, + num_heads=1, + head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) + + def load_hook(self, module, state_dict, prefix, *args): + self.load_pre_hook(state_dict, prefix, *args) + class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 55d1ac5c6..e41ded87f 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -225,7 +225,6 @@ class DeepseekV3Attention(nn.Module): k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) cos, sin = position_embeddings - q_rot, k_rot = self.reshape_for_rope(q_rot, k_rot) q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) k_rot = k_rot.expand(*k_pass.shape[:-1], -1) @@ -268,13 +267,6 @@ class DeepseekV3Attention(nn.Module): attn_output = self.o_proj(attn_output) return attn_output, attn_weights - def reshape_for_rope(self, q, k): - b, h, s, d = q.shape - q = q.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d) - b, h, s, d = k.shape - k = k.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d) - return q, k - class DeepseekV3DecoderLayer(nn.Module): def __init__(self, config: DeepseekV3Config, layer_idx: int): @@ -348,7 +340,57 @@ class DeepseekV3PreTrainedModel(LlamaPreTrainedModel): class DeepseekV3Model(LlamaModel): - pass + def __init__(self, config): + super().__init__(config) + self._register_load_state_dict_pre_hook(self.load_pre_hook) + self._register_state_dict_hook(self.load_hook) + self.post_init() + + def load_pre_hook(self, state_dict, prefix, *args): + """ + Weights have to be permuted for correct rope formulation. We can't do this in the weights + as every other framework already uses the `Llama` original function (which is copyrighted btw). + And I am not even sure it's better.... anyways end of my rant + """ + + def permute_for_rope(input_tensor): + """ + When you go from the complex ROPE formulation to sin and cos one, you need + to permute the query and key weights (to avoid doing it on the fly) + """ + n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2] + input_tensor = input_tensor.reshape(n_heads * dim1, dim2) + input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2) + input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2) + return input_tensor + + def permute_layer_for_rope(key, num_heads, head_dim, rope_dim): + weight = state_dict[key] + weight = weight.view(num_heads, head_dim, -1) + weight_rot = weight[:, -rope_dim:] + weight_rot = permute_for_rope(weight_rot) + weight[:, -rope_dim:] = weight_rot + weight = weight.view(-1, weight.shape[-1]) + state_dict[key] = weight + + for k in state_dict: + if "q_b_proj." in k: + permute_layer_for_rope( + k, + num_heads=self.config.num_attention_heads, + head_dim=self.config.qk_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) + if "kv_a_proj_with_mqa." in k: + permute_layer_for_rope( + k, + num_heads=1, + head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) + + def load_hook(self, module, state_dict, prefix, *args): + self.load_pre_hook(state_dict, prefix, *args) class DeepseekV3ForCausalLM(LlamaForCausalLM):