_output_embedding and _input_embeding

This commit is contained in:
Arthur Zucker 2024-12-11 13:53:35 +01:00
parent 893ef382c4
commit 13a195a7bb
2 changed files with 5 additions and 4 deletions

View file

@ -1830,7 +1830,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if base_model is not self:
return base_model.get_input_embeddings()
else:
raise NotImplementedError
return getattr(self, self._embeding_layer)
def set_input_embeddings(self, value: nn.Module):
"""
@ -1843,7 +1843,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if base_model is not self:
base_model.set_input_embeddings(value)
else:
raise NotImplementedError
raise setattr(self, self._embeding_layer, value)
def get_output_embeddings(self) -> nn.Module:
"""
@ -1852,7 +1852,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Returns:
`nn.Module`: A torch module mapping hidden states to vocabulary.
"""
return None # Overwrite for models with output embeddings
return getattr(self, self._output_embedding, None)
def _init_weights(self, module):
"""

View file

@ -1,6 +1,7 @@
import torch
from transformers import PreTrainedModel
class AutoForCausalLM(LlamaPreTrainedModel, GenerationMixin):
class AutoForCausalLM(PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_embeding_layer = "model.embed_tokens"