diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 82b112ad3..20b61ddf8 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -162,8 +162,8 @@ class Qwen2_5_VLPatchMerger(nn.Module): def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: tensor_ = tensor.float() - cos = freqs.cos() - sin = freqs.sin() + cos = freqs.cos().float() + sin = freqs.sin().float() output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) return output diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 601ad3737..7646bb6e3 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -65,8 +65,8 @@ else: def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: tensor_ = tensor.float() - cos = freqs.cos() - sin = freqs.sin() + cos = freqs.cos().float() + sin = freqs.sin().float() output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) return output