mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Update llm_optims docs for sdpa_kernel (#35481)
update: use sdpa_kernel
This commit is contained in:
parent
18e896bd8f
commit
44a26c871c
1 changed files with 7 additions and 4 deletions
|
|
@ -156,9 +156,11 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
|
|||
There are a few important things you must do to enable static kv-cache and `torch.compile` with the `StaticCache` method:
|
||||
1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length.
|
||||
2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache.
|
||||
3. Set `enable_math=True` in the [torch.backends.cuda.sdp_kernel](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more.
|
||||
3. Use `SDPBackend.MATH` in the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more.
|
||||
|
||||
```py
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
with torch.no_grad():
|
||||
past_key_values = StaticCache(
|
||||
|
|
@ -179,7 +181,7 @@ with torch.no_grad():
|
|||
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
|
||||
cache_position = torch.tensor([seq_length + 1], device=torch_device)
|
||||
for _ in range(1, NUM_TOKENS_TO_GENERATE):
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
||||
with sdpa_kernel(SDPBackend.MATH):
|
||||
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
cache_position += 1
|
||||
|
|
@ -453,10 +455,11 @@ Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and
|
|||
> [!TIP]
|
||||
> SDPA supports FlashAttention-2 as long as you have the latest PyTorch version installed.
|
||||
|
||||
Use the [torch.backends.cuda.sdp_kernel](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) context manager to explicitly enable or disable any of the three attention algorithms. For example, set `enable_flash=True` to enable FlashAttention.
|
||||
Use the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to explicitly enable or disable any of the four attention algorithms. For example, use `SDPBackend.FLASH_ATTENTION` to enable FlashAttention.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
|
|
@ -464,7 +467,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
||||
outputs = model.generate(**inputs)
|
||||
```
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue