mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix some stuff
This commit is contained in:
parent
6028e85990
commit
725d00caf4
3 changed files with 7 additions and 4 deletions
|
|
@ -1354,6 +1354,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
# `config.base_model_tp_plan` during `post_init`.
|
||||
_tp_plan = None
|
||||
|
||||
_output_embedding = None
|
||||
_input_embedding = None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
|
|
@ -1832,7 +1835,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
if base_model is not self:
|
||||
return base_model.get_input_embeddings()
|
||||
else:
|
||||
return getattr(self, self._embeding_layer)
|
||||
return getattr(self, self._input_embedding)
|
||||
|
||||
def set_input_embeddings(self, value: nn.Module):
|
||||
"""
|
||||
|
|
@ -1845,7 +1848,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
if base_model is not self:
|
||||
base_model.set_input_embeddings(value)
|
||||
else:
|
||||
raise setattr(self, self._embeding_layer, value)
|
||||
raise setattr(self, self._input_embedding, value)
|
||||
|
||||
def get_output_embeddings(self) -> nn.Module:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from ..auto import AutoModel
|
|||
class AutoForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_embeding_layer = "model.embed_tokens"
|
||||
_input_embedding = "model.embed_tokens"
|
||||
_output_embedding = "lm_head"
|
||||
_no_split_modules = []
|
||||
_supports_cache_class = True
|
||||
|
|
|
|||
|
|
@ -347,7 +347,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
|||
|
||||
|
||||
class LlamaModel(LlamaPreTrainedModel):
|
||||
_embedding_layer = "embed_tokens"
|
||||
_input_embedding = "embed_tokens"
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
|
|
|
|||
Loading…
Reference in a new issue