From 594c1610fa6243b2ffb670c49faf389ca5121939 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 8 Jul 2024 15:48:32 +0100 Subject: [PATCH] Mamba & RecurrentGemma: enable strict signature (#31549) * enable strict signature * this should not have been deleted * recurrent_gemma too --- src/transformers/generation/utils.py | 63 ++++++++----------- .../models/mamba/modeling_mamba.py | 2 - .../modeling_recurrent_gemma.py | 2 - 3 files changed, 27 insertions(+), 40 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f99ae64fb..371c59460 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2692,13 +2692,12 @@ class GenerationMixin: # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need @@ -2919,6 +2918,10 @@ class GenerationMixin: while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + # if sequential is True, split the input to batches of batch_size and run sequentially if sequential: if any( @@ -2944,24 +2947,13 @@ class GenerationMixin: model_inputs, split_size=batch_size, full_batch_size=batch_beam_size ) outputs_per_sub_batch = [ - self( - **inputs_per_sub_batch, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - for inputs_per_sub_batch in inputs_per_sub_batches + self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches ] outputs = stack_model_outputs(outputs_per_sub_batch) else: # Unchanged original behavior - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 @@ -3241,12 +3233,12 @@ class GenerationMixin: # do one decoder step on all beams of all sentences in batch model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 @@ -3522,12 +3514,11 @@ class GenerationMixin: while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 @@ -3793,11 +3784,11 @@ class GenerationMixin: model_inputs["num_logits_to_keep"] = candidate_length + 1 # 2.2. Run a forward pass on the candidate sequence - outputs = self( - **model_inputs, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + outputs = self(**model_inputs) # 2.3. Process the new logits new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 04430ada8..aa1bec59f 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -545,7 +545,6 @@ class MambaModel(MambaPreTrainedModel): use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it ) -> Union[Tuple, MambaOutput]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -673,7 +672,6 @@ class MambaForCausalLM(MambaPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, use_cache: Optional[bool] = None, - **kwargs, # for now we need this for generation ) -> Union[Tuple, MambaCausalLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 2a8e1c25f..40032851b 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -684,7 +684,6 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel): use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs, ) -> Union[Tuple, BaseModelOutputWithNoAttention]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -823,7 +822,6 @@ class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, use_cache: Optional[bool] = None, - **kwargs, # for now we need this for generation ) -> Union[Tuple, CausalLMOutput]: r""" Args: