diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 78a11176e..78186b062 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1776,7 +1776,11 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only - if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 1e155d604..15abcb53d 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -675,7 +675,11 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only - if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index dd0f80cc3..512cc602c 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1648,7 +1648,11 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only - if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, video_grid_thw, attention_mask ) diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index 05cf22a3f..6ef958e2a 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -284,6 +284,29 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0) _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw) + def test_forward_with_rope_deltas_cached(self): + """ + Tests that Qwen2-VL computes new rope deltas every forward pass with new set of inputs. + Rope deltas are cached when we generate and re-used for decoding phase, byt are not reset + automatically after generation ends. See https://github.com/huggingface/transformers/pull/36013 for more + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_generative_model_classes: + model = model_class(config).to(torch_device) + + # Generate and make sure rope_deltas are not `None` + self.assertTrue(model.rope_deltas is None) + generation_output = model.generate( + **input_dict, max_new_tokens=4, return_dict_in_generate=True, output_logits=True + ) + self.assertTrue(model.rope_deltas is not None) + + # Now if we try to do forward pass, we should get new rope logits, because cache is not passed + forward_output = model(**input_dict) + torch.testing.assert_close( + generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4 + ) + @unittest.skip(reason="Feedforward chunking is not yet supported") def test_feed_forward_chunking(self): pass