diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index e2c5781f7..2cb9c7cad 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -1217,7 +1217,7 @@ class TFGenerationMixin: # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of # the attention mask) can rely on the actual model input. model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( - inputs, bos_token_id, batch_size=model_kwargs["inputs_embeds"].shape[0] + inputs, bos_token_id, model_kwargs=model_kwargs ) else: if inputs is not None: @@ -1225,9 +1225,7 @@ class TFGenerationMixin: inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" # 4. if `inputs` is still None, try to create `input_ids` from BOS token - inputs = self._maybe_initialize_input_ids_for_generation( - inputs, bos_token_id, model_kwargs.get("encoder_outputs") - ) + inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) return inputs, input_name, model_kwargs @@ -1235,13 +1233,13 @@ class TFGenerationMixin: self, inputs: Optional[tf.Tensor] = None, bos_token_id: Optional[int] = None, - encoder_outputs: Optional[ModelOutput] = None, - batch_size: Optional[int] = None, + model_kwargs: Optional[Dict[str, tf.Tensor]] = None, ) -> tf.Tensor: """Initializes input ids for generation, if necessary.""" if inputs is not None: return inputs + encoder_outputs = model_kwargs.get("encoder_outputs") if self.config.is_encoder_decoder and encoder_outputs is not None: # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding shape = encoder_outputs.last_hidden_state.shape[:-1] @@ -1250,7 +1248,13 @@ class TFGenerationMixin: if bos_token_id is None: raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") - batch_size = batch_size if batch_size is not None else 1 + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, tf.Tensor): + batch_size = value.shape[0] + break return tf.ones((batch_size, 1), dtype=tf.int32) * bos_token_id @staticmethod diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 61f6090a9..062455049 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -544,7 +544,7 @@ class GenerationMixin: # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of # the attention mask) can rely on the actual model input. model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( - inputs, bos_token_id, batch_size=model_kwargs["inputs_embeds"].shape[0] + inputs, bos_token_id, model_kwargs=model_kwargs ) else: if inputs is not None: @@ -552,9 +552,7 @@ class GenerationMixin: inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" # 4. if `inputs` is still None, try to create `input_ids` from BOS token - inputs = self._maybe_initialize_input_ids_for_generation( - inputs, bos_token_id, model_kwargs.get("encoder_outputs") - ) + inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) return inputs, input_name, model_kwargs def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: @@ -567,13 +565,13 @@ class GenerationMixin: self, inputs: Optional[torch.Tensor] = None, bos_token_id: Optional[int] = None, - encoder_outputs: Optional[ModelOutput] = None, - batch_size: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.LongTensor: """Initializes input ids for generation, if necessary.""" if inputs is not None: return inputs + encoder_outputs = model_kwargs.get("encoder_outputs") if self.config.is_encoder_decoder and encoder_outputs is not None: # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding shape = encoder_outputs.last_hidden_state.size()[:-1] @@ -582,7 +580,13 @@ class GenerationMixin: if bos_token_id is None: raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") - batch_size = batch_size if batch_size is not None else 1 + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, torch.Tensor): + batch_size = value.shape[0] + break return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id def _prepare_attention_mask_for_generation( diff --git a/tests/models/git/test_modeling_git.py b/tests/models/git/test_modeling_git.py index b3b34f664..969a0b174 100644 --- a/tests/models/git/test_modeling_git.py +++ b/tests/models/git/test_modeling_git.py @@ -340,6 +340,24 @@ class GitModelTester: self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20)) + def _test_batched_generate_captioning(self, config, input_ids, input_mask, pixel_values): + model = GitForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # generate + generated_ids = model.generate( + input_ids=None, # captioning -> no input_ids + attention_mask=None, + pixel_values=pixel_values, + do_sample=False, + max_length=20, + num_beams=2, + num_return_sequences=2, + ) + + self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() @@ -398,6 +416,10 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester._test_beam_search_generate(*config_and_inputs) + def test_batched_generate_captioning(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester._test_batched_generate_captioning(*config_and_inputs) + def test_model_various_embeddings(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() for type in ["absolute", "relative_key", "relative_key_query"]: