From 4349a0e4012330b83116a3dc671292a8e46df923 Mon Sep 17 00:00:00 2001 From: Minho Shim <6764739+minostauros@users.noreply.github.com> Date: Thu, 9 Jan 2025 00:36:03 +0900 Subject: [PATCH] fix: Qwen2-VL generate with inputs_embeds (#35466) * fix: Qwen2-VL generate with inputs_embeds * change: optional input_ids in get_rope_index --- .../models/qwen2_vl/modeling_qwen2_vl.py | 15 +++++---------- tests/generation/test_utils.py | 4 +--- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 0f04b1d5e..ea169987b 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -32,13 +32,8 @@ from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import ( - AttentionMaskConverter, -) -from ...modeling_outputs import ( - BaseModelOutputWithPast, - ModelOutput, -) +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -1420,7 +1415,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): def get_rope_index( self, - input_ids: torch.LongTensor, + input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -1550,7 +1545,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] else: @@ -1676,7 +1671,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): attention_mask = attention_mask.to(inputs_embeds.device) # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2): + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: position_ids, rope_deltas = self.get_rope_index( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 4ac22e777..c19e0cc4f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1615,9 +1615,7 @@ class GenerationTesterMixin: # There are a few exception patterns in this test: # 1 - Some models can't generate without `input_ids`, when `inputs_embeds` are passed - requires_inputs_ids = any( - model_name in model_class.__name__.lower() for model_name in ["idefics", "qwen2vl"] - ) + requires_inputs_ids = any(model_name in model_class.__name__.lower() for model_name in ["idefics"]) # 2 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex # than calling the embedding layer with `input_ids`. Subcases of this exception: # 2.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag)