Update llm_optims docs for sdpa_kernel (#35481)

update: use sdpa_kernel
This commit is contained in:
Jacky Lee 2025-01-06 08:54:31 -08:00 committed by GitHub
parent 18e896bd8f
commit 44a26c871c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -156,9 +156,11 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
There are a few important things you must do to enable static kv-cache and `torch.compile` with the `StaticCache` method:
1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length.
2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache.
3. Set `enable_math=True` in the [torch.backends.cuda.sdp_kernel](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more.
3. Use `SDPBackend.MATH` in the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more.
```py
from torch.nn.attention import SDPBackend, sdpa_kernel
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
past_key_values = StaticCache(
@ -179,7 +181,7 @@ with torch.no_grad():
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, NUM_TOKENS_TO_GENERATE):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
with sdpa_kernel(SDPBackend.MATH):
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)
generated_ids[:, cache_position] = next_token.int()
cache_position += 1
@ -453,10 +455,11 @@ Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and
> [!TIP]
> SDPA supports FlashAttention-2 as long as you have the latest PyTorch version installed.
Use the [torch.backends.cuda.sdp_kernel](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) context manager to explicitly enable or disable any of the three attention algorithms. For example, set `enable_flash=True` to enable FlashAttention.
Use the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to explicitly enable or disable any of the four attention algorithms. For example, use `SDPBackend.FLASH_ATTENTION` to enable FlashAttention.
```py
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
@ -464,7 +467,7 @@ model = AutoModelForCausalLM.from_pretrained(
torch_dtype=torch.bfloat16,
)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
outputs = model.generate(**inputs)
```