From dc42330388b2243d8fed3c0fde47db8d5c6b8e1d Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 19 Nov 2024 23:51:32 +0800 Subject: [PATCH] fix crash in tiiuae/falcon-11B-vlm image-to-text generation (#34728) Signed-off-by: Wang, Yi --- src/transformers/models/falcon/modeling_falcon.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 504dcf10b..faea670ec 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1277,12 +1277,18 @@ class FalconForCausalLM(FalconPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1302,7 +1308,7 @@ class FalconForCausalLM(FalconPreTrainedModel, GenerationMixin): ) hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states) + lm_logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: