default init weights

This commit is contained in:
Arthur Zucker 2024-12-12 10:18:38 +01:00
parent 53450ac365
commit 2016bc47d0
2 changed files with 9 additions and 12 deletions

View file

@ -1887,7 +1887,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
using `from_pretrained`. Any attempt to initialize outside of this function
will be useless as the torch.nn.init function are all replaced with skip.
"""
pass
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _initialize_weights(self, module):
"""

View file

@ -334,17 +334,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
gradient_checkpointing = False
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class LlamaModel(LlamaPreTrainedModel):
_input_embedding = "embed_tokens"