mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
_output_embedding and _input_embeding
This commit is contained in:
parent
893ef382c4
commit
13a195a7bb
2 changed files with 5 additions and 4 deletions
|
|
@ -1830,7 +1830,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
if base_model is not self:
|
||||
return base_model.get_input_embeddings()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return getattr(self, self._embeding_layer)
|
||||
|
||||
def set_input_embeddings(self, value: nn.Module):
|
||||
"""
|
||||
|
|
@ -1843,7 +1843,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
if base_model is not self:
|
||||
base_model.set_input_embeddings(value)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise setattr(self, self._embeding_layer, value)
|
||||
|
||||
def get_output_embeddings(self) -> nn.Module:
|
||||
"""
|
||||
|
|
@ -1852,7 +1852,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
Returns:
|
||||
`nn.Module`: A torch module mapping hidden states to vocabulary.
|
||||
"""
|
||||
return None # Overwrite for models with output embeddings
|
||||
return getattr(self, self._output_embedding, None)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
class AutoForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
class AutoForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_embeding_layer = "model.embed_tokens"
|
||||
|
|
|
|||
Loading…
Reference in a new issue