From 014047e1c8784c00e2a04cb04ffcecdd5cb23c16 Mon Sep 17 00:00:00 2001 From: DeepWave <31004098+DeepWaved@users.noreply.github.com> Date: Fri, 7 Feb 2025 17:43:45 +0800 Subject: [PATCH] Fix bug in apply_rotary_pos_emb_flashatt: in Qwen2-5-VL (#36065) --- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 4 ++-- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 82b112ad3..20b61ddf8 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -162,8 +162,8 @@ class Qwen2_5_VLPatchMerger(nn.Module): def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: tensor_ = tensor.float() - cos = freqs.cos() - sin = freqs.sin() + cos = freqs.cos().float() + sin = freqs.sin().float() output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) return output diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 601ad3737..7646bb6e3 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -65,8 +65,8 @@ else: def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: tensor_ = tensor.float() - cos = freqs.cos() - sin = freqs.sin() + cos = freqs.cos().float() + sin = freqs.sin().float() output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) return output