From efdbad56ab5e90c223468e862e26f89e422b4782 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Wed, 25 Jan 2023 10:14:18 +0100 Subject: [PATCH] [GIT] Add test for batched generation (#21282) * Add test * Apply suggestions Co-authored-by: Niels Rogge --- tests/models/git/test_modeling_git.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/models/git/test_modeling_git.py b/tests/models/git/test_modeling_git.py index 67bede12b..e399ddea5 100644 --- a/tests/models/git/test_modeling_git.py +++ b/tests/models/git/test_modeling_git.py @@ -495,3 +495,22 @@ class GitModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 15)) self.assertEqual(generated_ids.shape, expected_shape) self.assertEquals(generated_caption, "what does the front of the bus say at the top? special") + + def test_batched_generation(self): + processor = GitProcessor.from_pretrained("microsoft/git-base-coco") + model = GitForCausalLM.from_pretrained("microsoft/git-base-coco") + model.to(torch_device) + + # create batch of size 2 + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = processor(images=[image, image], return_tensors="pt") + pixel_values = inputs.pixel_values.to(torch_device) + + # we have to prepare `input_ids` with the same batch size as `pixel_values` + start_token_id = model.config.bos_token_id + generated_ids = model.generate( + pixel_values=pixel_values, input_ids=torch.tensor([[start_token_id], [start_token_id]]), max_length=50 + ) + generated_captions = processor.batch_decode(generated_ids, skip_special_tokens=True) + + self.assertEquals(generated_captions, ["two cats sleeping on a pink blanket next to remotes."] * 2)