Fix bug in apply_rotary_pos_emb_flashatt: in Qwen2-5-VL (#36065)

This commit is contained in:
DeepWave 2025-02-07 17:43:45 +08:00 committed by GitHub
parent 006d9249ec
commit 014047e1c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 4 additions and 4 deletions

View file

@ -162,8 +162,8 @@ class Qwen2_5_VLPatchMerger(nn.Module):
def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
tensor_ = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = freqs.cos().float()
sin = freqs.sin().float()
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
return output

View file

@ -65,8 +65,8 @@ else:
def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
tensor_ = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = freqs.cos().float()
sin = freqs.sin().float()
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
return output