mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix bug in apply_rotary_pos_emb_flashatt: in Qwen2-5-VL (#36065)
This commit is contained in:
parent
128b840247
commit
8aa45e177e
2 changed files with 4 additions and 4 deletions
|
|
@ -162,8 +162,8 @@ class Qwen2_5_VLPatchMerger(nn.Module):
|
|||
|
||||
def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
tensor_ = tensor.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos = freqs.cos().float()
|
||||
sin = freqs.sin().float()
|
||||
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
|
||||
return output
|
||||
|
||||
|
|
|
|||
|
|
@ -65,8 +65,8 @@ else:
|
|||
|
||||
def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
tensor_ = tensor.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos = freqs.cos().float()
|
||||
sin = freqs.sin().float()
|
||||
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
|
||||
return output
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue