fix some stuff

This commit is contained in:
Arthur Zucker 2024-12-12 09:22:04 +01:00
parent 6028e85990
commit 725d00caf4
3 changed files with 7 additions and 4 deletions

View file

@ -1354,6 +1354,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# `config.base_model_tp_plan` during `post_init`.
_tp_plan = None
_output_embedding = None
_input_embedding = None
@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
@ -1832,7 +1835,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if base_model is not self:
return base_model.get_input_embeddings()
else:
return getattr(self, self._embeding_layer)
return getattr(self, self._input_embedding)
def set_input_embeddings(self, value: nn.Module):
"""
@ -1845,7 +1848,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if base_model is not self:
base_model.set_input_embeddings(value)
else:
raise setattr(self, self._embeding_layer, value)
raise setattr(self, self._input_embedding, value)
def get_output_embeddings(self) -> nn.Module:
"""

View file

@ -19,7 +19,7 @@ from ..auto import AutoModel
class AutoForCausalLM(PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_embeding_layer = "model.embed_tokens"
_input_embedding = "model.embed_tokens"
_output_embedding = "lm_head"
_no_split_modules = []
_supports_cache_class = True

View file

@ -347,7 +347,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
class LlamaModel(LlamaPreTrainedModel):
_embedding_layer = "embed_tokens"
_input_embedding = "embed_tokens"
def __init__(self, config: LlamaConfig):
super().__init__(config)