mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
clean
This commit is contained in:
parent
7a911efddf
commit
20c512bc80
1 changed files with 3 additions and 9 deletions
|
|
@ -20,17 +20,11 @@ def sdpa_attention_forward(module, query, key, value, attentions_mask=None, **_k
|
|||
if attentions_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query.device.type == "cuda" and causal_mask is not None:
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and query.shape[1] > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
|
|
|
|||
Loading…
Reference in a new issue