This commit is contained in:
Arthur Zucker 2024-12-11 18:07:53 +01:00
parent 7a911efddf
commit 20c512bc80

View file

@ -20,17 +20,11 @@ def sdpa_attention_forward(module, query, key, value, attentions_mask=None, **_k
if attentions_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query.device.type == "cuda" and causal_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and query.shape[1] > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,