mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Only cast cu_seqlens when tracing (#35016)
* Only cast `cu_seqlens` when tracing * Formatting
This commit is contained in:
parent
19dabe9636
commit
3480cbb97e
1 changed files with 6 additions and 1 deletions
|
|
@ -1025,7 +1025,12 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
|||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||
dim=0, dtype=grid_thw.dtype
|
||||
dim=0,
|
||||
# Select dtype based on the following factors:
|
||||
# - FA2 requires that cu_seqlens_q must have dtype int32
|
||||
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
|
||||
# See https://github.com/huggingface/transformers/pull/34852 for more information
|
||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue