mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Qwen2-VL: fix rope delta calculation (#36013)
* fix rope delats calculation * add test * style
This commit is contained in:
parent
e284c7e954
commit
5d75a25b03
4 changed files with 38 additions and 3 deletions
|
|
@ -1776,7 +1776,11 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
|||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
|
||||
if (
|
||||
(cache_position is not None and cache_position[0] == 0)
|
||||
or self.rope_deltas is None
|
||||
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
||||
):
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids,
|
||||
image_grid_thw,
|
||||
|
|
|
|||
|
|
@ -675,7 +675,11 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
|||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
|
||||
if (
|
||||
(cache_position is not None and cache_position[0] == 0)
|
||||
or self.rope_deltas is None
|
||||
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
||||
):
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids,
|
||||
image_grid_thw,
|
||||
|
|
|
|||
|
|
@ -1648,7 +1648,11 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
|
||||
if (
|
||||
(cache_position is not None and cache_position[0] == 0)
|
||||
or self.rope_deltas is None
|
||||
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
||||
):
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||
)
|
||||
|
|
|
|||
|
|
@ -284,6 +284,29 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||
image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
|
||||
|
||||
def test_forward_with_rope_deltas_cached(self):
|
||||
"""
|
||||
Tests that Qwen2-VL computes new rope deltas every forward pass with new set of inputs.
|
||||
Rope deltas are cached when we generate and re-used for decoding phase, byt are not reset
|
||||
automatically after generation ends. See https://github.com/huggingface/transformers/pull/36013 for more
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
|
||||
# Generate and make sure rope_deltas are not `None`
|
||||
self.assertTrue(model.rope_deltas is None)
|
||||
generation_output = model.generate(
|
||||
**input_dict, max_new_tokens=4, return_dict_in_generate=True, output_logits=True
|
||||
)
|
||||
self.assertTrue(model.rope_deltas is not None)
|
||||
|
||||
# Now if we try to do forward pass, we should get new rope logits, because cache is not passed
|
||||
forward_output = model(**input_dict)
|
||||
torch.testing.assert_close(
|
||||
generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue