Only cast cu_seqlens when tracing (#35016)

* Only cast `cu_seqlens` when tracing

* Formatting
This commit is contained in:
Joshua Lochner 2024-12-02 12:39:39 +02:00 committed by GitHub
parent 19dabe9636
commit 3480cbb97e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1025,7 +1025,12 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
rotary_pos_emb = self.rot_pos_emb(grid_thw)
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0, dtype=grid_thw.dtype
dim=0,
# Select dtype based on the following factors:
# - FA2 requires that cu_seqlens_q must have dtype int32
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
# See https://github.com/huggingface/transformers/pull/34852 for more information
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)