From 57e9c8321385dfd31bda33df144a4ac849206e06 Mon Sep 17 00:00:00 2001 From: Fernando Rodriguez Sanchez Date: Fri, 5 Jan 2024 12:36:10 +0100 Subject: [PATCH] Fix pos_mask application and update tests accordingly (#27892) * Fix pos_mask application and update tests accordingly * Fix style * Adding comments --------- Co-authored-by: Fernando Rodriguez --- .../models/flava/modeling_flava.py | 5 +- tests/models/flava/test_modeling_flava.py | 55 +++++++++++++++++++ 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 64ede9c89..f96e4292a 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -1949,6 +1949,7 @@ class FlavaForPreTraining(FlavaPreTrainedModel): if mim_labels is not None: mim_labels = mim_labels[pos_mask] + bool_masked_pos = bool_masked_pos[pos_mask] # MMM Image Loss if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0: @@ -1956,8 +1957,6 @@ class FlavaForPreTraining(FlavaPreTrainedModel): end_index = image_masked_embeddings.size(1) - 1 sequence_for_image = sequence_for_image[:, 2 : 2 + end_index, :] - if pos_mask is not None: - sequence_for_image = sequence_for_image[pos_mask] if mim_labels is not None: mim_labels = self._resize_to_2d(mim_labels) bool_masked_pos = self._resize_to_2d(bool_masked_pos) @@ -1979,8 +1978,6 @@ class FlavaForPreTraining(FlavaPreTrainedModel): if multimodal_masked_embeddings is not None and self.mmm_text_weight > 0: sequence_for_text = multimodal_masked_embeddings sequence_for_text = sequence_for_text[:, -text_masked_embeddings.size(1) :, :] - if pos_mask is not None: - sequence_for_text = sequence_for_text[pos_mask] if mlm_labels is not None: mlm_labels = self._resize_to_2d(mlm_labels) diff --git a/tests/models/flava/test_modeling_flava.py b/tests/models/flava/test_modeling_flava.py index e4b3990dc..48a070d9f 100644 --- a/tests/models/flava/test_modeling_flava.py +++ b/tests/models/flava/test_modeling_flava.py @@ -1313,8 +1313,12 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase): return_codebook_pixels=True, return_image_mask=True, ) + # Create a clone of the input_ids tensor that will be its masked version inputs["input_ids_masked"] = inputs["input_ids"].clone() + # Mask the tokens "a" & "cat" from the "a photo of a cat" text using the special 103 value inputs["input_ids_masked"][0, 4:6] = 103 + # MLM labels. It is a cloned version of input_ids where all values are -100 (i.e., ignored) + # except those that are masked, whose original values are stored inputs["mlm_labels"] = inputs["input_ids"].clone() inputs["mlm_labels"][:, :] = -100 inputs["mlm_labels"][0, 4:6] = inputs["input_ids"][0, 4:6] @@ -1338,3 +1342,54 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase): self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4) self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.0290069, places=4) self.assertAlmostEqual(outputs.loss.item(), 11.0626, places=4) + + @slow + def test_inference_with_itm_labels(self): + model_name = "facebook/flava-full" + model = FlavaForPreTraining.from_pretrained(model_name).to(torch_device) + processor = FlavaProcessor.from_pretrained(model_name) + torch.manual_seed(1) + random.seed(1) + + image = prepare_img() + inputs = processor( + text=["a photo of a cat", "a photo of a dog"], + images=[image, image], + padding="max_length", + max_length=77, + return_tensors="pt", + return_codebook_pixels=True, + return_image_mask=True, + ) + # Create a clone of the input_ids tensor that will be its masked version + inputs["input_ids_masked"] = inputs["input_ids"].clone() + # Mask the tokens "a" & "cat" from the "a photo of a cat" text using the special 103 value + inputs["input_ids_masked"][0, 4:6] = 103 + # MLM labels. It is a cloned version of input_ids where all values are -100 (i.e., ignored) + # except those that are masked, whose original values are stored + inputs["mlm_labels"] = inputs["input_ids"].clone() + inputs["mlm_labels"][:, :] = -100 + inputs["mlm_labels"][0, 4:6] = inputs["input_ids"][0, 4:6] + # Manually create the itm_labels tensor that indicates if the image-text match. + # In this case, the firs pair matches and the second does not + inputs["itm_labels"] = torch.tensor([1, 0]) + inputs = inputs.to(torch_device) + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # verify the logits + self.assertEqual( + outputs.contrastive_logits_per_image.shape, + torch.Size((torch.count_nonzero(inputs["itm_labels"]).item(), inputs.input_ids.shape[0])), + ) + self.assertEqual( + outputs.contrastive_logits_per_text.shape, + torch.Size((torch.count_nonzero(inputs["itm_labels"]).item(), inputs.pixel_values.shape[0])), + ) + + expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device) + self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3)) + self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4) + self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 6.89590501, places=4) + self.assertAlmostEqual(outputs.loss.item(), 9.1995, places=4)