diff --git a/docs/source/en/model_doc/pix2struct.mdx b/docs/source/en/model_doc/pix2struct.mdx index 340b06c69..fb4ecf05e 100644 --- a/docs/source/en/model_doc/pix2struct.mdx +++ b/docs/source/en/model_doc/pix2struct.mdx @@ -28,9 +28,8 @@ We therefore advise you to use these models for the tasks they have been fine tu This model was contributed by [ybelkada](https://huggingface.co/ybelkada). The original code can be found [here](https://github.com/google-research/pix2struct). -## Resources: +## Resources -- [Paper](https://arxiv.org/abs/2210.03347) - [Fine-tuning Notebook](https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_pix2struct.ipynb) - [All models](https://huggingface.co/models?search=pix2struct) @@ -70,4 +69,4 @@ The original code can be found [here](https://github.com/google-research/pix2str ## Pix2StructForConditionalGeneration [[autodoc]] Pix2StructForConditionalGeneration - - forward + - forward \ No newline at end of file diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index d29e1dbce..1df7b57c7 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -681,7 +681,7 @@ class GenerationConfig(PushToHubMixin): # Special case: some models have generation attributes set in the decoder. Use them if still unset in the # generation config. - for decoder_name in ("decoder", "generator"): + for decoder_name in ("decoder", "generator", "text_config"): if decoder_name in config_dict: default_generation_config = GenerationConfig() decoder_config = config_dict[decoder_name] diff --git a/src/transformers/models/pix2struct/configuration_pix2struct.py b/src/transformers/models/pix2struct/configuration_pix2struct.py index dead3d8a0..244cb2705 100644 --- a/src/transformers/models/pix2struct/configuration_pix2struct.py +++ b/src/transformers/models/pix2struct/configuration_pix2struct.py @@ -358,9 +358,10 @@ class Pix2StructConfig(PretrainedConfig): initializer_range=0.02, is_vqa=False, tie_word_embeddings=False, + is_encoder_decoder=True, **kwargs, ): - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + super().__init__(tie_word_embeddings=tie_word_embeddings, is_encoder_decoder=is_encoder_decoder, **kwargs) if text_config is None: text_config = {} @@ -373,9 +374,9 @@ class Pix2StructConfig(PretrainedConfig): self.text_config = Pix2StructTextConfig(**text_config) self.vision_config = Pix2StructVisionConfig(**vision_config) - self.text_config.encoder_hidden_size = self.vision_config.hidden_size self.decoder_start_token_id = self.text_config.decoder_start_token_id self.pad_token_id = self.text_config.pad_token_id + self.eos_token_id = self.text_config.eos_token_id self.initializer_factor = initializer_factor self.initializer_range = initializer_range diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index ead913e1d..6ce279f02 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -14,7 +14,6 @@ # limitations under the License. """ Pix2Struct modeling file""" -import copy import math from typing import Dict, List, Optional, Tuple, Union @@ -1580,25 +1579,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): cross_attentions=all_cross_attentions, ) - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), - "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), - "is_decoder": True, - } - @add_start_docstrings( "A conditional generation model with a language modeling head. Can be used for sequence generation tasks.", @@ -1618,13 +1598,9 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): def __init__(self, config: Pix2StructConfig): super().__init__(config) - encoder_config = copy.deepcopy(config.vision_config) - self.encoder = Pix2StructVisionModel(encoder_config) - decoder_config = copy.deepcopy(config.text_config) - self.decoder_start_token_id = decoder_config.pad_token_id - self.decoder_eos_token_ids = decoder_config.eos_token_id - self.decoder = Pix2StructTextModel(decoder_config) + self.encoder = Pix2StructVisionModel(config.vision_config) + self.decoder = Pix2StructTextModel(config.text_config) self.is_vqa = config.is_vqa @@ -1682,6 +1658,8 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): Example: + Inference: + ```python >>> from PIL import Image >>> import requests @@ -1690,15 +1668,40 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base") - >>> labels = "A stop sign is on the street corner." >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> inputs = processor(images=image, text=labels, return_tensors="pt", add_special_tokens=True) + >>> inputs = processor(images=image, return_tensors="pt") + + >>> # autoregressive generation + >>> generated_ids = model.generate(**inputs, max_new_tokens=50) + >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> print(generated_text) + A stop sign is on a street corner. + ``` + + Training: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration + + >>> processor = AutoProcessor.from_pretrained("google/pix2struct-base") + >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base") + + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "A stop sign is on the street corner." + + >>> inputs = processor(images=image, return_tensors="pt") + >>> labels = processor(text=text, return_tensors="pt").input_ids >>> # forward pass - >>> outputs = model(**inputs) - >>> last_hidden_states = outputs.loss + >>> outputs = model(**inputs, labels=labels) + >>> loss = outputs.loss + >>> print(loss.item()) + 5.239729881286621 ```""" use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1759,54 +1762,29 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): encoder_attentions=encoder_outputs.attentions, ) - @torch.no_grad() - def generate( + def prepare_inputs_for_generation( self, - flattened_patches: torch.FloatTensor, - decoder_input_ids: Optional[torch.LongTensor] = None, + input_ids, + flattened_patches: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, - decoder_attention_mask: Optional[torch.LongTensor] = None, - **generate_kwargs, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, ): - r""" - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration - - >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") - >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base") - - >>> conditional_text = "A stop sign" - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=conditional_text, return_tensors="pt", add_special_tokens=True) - - >>> # forward pass - >>> outputs = model.generate(**inputs) - >>> print(processor.batch_decode(outputs, skip_special_tokens=True)) - ['A stop sign the street with a sign that says yes'] - ```""" - batch_size, _, _ = flattened_patches.shape - - vision_outputs = self.encoder(flattened_patches=flattened_patches, attention_mask=attention_mask) - - image_embeds = vision_outputs[0] - - if isinstance(decoder_input_ids, torch.Tensor): - # check if the first element of `input_ids` is equal to `decoder_input_ids`: - if (decoder_input_ids[:, 0] != self.decoder_start_token_id).all().item(): - # add `decoder_input_ids` as first token to `input_ids` - decoder_input_ids = torch.cat( + if isinstance(input_ids, torch.Tensor): + # check if the first element of `input_ids` is equal to `input_ids`: + if (input_ids[:, 0] != self.config.decoder_start_token_id).all().item(): + # add `input_ids` as first token to `input_ids` + input_ids = torch.cat( [ - torch.ones((decoder_input_ids.shape[0], 1), dtype=torch.long, device=decoder_input_ids.device) - * self.decoder_start_token_id, - decoder_input_ids, + torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) + * self.config.decoder_start_token_id, + input_ids, ], dim=-1, ) @@ -1823,20 +1801,26 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): ], dim=-1, ) - elif decoder_input_ids is None: - decoder_input_ids = ( - torch.LongTensor([[self.decoder_start_token_id]]).repeat(batch_size, 1).to(image_embeds.device) - ) + elif input_ids is None: + batch_size = flattened_patches.shape[0] + input_ids = torch.LongTensor([[self.input_ids]]).repeat(batch_size, 1).to(input_ids.device) if decoder_attention_mask is None: - decoder_attention_mask = torch.ones_like(decoder_input_ids).to(image_embeds.device) + decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device) - outputs = self.decoder.generate( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=image_embeds, - encoder_attention_mask=attention_mask, - **generate_kwargs, - ) + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] - return outputs + return { + "flattened_patches": flattened_patches, + "decoder_input_ids": input_ids, + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } diff --git a/tests/models/pix2struct/test_modeling_pix2struct.py b/tests/models/pix2struct/test_modeling_pix2struct.py index f56f8f6d3..42ee3c2b4 100644 --- a/tests/models/pix2struct/test_modeling_pix2struct.py +++ b/tests/models/pix2struct/test_modeling_pix2struct.py @@ -443,24 +443,22 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase): # signature.parameters is an OrderedDict => so arg_names order is deterministic arg_names = [*signature.parameters.keys()] - if model.config.is_encoder_decoder: - expected_arg_names = [ - "input_ids", - "attention_mask", - "decoder_input_ids", - "decoder_attention_mask", - ] - expected_arg_names.extend( - ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"] - if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names - else ["encoder_outputs"] - ) - self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) - else: - expected_arg_names = ( - ["input_ids"] if model_class != Pix2StructForConditionalGeneration else ["flattened_patches"] - ) - self.assertListEqual(arg_names[:1], expected_arg_names) + expected_arg_names = [ + "flattened_patches", + "attention_mask", + "decoder_input_ids", + "decoder_attention_mask", + "head_mask", + "decoder_head_mask", + "cross_attn_head_mask", + "encoder_outputs", + "past_key_values", + "labels", + "decoder_inputs_embeds", + "use_cache", + ] + + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) def test_training(self): if not self.model_tester.is_training: @@ -765,7 +763,7 @@ class Pix2StructIntegrationTest(unittest.TestCase): ) def test_vqa_model(self): - model_id = "ybelkada/pix2struct-ai2d-base" + model_id = "google/pix2struct-ai2d-base" image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg" image = Image.open(requests.get(image_url, stream=True).raw) @@ -784,7 +782,7 @@ class Pix2StructIntegrationTest(unittest.TestCase): self.assertEqual(processor.decode(predictions[0], skip_special_tokens=True), "ash cloud") def test_vqa_model_batched(self): - model_id = "ybelkada/pix2struct-ai2d-base" + model_id = "google/pix2struct-ai2d-base" image_urls = [ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg", diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index f90e6055c..d3143b216 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -306,6 +306,7 @@ src/transformers/models/pegasus/tokenization_pegasus.py src/transformers/models/pegasus/tokenization_pegasus_fast.py src/transformers/models/perceiver/tokenization_perceiver.py src/transformers/models/phobert/tokenization_phobert.py +src/transformers/models/pix2struct/modeling_pix2struct.py src/transformers/models/plbart/tokenization_plbart.py src/transformers/models/prophetnet/tokenization_prophetnet.py src/transformers/models/rag/tokenization_rag.py