From 44a26c871c89bba3cb959f91dba625e2b19e2979 Mon Sep 17 00:00:00 2001 From: Jacky Lee <39754370+jla524@users.noreply.github.com> Date: Mon, 6 Jan 2025 08:54:31 -0800 Subject: [PATCH] Update llm_optims docs for `sdpa_kernel` (#35481) update: use sdpa_kernel --- docs/source/en/llm_optims.md | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 17ebb841d..37406ea0b 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -156,9 +156,11 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu There are a few important things you must do to enable static kv-cache and `torch.compile` with the `StaticCache` method: 1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length. 2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache. -3. Set `enable_math=True` in the [torch.backends.cuda.sdp_kernel](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more. +3. Use `SDPBackend.MATH` in the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more. ```py +from torch.nn.attention import SDPBackend, sdpa_kernel + batch_size, seq_length = inputs["input_ids"].shape with torch.no_grad(): past_key_values = StaticCache( @@ -179,7 +181,7 @@ with torch.no_grad(): decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) cache_position = torch.tensor([seq_length + 1], device=torch_device) for _ in range(1, NUM_TOKENS_TO_GENERATE): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): + with sdpa_kernel(SDPBackend.MATH): next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values) generated_ids[:, cache_position] = next_token.int() cache_position += 1 @@ -453,10 +455,11 @@ Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and > [!TIP] > SDPA supports FlashAttention-2 as long as you have the latest PyTorch version installed. -Use the [torch.backends.cuda.sdp_kernel](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) context manager to explicitly enable or disable any of the three attention algorithms. For example, set `enable_flash=True` to enable FlashAttention. +Use the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to explicitly enable or disable any of the four attention algorithms. For example, use `SDPBackend.FLASH_ATTENTION` to enable FlashAttention. ```py import torch +from torch.nn.attention import SDPBackend, sdpa_kernel from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained( @@ -464,7 +467,7 @@ model = AutoModelForCausalLM.from_pretrained( torch_dtype=torch.bfloat16, ) -with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): +with sdpa_kernel(SDPBackend.FLASH_ATTENTION): outputs = model.generate(**inputs) ```