mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
default init weights
This commit is contained in:
parent
53450ac365
commit
2016bc47d0
2 changed files with 9 additions and 12 deletions
|
|
@ -1887,7 +1887,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
using `from_pretrained`. Any attempt to initialize outside of this function
|
||||
will be useless as the torch.nn.init function are all replaced with skip.
|
||||
"""
|
||||
pass
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _initialize_weights(self, module):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -334,17 +334,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
|||
_supports_static_cache = True
|
||||
gradient_checkpointing = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class LlamaModel(LlamaPreTrainedModel):
|
||||
_input_embedding = "embed_tokens"
|
||||
|
|
|
|||
Loading…
Reference in a new issue