From e7c3fa7f57ea5df2eedc6c7766ade06d75060904 Mon Sep 17 00:00:00 2001 From: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Date: Tue, 22 Oct 2024 11:57:44 -0400 Subject: [PATCH] Fix continue_final_message for image-text-to-text chat templates (#34236) * fix continue_final_message for vlms * Add one test for vlms continue_final_message chat template --- src/transformers/tokenization_utils_base.py | 5 ++++- tests/models/llava/test_processor_llava.py | 21 +++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index b52a93ae9..16c05a140 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1874,7 +1874,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): **template_kwargs, ) if continue_final_message: - final_message = chat[-1]["content"].strip() + final_message = chat[-1]["content"] + if isinstance(final_message, (list, tuple)): + final_message = final_message[-1]["text"] + final_message = final_message.strip() rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)].rstrip() rendered.append(rendered_chat) diff --git a/tests/models/llava/test_processor_llava.py b/tests/models/llava/test_processor_llava.py index 06a180615..d3a66a16d 100644 --- a/tests/models/llava/test_processor_llava.py +++ b/tests/models/llava/test_processor_llava.py @@ -93,3 +93,24 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase): formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) self.assertEqual(expected_prompt, formatted_prompt) + + def test_chat_template_with_continue_final_message(self): + processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + expected_prompt = "USER: \nDescribe this image. ASSISTANT: There is a dog and" + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "Describe this image."}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "There is a dog and"}, + ], + }, + ] + prompt = processor.apply_chat_template(messages, continue_final_message=True) + self.assertEqual(expected_prompt, prompt)