Qwen2-VL: fix rope delta calculation (#36013)

* fix rope delats calculation

* add test

* style
This commit is contained in:
Raushan Turganbay 2025-02-04 09:48:29 +01:00 committed by GitHub
parent e284c7e954
commit 5d75a25b03
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 38 additions and 3 deletions

View file

@ -1776,7 +1776,11 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
if (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0)
):
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,

View file

@ -675,7 +675,11 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
if (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0)
):
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,

View file

@ -1648,7 +1648,11 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
if (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0)
):
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)

View file

@ -284,6 +284,29 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
def test_forward_with_rope_deltas_cached(self):
"""
Tests that Qwen2-VL computes new rope deltas every forward pass with new set of inputs.
Rope deltas are cached when we generate and re-used for decoding phase, byt are not reset
automatically after generation ends. See https://github.com/huggingface/transformers/pull/36013 for more
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
# Generate and make sure rope_deltas are not `None`
self.assertTrue(model.rope_deltas is None)
generation_output = model.generate(
**input_dict, max_new_tokens=4, return_dict_in_generate=True, output_logits=True
)
self.assertTrue(model.rope_deltas is not None)
# Now if we try to do forward pass, we should get new rope logits, because cache is not passed
forward_output = model(**input_dict)
torch.testing.assert_close(
generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4
)
@unittest.skip(reason="Feedforward chunking is not yet supported")
def test_feed_forward_chunking(self):
pass