mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Generate: Fix GIT batched captioning (#21738)
This commit is contained in:
parent
78a93d17c0
commit
1d4b797852
3 changed files with 44 additions and 14 deletions
|
|
@ -1217,7 +1217,7 @@ class TFGenerationMixin:
|
|||
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
|
||||
# the attention mask) can rely on the actual model input.
|
||||
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
|
||||
inputs, bos_token_id, batch_size=model_kwargs["inputs_embeds"].shape[0]
|
||||
inputs, bos_token_id, model_kwargs=model_kwargs
|
||||
)
|
||||
else:
|
||||
if inputs is not None:
|
||||
|
|
@ -1225,9 +1225,7 @@ class TFGenerationMixin:
|
|||
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
||||
|
||||
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
|
||||
inputs = self._maybe_initialize_input_ids_for_generation(
|
||||
inputs, bos_token_id, model_kwargs.get("encoder_outputs")
|
||||
)
|
||||
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
|
||||
|
||||
return inputs, input_name, model_kwargs
|
||||
|
||||
|
|
@ -1235,13 +1233,13 @@ class TFGenerationMixin:
|
|||
self,
|
||||
inputs: Optional[tf.Tensor] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
encoder_outputs: Optional[ModelOutput] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
model_kwargs: Optional[Dict[str, tf.Tensor]] = None,
|
||||
) -> tf.Tensor:
|
||||
"""Initializes input ids for generation, if necessary."""
|
||||
if inputs is not None:
|
||||
return inputs
|
||||
|
||||
encoder_outputs = model_kwargs.get("encoder_outputs")
|
||||
if self.config.is_encoder_decoder and encoder_outputs is not None:
|
||||
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
|
||||
shape = encoder_outputs.last_hidden_state.shape[:-1]
|
||||
|
|
@ -1250,7 +1248,13 @@ class TFGenerationMixin:
|
|||
if bos_token_id is None:
|
||||
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
|
||||
|
||||
batch_size = batch_size if batch_size is not None else 1
|
||||
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
|
||||
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
|
||||
batch_size = 1
|
||||
for value in model_kwargs.values():
|
||||
if isinstance(value, tf.Tensor):
|
||||
batch_size = value.shape[0]
|
||||
break
|
||||
return tf.ones((batch_size, 1), dtype=tf.int32) * bos_token_id
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -544,7 +544,7 @@ class GenerationMixin:
|
|||
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
|
||||
# the attention mask) can rely on the actual model input.
|
||||
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
|
||||
inputs, bos_token_id, batch_size=model_kwargs["inputs_embeds"].shape[0]
|
||||
inputs, bos_token_id, model_kwargs=model_kwargs
|
||||
)
|
||||
else:
|
||||
if inputs is not None:
|
||||
|
|
@ -552,9 +552,7 @@ class GenerationMixin:
|
|||
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
||||
|
||||
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
|
||||
inputs = self._maybe_initialize_input_ids_for_generation(
|
||||
inputs, bos_token_id, model_kwargs.get("encoder_outputs")
|
||||
)
|
||||
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
|
||||
return inputs, input_name, model_kwargs
|
||||
|
||||
def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||
|
|
@ -567,13 +565,13 @@ class GenerationMixin:
|
|||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
encoder_outputs: Optional[ModelOutput] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.LongTensor:
|
||||
"""Initializes input ids for generation, if necessary."""
|
||||
if inputs is not None:
|
||||
return inputs
|
||||
|
||||
encoder_outputs = model_kwargs.get("encoder_outputs")
|
||||
if self.config.is_encoder_decoder and encoder_outputs is not None:
|
||||
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
|
||||
shape = encoder_outputs.last_hidden_state.size()[:-1]
|
||||
|
|
@ -582,7 +580,13 @@ class GenerationMixin:
|
|||
if bos_token_id is None:
|
||||
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
|
||||
|
||||
batch_size = batch_size if batch_size is not None else 1
|
||||
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
|
||||
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
|
||||
batch_size = 1
|
||||
for value in model_kwargs.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
batch_size = value.shape[0]
|
||||
break
|
||||
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
|
||||
|
||||
def _prepare_attention_mask_for_generation(
|
||||
|
|
|
|||
|
|
@ -340,6 +340,24 @@ class GitModelTester:
|
|||
|
||||
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
|
||||
|
||||
def _test_batched_generate_captioning(self, config, input_ids, input_mask, pixel_values):
|
||||
model = GitForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# generate
|
||||
generated_ids = model.generate(
|
||||
input_ids=None, # captioning -> no input_ids
|
||||
attention_mask=None,
|
||||
pixel_values=pixel_values,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
|
||||
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
|
|
@ -398,6 +416,10 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester._test_beam_search_generate(*config_and_inputs)
|
||||
|
||||
def test_batched_generate_captioning(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester._test_batched_generate_captioning(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
|
|
|
|||
Loading…
Reference in a new issue