Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
# Owner(s): ["module: inductor"]
|
|
|
|
|
# flake8: noqa: B950
|
|
|
|
|
|
|
|
|
|
import functools
|
|
|
|
|
from collections import namedtuple
|
2024-08-21 19:03:22 +00:00
|
|
|
from contextlib import nullcontext
|
[FlexAttention] Add broadcast support for kv batch dimension (#135505)
This PR adds broadcast support for KV batch dimension.
## Details
Consider Q of shape `[Bq, Hq, Q_LEN, D]`, and K, V of shape `[Bkv, Hkv, KV_LEN, D]`. Prior to this diff, we require `Bq == Bkv`. However, for some use cases, we may have Bkv < Bq. For example, in paged attention, we provide K, V of shape `[1, Hkv, MAX_LEN, D]`, while still providing Q of shape `[Bq, Hq, Q_LEN, D]`. Here, MAX_LEN is the maximal number of tokens supported by paged attention.
This PR relax this requirement to be `Bq == Bkv or (Bq > 1 and Bkv == 0)`. This support covers both flex decoding, flex attention forward and backward.
## Benchmark
GPU: H100
We see negligible (1%~2%) performance change from this PR when `Bq == Bkv`.
```
python benchmarks/transformer/score_mod.py --calculate-bwd
```
### Perf before this PR
**FWD**
| Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) |
|---------|-----------|---------------|------------|----------------|------------------------------|
| Average | 0.743 | | | | |
| Max | 0.955 | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) |
| Min | 0.548 | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) |
**BWD**
| Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) |
|---------|-----------|-------------|------------|----------------|-----------------------------|
| Average | 0.834 | | | | |
| Max | 1.261 | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) |
| Min | 0.456 | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) |
<details>
<summary> Full performance sweep </summary>
| score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup |
|---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------|
| None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.264 | 17.184 | 107.040 | 140.800 | 0.888 | 0.760 |
| None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.840 | 19.744 | 112.576 | 140.064 | 0.802 | 0.804 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.232 | 17.344 | 87.744 | 142.496 | 0.878 | 0.616 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.264 | 17.184 | 108.192 | 143.328 | 0.888 | 0.755 |
| None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.904 | 22.400 | 106.432 | 136.512 | 0.889 | 0.780 |
| None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.424 | 26.752 | 91.712 | 106.688 | 0.726 | 0.860 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.808 | 22.432 | 89.024 | 101.920 | 0.883 | 0.873 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.840 | 22.272 | 88.896 | 102.592 | 0.891 | 0.867 |
| None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.240 | 32.416 | 116.768 | 112.256 | 0.933 | 1.040 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 29.536 | 37.024 | 113.664 | 102.688 | 0.798 | 1.107 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.656 | 32.800 | 116.992 | 127.008 | 0.935 | 0.921 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.592 | 32.480 | 116.928 | 112.160 | 0.942 | 1.043 |
| None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.448 | 61.920 | 198.656 | 204.512 | 0.653 | 0.971 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 37.760 | 62.528 | 189.536 | 170.624 | 0.604 | 1.111 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.896 | 62.368 | 198.304 | 205.824 | 0.656 | 0.963 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.448 | 61.952 | 198.432 | 203.648 | 0.653 | 0.974 |
| None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 318.528 | 355.904 | 947.232 | 1162.496 | 0.895 | 0.815 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 199.776 | 252.128 | 677.792 | 813.184 | 0.792 | 0.834 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 316.512 | 363.328 | 947.712 | 1361.984 | 0.871 | 0.696 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 317.984 | 356.864 | 947.264 | 1165.024 | 0.891 | 0.813 |
| None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 446.656 | 734.656 | 1664.288 | 2172.960 | 0.608 | 0.766 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 278.688 | 467.648 | 1182.624 | 1339.296 | 0.596 | 0.883 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 447.872 | 744.096 | 1662.944 | 2196.544 | 0.602 | 0.757 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 448.128 | 732.928 | 1663.072 | 2156.800 | 0.611 | 0.771 |
| None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.648 | 16.640 | 107.520 | 143.008 | 0.940 | 0.752 |
| None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.776 | 18.240 | 129.056 | 141.920 | 0.865 | 0.909 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.168 | 16.640 | 103.616 | 139.648 | 0.912 | 0.742 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.616 | 16.640 | 128.608 | 164.448 | 0.938 | 0.782 |
| None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.776 | 21.952 | 125.344 | 170.304 | 0.901 | 0.736 |
| None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.776 | 23.712 | 104.288 | 196.896 | 0.834 | 0.530 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.072 | 21.952 | 102.080 | 177.056 | 0.869 | 0.577 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.648 | 21.920 | 109.920 | 170.848 | 0.896 | 0.643 |
| None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.464 | 31.936 | 127.808 | 228.832 | 0.954 | 0.559 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 29.472 | 33.856 | 113.152 | 215.072 | 0.871 | 0.526 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.496 | 32.160 | 116.576 | 231.744 | 0.948 | 0.503 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.464 | 31.904 | 116.320 | 229.824 | 0.955 | 0.506 |
| None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.480 | 61.440 | 176.448 | 345.312 | 0.659 | 0.511 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 38.304 | 59.424 | 169.312 | 371.360 | 0.645 | 0.456 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.960 | 61.760 | 176.512 | 358.912 | 0.663 | 0.492 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.352 | 61.696 | 176.512 | 344.928 | 0.654 | 0.512 |
| None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 316.224 | 357.728 | 905.728 | 1668.448 | 0.884 | 0.543 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 199.904 | 248.416 | 636.544 | 1109.088 | 0.805 | 0.574 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 314.880 | 363.616 | 906.304 | 1658.176 | 0.866 | 0.547 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 316.160 | 354.368 | 906.080 | 1649.024 | 0.892 | 0.549 |
| None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.912 | 739.840 | 1555.808 | 2521.952 | 0.604 | 0.617 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 279.776 | 463.904 | 1068.928 | 1849.888 | 0.603 | 0.578 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.080 | 748.960 | 1553.504 | 2629.888 | 0.596 | 0.591 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.208 | 740.608 | 1558.880 | 2524.960 | 0.602 | 0.617 |
| None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 33.568 | 41.280 | 170.016 | 147.584 | 0.813 | 1.152 |
| None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 30.688 | 43.040 | 159.552 | 146.720 | 0.713 | 1.087 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 34.112 | 41.504 | 170.112 | 152.672 | 0.822 | 1.114 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 34.240 | 41.152 | 170.272 | 134.976 | 0.832 | 1.261 |
| None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.672 | 76.416 | 295.296 | 263.648 | 0.637 | 1.120 |
| None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 45.088 | 72.576 | 281.920 | 237.664 | 0.621 | 1.186 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.032 | 76.672 | 295.520 | 265.248 | 0.626 | 1.114 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.096 | 76.096 | 295.456 | 262.112 | 0.632 | 1.127 |
| None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 93.920 | 111.232 | 401.568 | 382.944 | 0.844 | 1.049 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 68.192 | 95.232 | 338.752 | 326.816 | 0.716 | 1.037 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 93.984 | 111.840 | 401.856 | 444.224 | 0.840 | 0.905 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 94.176 | 110.496 | 401.600 | 383.136 | 0.852 | 1.048 |
| None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.488 | 227.040 | 727.424 | 739.712 | 0.579 | 0.983 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 95.616 | 169.760 | 616.864 | 574.112 | 0.563 | 1.074 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.680 | 228.672 | 727.616 | 746.048 | 0.576 | 0.975 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.104 | 225.696 | 727.904 | 735.392 | 0.581 | 0.990 |
| None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1227.296 | 1386.656 | 3720.192 | 4539.904 | 0.885 | 0.819 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 691.360 | 831.712 | 2515.872 | 3067.808 | 0.831 | 0.820 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1228.192 | 1403.136 | 3715.520 | 5309.280 | 0.875 | 0.700 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1229.024 | 1384.992 | 3715.904 | 4550.368 | 0.887 | 0.817 |
| None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1784.832 | 2865.888 | 6539.840 | 8460.224 | 0.623 | 0.773 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1017.408 | 1660.480 | 4369.824 | 5056.992 | 0.613 | 0.864 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1792.448 | 2904.864 | 6546.080 | 8537.024 | 0.617 | 0.767 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1795.552 | 2856.864 | 6544.672 | 8400.160 | 0.629 | 0.779 |
| None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 38.880 | 148.832 | 179.936 | 0.881 | 0.827 |
| None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 31.168 | 38.080 | 138.528 | 167.552 | 0.818 | 0.827 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 39.168 | 148.512 | 181.248 | 0.874 | 0.819 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 38.784 | 148.864 | 180.224 | 0.883 | 0.826 |
| None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.832 | 76.352 | 253.632 | 295.968 | 0.640 | 0.857 |
| None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 45.760 | 65.792 | 239.040 | 290.752 | 0.696 | 0.822 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.768 | 76.576 | 253.312 | 304.032 | 0.637 | 0.833 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.768 | 76.192 | 253.600 | 296.096 | 0.640 | 0.856 |
| None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.728 | 109.728 | 357.696 | 498.912 | 0.854 | 0.717 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 68.704 | 92.288 | 295.616 | 386.240 | 0.744 | 0.765 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.632 | 111.392 | 357.408 | 512.448 | 0.841 | 0.697 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.280 | 109.952 | 357.696 | 501.440 | 0.848 | 0.713 |
| None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.392 | 230.496 | 612.224 | 807.552 | 0.570 | 0.758 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 96.512 | 165.184 | 502.624 | 672.384 | 0.584 | 0.748 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.360 | 232.608 | 612.064 | 832.320 | 0.565 | 0.735 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.008 | 230.528 | 612.640 | 804.320 | 0.568 | 0.762 |
| None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1227.968 | 1377.408 | 3477.920 | 5324.384 | 0.892 | 0.653 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 695.264 | 824.544 | 2268.224 | 3210.208 | 0.843 | 0.707 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1228.640 | 1404.576 | 3476.832 | 5463.456 | 0.875 | 0.636 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1228.416 | 1378.752 | 3478.048 | 5367.712 | 0.891 | 0.648 |
| None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1788.736 | 2867.712 | 6039.520 | 8616.256 | 0.624 | 0.701 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1021.952 | 1653.824 | 3866.208 | 5306.848 | 0.618 | 0.729 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1786.752 | 2896.352 | 6044.128 | 8871.360 | 0.617 | 0.681 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1786.080 | 2868.672 | 6040.160 | 8550.144 | 0.623 | 0.706 |
| None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 57.504 | 71.552 | 312.768 | 255.040 | 0.804 | 1.226 |
| None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 49.472 | 71.104 | 285.696 | 243.520 | 0.696 | 1.173 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 58.112 | 72.896 | 312.768 | 288.256 | 0.797 | 1.085 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 57.952 | 71.680 | 312.768 | 255.552 | 0.808 | 1.224 |
| None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.336 | 144.256 | 580.128 | 500.160 | 0.571 | 1.160 |
| None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 76.160 | 123.712 | 552.544 | 447.648 | 0.616 | 1.234 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.400 | 145.184 | 580.032 | 504.032 | 0.568 | 1.151 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.368 | 143.904 | 580.192 | 499.936 | 0.572 | 1.161 |
| None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.216 | 209.568 | 787.872 | 747.712 | 0.846 | 1.054 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 121.984 | 168.256 | 651.968 | 628.256 | 0.725 | 1.038 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.088 | 211.488 | 788.320 | 864.352 | 0.837 | 0.912 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.440 | 208.576 | 787.424 | 749.120 | 0.851 | 1.051 |
| None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 249.472 | 441.376 | 1405.440 | 1431.648 | 0.565 | 0.982 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 172.960 | 312.064 | 1172.064 | 1096.448 | 0.554 | 1.069 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 249.632 | 446.336 | 1405.408 | 1448.480 | 0.559 | 0.970 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 250.944 | 440.128 | 1406.624 | 1421.952 | 0.570 | 0.989 |
| None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2418.720 | 2747.936 | 7330.432 | 9023.712 | 0.880 | 0.812 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 1353.696 | 1608.480 | 4941.696 | 6078.752 | 0.842 | 0.813 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2427.456 | 2746.816 | 7329.792 | 10539.968 | 0.884 | 0.695 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2426.688 | 2763.168 | 7336.256 | 9057.536 | 0.878 | 0.810 |
| None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3554.240 | 5634.400 | 12919.872 | 16843.489 | 0.631 | 0.767 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 2003.648 | 3250.784 | 8610.144 | 10015.424 | 0.616 | 0.860 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3582.080 | 5710.944 | 12923.328 | 17011.871 | 0.627 | 0.760 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3581.920 | 5618.144 | 12934.528 | 16745.888 | 0.638 | 0.772 |
| None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.120 | 71.232 | 269.760 | 295.680 | 0.802 | 0.912 |
| None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 49.408 | 65.312 | 242.304 | 253.952 | 0.756 | 0.954 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.504 | 72.544 | 269.632 | 298.976 | 0.793 | 0.902 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.760 | 71.040 | 269.600 | 296.640 | 0.813 | 0.909 |
| None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 82.336 | 147.168 | 466.080 | 487.456 | 0.559 | 0.956 |
| None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 76.704 | 115.040 | 435.392 | 453.248 | 0.667 | 0.961 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 81.856 | 147.424 | 465.920 | 499.552 | 0.555 | 0.933 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 81.760 | 146.656 | 466.176 | 485.984 | 0.557 | 0.959 |
| None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 176.608 | 206.976 | 678.080 | 866.976 | 0.853 | 0.782 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 121.664 | 164.768 | 538.240 | 636.160 | 0.738 | 0.846 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 176.608 | 209.664 | 677.696 | 883.424 | 0.842 | 0.767 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 177.440 | 207.840 | 677.248 | 868.288 | 0.854 | 0.780 |
| None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 250.272 | 449.536 | 1163.424 | 1420.832 | 0.557 | 0.819 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 173.472 | 305.376 | 929.408 | 1104.544 | 0.568 | 0.841 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 249.376 | 454.976 | 1163.648 | 1455.296 | 0.548 | 0.800 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 250.368 | 450.144 | 1163.520 | 1409.984 | 0.556 | 0.825 |
| None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2416.576 | 2726.208 | 6835.520 | 10442.784 | 0.886 | 0.655 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 1357.440 | 1590.752 | 4433.664 | 5975.296 | 0.853 | 0.742 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2427.360 | 2747.040 | 6853.056 | 10670.784 | 0.884 | 0.642 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2441.120 | 2718.944 | 6836.640 | 10433.792 | 0.898 | 0.655 |
| None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3555.392 | 5620.960 | 11944.000 | 16504.801 | 0.633 | 0.724 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 2010.848 | 3241.152 | 7636.064 | 9870.464 | 0.620 | 0.774 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3557.440 | 5688.352 | 11935.744 | 17090.496 | 0.625 | 0.698 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3562.720 | 5630.432 | 11939.168 | 16392.033 | 0.633 | 0.728 |
</details>
### Perf after this PR
**FWD**
| Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) |
|---------|-----------|---------------|------------|----------------|----------------------------|
| Average | 0.776 | | | | |
| Max | 1.006 | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) |
| Min | 0.566 | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) |
**BWD**
| Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) |
|---------|-----------|-------------|------------|----------------|-----------------------------|
| Average | 0.817 | | | | |
| Max | 1.150 | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) |
| Min | 0.454 | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) |
<details>
<summary> Full performance sweep </summary>
| score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup |
|---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------|
| None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.680 | 17.056 | 64.544 | 73.376 | 0.919 | 0.880 |
| None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.712 | 19.872 | 65.408 | 72.864 | 0.791 | 0.898 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 16.160 | 17.280 | 64.896 | 73.888 | 0.935 | 0.878 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 16.192 | 17.120 | 64.896 | 75.424 | 0.946 | 0.860 |
| None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.648 | 22.496 | 89.184 | 82.592 | 0.873 | 1.080 |
| None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 20.320 | 26.816 | 91.264 | 82.880 | 0.758 | 1.101 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 20.096 | 22.528 | 89.184 | 83.776 | 0.892 | 1.065 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.680 | 22.432 | 89.184 | 120.096 | 0.877 | 0.743 |
| None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.384 | 32.512 | 119.232 | 128.960 | 0.996 | 0.925 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.176 | 37.248 | 113.664 | 119.520 | 0.810 | 0.951 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.512 | 32.928 | 119.264 | 131.456 | 0.987 | 0.907 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.448 | 32.704 | 119.200 | 128.352 | 0.992 | 0.929 |
| None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 41.952 | 62.176 | 199.040 | 214.304 | 0.675 | 0.929 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 39.744 | 62.880 | 189.504 | 179.968 | 0.632 | 1.053 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 41.472 | 62.784 | 199.136 | 217.664 | 0.661 | 0.915 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 42.048 | 61.952 | 199.168 | 214.496 | 0.679 | 0.929 |
| None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 341.184 | 357.632 | 980.256 | 1328.896 | 0.954 | 0.738 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 212.576 | 252.960 | 673.888 | 824.864 | 0.840 | 0.817 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 340.000 | 363.296 | 980.768 | 1375.808 | 0.936 | 0.713 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 340.768 | 356.832 | 980.960 | 1326.272 | 0.955 | 0.740 |
| None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 459.392 | 737.120 | 1678.240 | 2205.248 | 0.623 | 0.761 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 292.672 | 468.096 | 1178.016 | 1371.584 | 0.625 | 0.859 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 462.144 | 745.312 | 1680.000 | 2252.512 | 0.620 | 0.746 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 462.112 | 736.576 | 1679.008 | 2216.480 | 0.627 | 0.758 |
| None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.064 | 16.704 | 105.120 | 120.768 | 0.962 | 0.870 |
| None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.552 | 18.144 | 107.136 | 121.696 | 0.857 | 0.880 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.096 | 16.768 | 102.688 | 120.864 | 0.960 | 0.850 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.032 | 16.576 | 104.736 | 124.672 | 0.967 | 0.840 |
| None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.392 | 21.952 | 104.736 | 174.656 | 0.883 | 0.600 |
| None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 20.128 | 23.712 | 105.216 | 199.008 | 0.849 | 0.529 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.904 | 21.888 | 103.744 | 179.520 | 0.909 | 0.578 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.968 | 21.952 | 104.640 | 177.312 | 0.910 | 0.590 |
| None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.096 | 31.904 | 118.720 | 231.968 | 1.006 | 0.512 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.528 | 33.952 | 112.480 | 218.304 | 0.899 | 0.515 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.160 | 32.224 | 118.752 | 237.312 | 0.998 | 0.500 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.128 | 32.032 | 118.240 | 233.120 | 1.003 | 0.507 |
| None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.312 | 61.280 | 177.408 | 350.688 | 0.674 | 0.506 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 39.552 | 59.360 | 168.832 | 371.488 | 0.666 | 0.454 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.984 | 61.696 | 177.376 | 360.416 | 0.680 | 0.492 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.312 | 61.760 | 177.184 | 355.744 | 0.669 | 0.498 |
| None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 339.744 | 357.888 | 939.712 | 1665.376 | 0.949 | 0.564 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 212.608 | 248.832 | 633.280 | 1122.848 | 0.854 | 0.564 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 339.712 | 363.232 | 940.448 | 1689.440 | 0.935 | 0.557 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 341.056 | 355.264 | 940.128 | 1641.152 | 0.960 | 0.573 |
| None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.736 | 741.024 | 1569.824 | 2559.552 | 0.622 | 0.613 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 293.856 | 464.192 | 1066.240 | 1840.416 | 0.633 | 0.579 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.704 | 753.152 | 1570.112 | 2641.088 | 0.612 | 0.594 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.832 | 745.536 | 1570.144 | 2602.560 | 0.618 | 0.603 |
| None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.680 | 41.280 | 171.840 | 158.176 | 0.864 | 1.086 |
| None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 31.360 | 42.976 | 158.912 | 139.264 | 0.730 | 1.141 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.168 | 41.600 | 171.648 | 161.344 | 0.845 | 1.064 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.136 | 41.152 | 171.808 | 158.336 | 0.854 | 1.085 |
| None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.832 | 76.384 | 295.680 | 277.696 | 0.639 | 1.065 |
| None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 45.632 | 72.512 | 281.760 | 250.752 | 0.629 | 1.124 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 49.504 | 76.608 | 295.584 | 279.712 | 0.646 | 1.057 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.864 | 75.904 | 295.456 | 277.568 | 0.644 | 1.064 |
| None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 99.392 | 111.232 | 408.640 | 442.656 | 0.894 | 0.923 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 71.392 | 95.168 | 338.784 | 341.760 | 0.750 | 0.991 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 99.808 | 112.256 | 408.608 | 456.160 | 0.889 | 0.896 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 100.032 | 110.816 | 408.512 | 444.192 | 0.903 | 0.920 |
| None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.040 | 226.112 | 726.880 | 774.176 | 0.597 | 0.939 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 99.904 | 169.696 | 616.448 | 607.104 | 0.589 | 1.015 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.488 | 228.384 | 727.776 | 782.368 | 0.593 | 0.930 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.744 | 225.664 | 728.000 | 773.600 | 0.602 | 0.941 |
| None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1324.192 | 1387.808 | 3866.944 | 5217.184 | 0.954 | 0.741 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 738.464 | 832.608 | 2507.392 | 3146.688 | 0.887 | 0.797 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1326.016 | 1404.256 | 3867.872 | 5382.624 | 0.944 | 0.719 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1326.144 | 1386.688 | 3867.552 | 5203.264 | 0.956 | 0.743 |
| None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1847.488 | 2866.336 | 6612.704 | 8597.696 | 0.645 | 0.769 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1066.592 | 1660.640 | 4357.696 | 5174.016 | 0.642 | 0.842 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1850.464 | 2905.408 | 6616.928 | 8793.280 | 0.637 | 0.752 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1848.896 | 2834.720 | 6623.872 | 8637.920 | 0.652 | 0.767 |
| None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.384 | 38.656 | 150.336 | 182.624 | 0.941 | 0.823 |
| None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 31.360 | 38.112 | 137.664 | 171.840 | 0.823 | 0.801 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.608 | 39.040 | 150.528 | 183.872 | 0.938 | 0.819 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.064 | 38.656 | 150.560 | 183.520 | 0.933 | 0.820 |
| None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.344 | 76.352 | 253.920 | 301.440 | 0.646 | 0.842 |
| None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 46.720 | 65.824 | 239.424 | 296.384 | 0.710 | 0.808 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.248 | 76.416 | 253.728 | 307.808 | 0.644 | 0.824 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.376 | 76.288 | 253.728 | 304.736 | 0.647 | 0.833 |
| None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.264 | 110.144 | 364.960 | 503.072 | 0.901 | 0.725 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 71.136 | 92.384 | 294.432 | 393.056 | 0.770 | 0.749 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.200 | 111.360 | 365.152 | 512.640 | 0.891 | 0.712 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.264 | 110.240 | 365.088 | 504.224 | 0.900 | 0.724 |
| None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.680 | 230.336 | 613.472 | 816.896 | 0.589 | 0.751 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 100.256 | 165.088 | 502.144 | 676.480 | 0.607 | 0.742 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.008 | 232.480 | 613.184 | 836.672 | 0.581 | 0.733 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.232 | 230.624 | 613.536 | 827.136 | 0.586 | 0.742 |
| None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1324.064 | 1378.688 | 3631.808 | 5308.384 | 0.960 | 0.684 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 731.776 | 826.688 | 2263.168 | 3241.344 | 0.885 | 0.698 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1316.128 | 1403.200 | 3625.088 | 5550.688 | 0.938 | 0.653 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1311.904 | 1378.880 | 3616.320 | 5353.696 | 0.951 | 0.675 |
| None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1837.856 | 2887.392 | 6121.632 | 8586.656 | 0.637 | 0.713 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1066.976 | 1654.368 | 3843.136 | 5291.040 | 0.645 | 0.726 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1854.208 | 2896.832 | 6130.112 | 8745.984 | 0.640 | 0.701 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1860.512 | 2889.344 | 6135.648 | 8750.592 | 0.644 | 0.701 |
| None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 60.640 | 71.552 | 315.968 | 296.512 | 0.847 | 1.066 |
| None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 50.784 | 71.040 | 284.288 | 258.880 | 0.715 | 1.098 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 61.312 | 72.704 | 315.680 | 302.016 | 0.843 | 1.045 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 60.800 | 71.776 | 316.320 | 297.152 | 0.847 | 1.065 |
| None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.576 | 144.416 | 580.576 | 535.936 | 0.586 | 1.083 |
| None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 76.064 | 123.648 | 553.344 | 481.376 | 0.615 | 1.150 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.160 | 145.248 | 581.024 | 540.000 | 0.579 | 1.076 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.512 | 143.552 | 581.088 | 535.776 | 0.589 | 1.085 |
| None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.152 | 209.408 | 798.400 | 868.704 | 0.903 | 0.919 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 127.552 | 168.800 | 650.816 | 663.328 | 0.756 | 0.981 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.376 | 211.360 | 798.080 | 895.552 | 0.896 | 0.891 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.440 | 208.576 | 797.888 | 873.152 | 0.908 | 0.914 |
| None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 257.536 | 441.760 | 1408.960 | 1514.720 | 0.583 | 0.930 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 179.328 | 312.096 | 1170.368 | 1177.472 | 0.575 | 0.994 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 259.264 | 446.944 | 1408.768 | 1530.400 | 0.580 | 0.921 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 258.080 | 440.480 | 1408.864 | 1514.144 | 0.586 | 0.930 |
| None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2595.808 | 2771.456 | 7616.704 | 10405.248 | 0.937 | 0.732 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 1435.744 | 1610.336 | 4927.520 | 6220.000 | 0.892 | 0.792 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2595.264 | 2745.056 | 7611.232 | 10631.392 | 0.945 | 0.716 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2576.256 | 2735.456 | 7626.400 | 10346.976 | 0.942 | 0.737 |
| None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3679.744 | 5634.816 | 13077.056 | 17182.528 | 0.653 | 0.761 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 2099.360 | 3250.176 | 8589.664 | 10236.672 | 0.646 | 0.839 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3676.800 | 5716.288 | 13073.088 | 17311.071 | 0.643 | 0.755 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3679.136 | 5570.496 | 13070.720 | 17192.863 | 0.660 | 0.760 |
| None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.600 | 71.008 | 272.320 | 300.000 | 0.868 | 0.908 |
| None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 50.176 | 65.344 | 241.568 | 258.912 | 0.768 | 0.933 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.120 | 72.512 | 272.672 | 305.408 | 0.843 | 0.893 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.248 | 71.136 | 272.640 | 301.120 | 0.861 | 0.905 |
| None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.872 | 146.784 | 466.912 | 496.832 | 0.571 | 0.940 |
| None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 76.704 | 115.072 | 435.584 | 462.112 | 0.667 | 0.943 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.392 | 147.392 | 466.656 | 504.448 | 0.566 | 0.925 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.360 | 146.688 | 466.656 | 499.040 | 0.568 | 0.935 |
| None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 189.024 | 207.584 | 684.768 | 873.568 | 0.911 | 0.784 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 126.944 | 164.288 | 536.192 | 645.984 | 0.773 | 0.830 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 188.768 | 209.760 | 684.096 | 897.504 | 0.900 | 0.762 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 189.408 | 207.776 | 685.024 | 876.384 | 0.912 | 0.782 |
| None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 259.168 | 449.536 | 1167.936 | 1433.280 | 0.577 | 0.815 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 180.000 | 305.312 | 928.000 | 1113.920 | 0.590 | 0.833 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 258.464 | 455.136 | 1167.808 | 1462.848 | 0.568 | 0.798 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 257.824 | 450.208 | 1167.744 | 1448.000 | 0.573 | 0.806 |
| None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2598.368 | 2729.120 | 7134.400 | 10381.632 | 0.952 | 0.687 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 1435.456 | 1591.040 | 4424.768 | 6035.808 | 0.902 | 0.733 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2594.752 | 2725.952 | 7128.384 | 10822.496 | 0.952 | 0.659 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2597.888 | 2716.960 | 7101.568 | 10385.440 | 0.956 | 0.684 |
| None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3647.648 | 5581.632 | 12089.952 | 16667.233 | 0.654 | 0.725 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 2093.952 | 3241.440 | 7579.392 | 9847.936 | 0.646 | 0.770 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3650.528 | 5650.688 | 12105.568 | 16963.680 | 0.646 | 0.714 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3680.064 | 5585.312 | 12117.504 | 16935.040 | 0.659 | 0.716 |
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135505
Approved by: https://github.com/Chillee
2024-09-10 09:30:00 +00:00
|
|
|
from typing import Callable, Optional, Tuple
|
2024-08-10 00:24:34 +00:00
|
|
|
from unittest import expectedFailure, skipUnless
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
from unittest.mock import patch
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch._inductor.test_case import TestCase as InductorTestCase
|
|
|
|
|
from torch._inductor.utils import run_and_get_code
|
2024-07-16 16:21:23 +00:00
|
|
|
from torch.nn.attention.flex_attention import (
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
_create_empty_block_mask,
|
|
|
|
|
_identity,
|
2024-08-27 17:39:14 +00:00
|
|
|
BlockMask,
|
2024-07-24 01:36:12 +00:00
|
|
|
create_block_mask,
|
2024-07-17 22:24:22 +00:00
|
|
|
flex_attention,
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
)
|
|
|
|
|
from torch.testing import FileCheck
|
|
|
|
|
from torch.testing._internal import common_utils
|
|
|
|
|
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16
|
2024-09-19 18:02:39 +00:00
|
|
|
from torch.testing._internal.common_utils import skipIfRocm
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
from torch.utils._triton import has_triton
|
|
|
|
|
|
2024-07-18 06:46:35 +00:00
|
|
|
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
# Skip tests if Triton is not available
|
|
|
|
|
supported_platform = skipUnless(
|
|
|
|
|
torch.cuda.is_available()
|
|
|
|
|
and has_triton()
|
|
|
|
|
and torch.cuda.get_device_capability() >= (8, 0),
|
|
|
|
|
"Requires CUDA and Triton",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
|
|
|
|
|
torch.set_float32_matmul_precision("high")
|
|
|
|
|
|
|
|
|
|
index = torch.ops.aten.index
|
2024-07-17 22:24:22 +00:00
|
|
|
Tensor = torch.Tensor
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
|
Add explicit GQA support. (#131559)
### tl;dr
This PR adds GQA support to higher order op `flex_attention`.
## Details
When `enable_gqa` is set to True, HOP `flex_attention(score_mod, query, key, value, block_mask, enable_gqa)` runs Group Query Attention(GQA), where the number of query heads (Hq) is a multiple of number of key/value heads (Hkv). Each group of query heads (`Hq//Hkv` heads) attends to a shared kv head.
Otherwise, `flex_attention` assumes Multi Head Attention (MHA) where the number of query heads is equal the number of kv heads.
The `score_mod` and `mask_mod` API are adapted accordingly to take `q_head` as head index.
```
def score_mod(score: torch.Tensor, batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
def mask_mod(batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```
## Example
```python
import torch
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask
torch.manual_seed(0)
def query_key_value_clones(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype = None,
):
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
if dtype is None:
dtype = query.dtype
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref
# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim).
# query and key/value can have different num_heads and seq_len
# Here 8 query heads share one KV head.
query = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
key = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
value = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
query1, key1, value1 = query_key_value_clones(query, key, value)
# Lets create a score_modification. We take alibi_bias as an example.
# score_mod takes batch index, query head index, query index, and key/value index.
def _generate_alibi_bias(num_kv_heads: int, num_q_heads: int):
def _alibi_bias(
score: torch.Tensor,
b: torch.Tensor,
hq: torch.Tensor,
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
# Let's calculate kv head from query head index
group = num_q_heads // num_kv_heads
hkv = hq // group
scale = torch.exp2(-((hkv + 1) * 8.0 / num_kv_heads))
return score + (token_kv - token_q) * scale
return _alibi_bias
# Let's apply a casual mask on top of it
def causal_mask(b, h, q, kv):
return q >= kv
# Generate a block mask for our new mask_mod function.
# The mask is broadcasted long head & batch dimensions.
block_mask = create_block_mask(causal_mask, B=1, H=1, Q_LEN=2048, KV_LEN=2048)
# Lets call flex_attention with our new score modification and block mask under eager mode.
output = flex_attention(query, key, value, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
# Now lets compile flex_attention and run the flex_attention kernel.
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention(query1, key1, value1, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
torch.testing.assert_close(output, out_compiled, atol=5e-2, rtol=2e-2)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131559
Approved by: https://github.com/drisspg
2024-08-09 18:09:18 +00:00
|
|
|
def create_attention(score_mod, block_mask, enable_gqa=False):
|
|
|
|
|
return functools.partial(
|
|
|
|
|
flex_attention,
|
|
|
|
|
score_mod=score_mod,
|
|
|
|
|
block_mask=block_mask,
|
|
|
|
|
enable_gqa=enable_gqa,
|
|
|
|
|
)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
|
2024-07-24 01:36:12 +00:00
|
|
|
def create_block_mask_test(score_mod, query, key):
|
|
|
|
|
block_mask = create_block_mask(
|
|
|
|
|
score_mod, 1, 1, query.shape[-2], key.shape[-2], query.device
|
|
|
|
|
)
|
|
|
|
|
return block_mask
|
|
|
|
|
|
|
|
|
|
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
test_dtypes = (
|
|
|
|
|
[torch.float16, torch.bfloat16, torch.float32]
|
|
|
|
|
if PLATFORM_SUPPORTS_BF16
|
|
|
|
|
else [torch.float16, torch.float32]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
test_dtypes_fast = [torch.float16]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# --------- Useful score mod functions for testing ---------
|
2024-07-17 22:24:22 +00:00
|
|
|
def _causal(
|
|
|
|
|
score: Tensor,
|
|
|
|
|
batch: Tensor,
|
|
|
|
|
head: Tensor,
|
|
|
|
|
token_q: Tensor,
|
|
|
|
|
token_kv: Tensor,
|
|
|
|
|
) -> Tensor:
|
|
|
|
|
return torch.where(token_q >= token_kv, score, float("-inf"))
|
|
|
|
|
|
|
|
|
|
|
2024-08-01 03:04:45 +00:00
|
|
|
def _generate_windowed(offset):
|
|
|
|
|
def _windowed(score, b, h, q, kv):
|
|
|
|
|
return torch.where(q + offset >= kv, score, float("-inf"))
|
|
|
|
|
|
|
|
|
|
return _windowed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_windowed_sdpa_mask(Mq, Mkv, offset):
|
|
|
|
|
return torch.tril(torch.ones(Mkv, Mkv, dtype=torch.bool, device="cuda"))[
|
|
|
|
|
offset : offset + Mq
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
2024-07-17 22:24:22 +00:00
|
|
|
def _rel_bias(
|
|
|
|
|
score: Tensor,
|
|
|
|
|
batch: Tensor,
|
|
|
|
|
head: Tensor,
|
|
|
|
|
token_q: Tensor,
|
|
|
|
|
token_kv: Tensor,
|
|
|
|
|
) -> Tensor:
|
|
|
|
|
return score + (token_q - token_kv)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _rel_causal(
|
|
|
|
|
score: Tensor,
|
|
|
|
|
batch: Tensor,
|
|
|
|
|
head: Tensor,
|
|
|
|
|
token_q: Tensor,
|
|
|
|
|
token_kv: Tensor,
|
|
|
|
|
) -> Tensor:
|
|
|
|
|
return torch.where(token_q >= token_kv, score + (token_q - token_kv), float("-inf"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_alibi_bias(num_heads: int):
|
|
|
|
|
def _alibi_bias(
|
|
|
|
|
score: Tensor,
|
|
|
|
|
batch: Tensor,
|
|
|
|
|
head: Tensor,
|
|
|
|
|
token_q: Tensor,
|
|
|
|
|
token_kv: Tensor,
|
|
|
|
|
) -> Tensor:
|
|
|
|
|
scale = torch.exp2(-((head + 1) * 8.0 / num_heads))
|
|
|
|
|
return score + (token_kv - token_q) * scale
|
|
|
|
|
|
|
|
|
|
return _alibi_bias
|
|
|
|
|
|
|
|
|
|
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
def _inverse_causal(score, b, h, m, n):
|
|
|
|
|
return torch.where(m <= n, score, float("-inf"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _times_two(score, b, h, m, n):
|
|
|
|
|
"""Joint graph needed for correctness"""
|
|
|
|
|
return score * 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _squared(score, b, h, m, n):
|
|
|
|
|
"""Joint graph needed for correctness"""
|
|
|
|
|
return score * score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _head_offset(dtype: torch.dtype):
|
|
|
|
|
"""Captured Buffer"""
|
2024-07-24 01:36:12 +00:00
|
|
|
head_offset = torch.rand(Hq, device="cuda", dtype=dtype)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
def score_mod(score, b, h, m, n):
|
|
|
|
|
return score * head_offset[h]
|
|
|
|
|
|
|
|
|
|
return score_mod
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _trig(score, b, h, m, n):
|
|
|
|
|
"""Joint graph needed for correctness"""
|
|
|
|
|
return torch.sin(torch.cos(score)) + torch.tan(b)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _trig2(score, b, h, m, n):
|
|
|
|
|
"""Branching joint graph"""
|
|
|
|
|
cos_score = torch.cos(score)
|
|
|
|
|
sin_score = torch.sin(score)
|
|
|
|
|
z = cos_score * sin_score + torch.tan(b)
|
|
|
|
|
return z
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_score_mods = [
|
|
|
|
|
_identity,
|
|
|
|
|
_times_two,
|
|
|
|
|
_squared,
|
|
|
|
|
_causal,
|
|
|
|
|
_inverse_causal,
|
|
|
|
|
_rel_bias,
|
|
|
|
|
_rel_causal,
|
|
|
|
|
_generate_alibi_bias(8),
|
2024-08-01 03:04:45 +00:00
|
|
|
_generate_windowed(1000),
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
]
|
|
|
|
|
|
|
|
|
|
captured_buffers_map = {
|
|
|
|
|
"_head_offset": _head_offset,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
B = 4
|
|
|
|
|
S = 2048
|
|
|
|
|
D = 64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_Hq_Hkv = [
|
|
|
|
|
(16, 1),
|
|
|
|
|
(8, 2),
|
|
|
|
|
(16, 16),
|
|
|
|
|
]
|
|
|
|
|
|
[FlexAttention] Add broadcast support for kv batch dimension (#135505)
This PR adds broadcast support for KV batch dimension.
## Details
Consider Q of shape `[Bq, Hq, Q_LEN, D]`, and K, V of shape `[Bkv, Hkv, KV_LEN, D]`. Prior to this diff, we require `Bq == Bkv`. However, for some use cases, we may have Bkv < Bq. For example, in paged attention, we provide K, V of shape `[1, Hkv, MAX_LEN, D]`, while still providing Q of shape `[Bq, Hq, Q_LEN, D]`. Here, MAX_LEN is the maximal number of tokens supported by paged attention.
This PR relax this requirement to be `Bq == Bkv or (Bq > 1 and Bkv == 0)`. This support covers both flex decoding, flex attention forward and backward.
## Benchmark
GPU: H100
We see negligible (1%~2%) performance change from this PR when `Bq == Bkv`.
```
python benchmarks/transformer/score_mod.py --calculate-bwd
```
### Perf before this PR
**FWD**
| Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) |
|---------|-----------|---------------|------------|----------------|------------------------------|
| Average | 0.743 | | | | |
| Max | 0.955 | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) |
| Min | 0.548 | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) |
**BWD**
| Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) |
|---------|-----------|-------------|------------|----------------|-----------------------------|
| Average | 0.834 | | | | |
| Max | 1.261 | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) |
| Min | 0.456 | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) |
<details>
<summary> Full performance sweep </summary>
| score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup |
|---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------|
| None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.264 | 17.184 | 107.040 | 140.800 | 0.888 | 0.760 |
| None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.840 | 19.744 | 112.576 | 140.064 | 0.802 | 0.804 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.232 | 17.344 | 87.744 | 142.496 | 0.878 | 0.616 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.264 | 17.184 | 108.192 | 143.328 | 0.888 | 0.755 |
| None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.904 | 22.400 | 106.432 | 136.512 | 0.889 | 0.780 |
| None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.424 | 26.752 | 91.712 | 106.688 | 0.726 | 0.860 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.808 | 22.432 | 89.024 | 101.920 | 0.883 | 0.873 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.840 | 22.272 | 88.896 | 102.592 | 0.891 | 0.867 |
| None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.240 | 32.416 | 116.768 | 112.256 | 0.933 | 1.040 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 29.536 | 37.024 | 113.664 | 102.688 | 0.798 | 1.107 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.656 | 32.800 | 116.992 | 127.008 | 0.935 | 0.921 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.592 | 32.480 | 116.928 | 112.160 | 0.942 | 1.043 |
| None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.448 | 61.920 | 198.656 | 204.512 | 0.653 | 0.971 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 37.760 | 62.528 | 189.536 | 170.624 | 0.604 | 1.111 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.896 | 62.368 | 198.304 | 205.824 | 0.656 | 0.963 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.448 | 61.952 | 198.432 | 203.648 | 0.653 | 0.974 |
| None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 318.528 | 355.904 | 947.232 | 1162.496 | 0.895 | 0.815 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 199.776 | 252.128 | 677.792 | 813.184 | 0.792 | 0.834 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 316.512 | 363.328 | 947.712 | 1361.984 | 0.871 | 0.696 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 317.984 | 356.864 | 947.264 | 1165.024 | 0.891 | 0.813 |
| None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 446.656 | 734.656 | 1664.288 | 2172.960 | 0.608 | 0.766 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 278.688 | 467.648 | 1182.624 | 1339.296 | 0.596 | 0.883 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 447.872 | 744.096 | 1662.944 | 2196.544 | 0.602 | 0.757 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 448.128 | 732.928 | 1663.072 | 2156.800 | 0.611 | 0.771 |
| None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.648 | 16.640 | 107.520 | 143.008 | 0.940 | 0.752 |
| None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.776 | 18.240 | 129.056 | 141.920 | 0.865 | 0.909 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.168 | 16.640 | 103.616 | 139.648 | 0.912 | 0.742 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.616 | 16.640 | 128.608 | 164.448 | 0.938 | 0.782 |
| None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.776 | 21.952 | 125.344 | 170.304 | 0.901 | 0.736 |
| None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.776 | 23.712 | 104.288 | 196.896 | 0.834 | 0.530 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.072 | 21.952 | 102.080 | 177.056 | 0.869 | 0.577 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.648 | 21.920 | 109.920 | 170.848 | 0.896 | 0.643 |
| None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.464 | 31.936 | 127.808 | 228.832 | 0.954 | 0.559 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 29.472 | 33.856 | 113.152 | 215.072 | 0.871 | 0.526 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.496 | 32.160 | 116.576 | 231.744 | 0.948 | 0.503 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.464 | 31.904 | 116.320 | 229.824 | 0.955 | 0.506 |
| None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.480 | 61.440 | 176.448 | 345.312 | 0.659 | 0.511 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 38.304 | 59.424 | 169.312 | 371.360 | 0.645 | 0.456 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.960 | 61.760 | 176.512 | 358.912 | 0.663 | 0.492 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.352 | 61.696 | 176.512 | 344.928 | 0.654 | 0.512 |
| None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 316.224 | 357.728 | 905.728 | 1668.448 | 0.884 | 0.543 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 199.904 | 248.416 | 636.544 | 1109.088 | 0.805 | 0.574 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 314.880 | 363.616 | 906.304 | 1658.176 | 0.866 | 0.547 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 316.160 | 354.368 | 906.080 | 1649.024 | 0.892 | 0.549 |
| None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.912 | 739.840 | 1555.808 | 2521.952 | 0.604 | 0.617 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 279.776 | 463.904 | 1068.928 | 1849.888 | 0.603 | 0.578 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.080 | 748.960 | 1553.504 | 2629.888 | 0.596 | 0.591 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.208 | 740.608 | 1558.880 | 2524.960 | 0.602 | 0.617 |
| None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 33.568 | 41.280 | 170.016 | 147.584 | 0.813 | 1.152 |
| None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 30.688 | 43.040 | 159.552 | 146.720 | 0.713 | 1.087 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 34.112 | 41.504 | 170.112 | 152.672 | 0.822 | 1.114 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 34.240 | 41.152 | 170.272 | 134.976 | 0.832 | 1.261 |
| None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.672 | 76.416 | 295.296 | 263.648 | 0.637 | 1.120 |
| None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 45.088 | 72.576 | 281.920 | 237.664 | 0.621 | 1.186 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.032 | 76.672 | 295.520 | 265.248 | 0.626 | 1.114 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.096 | 76.096 | 295.456 | 262.112 | 0.632 | 1.127 |
| None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 93.920 | 111.232 | 401.568 | 382.944 | 0.844 | 1.049 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 68.192 | 95.232 | 338.752 | 326.816 | 0.716 | 1.037 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 93.984 | 111.840 | 401.856 | 444.224 | 0.840 | 0.905 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 94.176 | 110.496 | 401.600 | 383.136 | 0.852 | 1.048 |
| None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.488 | 227.040 | 727.424 | 739.712 | 0.579 | 0.983 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 95.616 | 169.760 | 616.864 | 574.112 | 0.563 | 1.074 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.680 | 228.672 | 727.616 | 746.048 | 0.576 | 0.975 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.104 | 225.696 | 727.904 | 735.392 | 0.581 | 0.990 |
| None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1227.296 | 1386.656 | 3720.192 | 4539.904 | 0.885 | 0.819 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 691.360 | 831.712 | 2515.872 | 3067.808 | 0.831 | 0.820 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1228.192 | 1403.136 | 3715.520 | 5309.280 | 0.875 | 0.700 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1229.024 | 1384.992 | 3715.904 | 4550.368 | 0.887 | 0.817 |
| None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1784.832 | 2865.888 | 6539.840 | 8460.224 | 0.623 | 0.773 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1017.408 | 1660.480 | 4369.824 | 5056.992 | 0.613 | 0.864 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1792.448 | 2904.864 | 6546.080 | 8537.024 | 0.617 | 0.767 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1795.552 | 2856.864 | 6544.672 | 8400.160 | 0.629 | 0.779 |
| None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 38.880 | 148.832 | 179.936 | 0.881 | 0.827 |
| None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 31.168 | 38.080 | 138.528 | 167.552 | 0.818 | 0.827 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 39.168 | 148.512 | 181.248 | 0.874 | 0.819 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 38.784 | 148.864 | 180.224 | 0.883 | 0.826 |
| None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.832 | 76.352 | 253.632 | 295.968 | 0.640 | 0.857 |
| None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 45.760 | 65.792 | 239.040 | 290.752 | 0.696 | 0.822 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.768 | 76.576 | 253.312 | 304.032 | 0.637 | 0.833 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.768 | 76.192 | 253.600 | 296.096 | 0.640 | 0.856 |
| None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.728 | 109.728 | 357.696 | 498.912 | 0.854 | 0.717 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 68.704 | 92.288 | 295.616 | 386.240 | 0.744 | 0.765 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.632 | 111.392 | 357.408 | 512.448 | 0.841 | 0.697 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.280 | 109.952 | 357.696 | 501.440 | 0.848 | 0.713 |
| None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.392 | 230.496 | 612.224 | 807.552 | 0.570 | 0.758 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 96.512 | 165.184 | 502.624 | 672.384 | 0.584 | 0.748 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.360 | 232.608 | 612.064 | 832.320 | 0.565 | 0.735 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.008 | 230.528 | 612.640 | 804.320 | 0.568 | 0.762 |
| None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1227.968 | 1377.408 | 3477.920 | 5324.384 | 0.892 | 0.653 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 695.264 | 824.544 | 2268.224 | 3210.208 | 0.843 | 0.707 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1228.640 | 1404.576 | 3476.832 | 5463.456 | 0.875 | 0.636 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1228.416 | 1378.752 | 3478.048 | 5367.712 | 0.891 | 0.648 |
| None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1788.736 | 2867.712 | 6039.520 | 8616.256 | 0.624 | 0.701 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1021.952 | 1653.824 | 3866.208 | 5306.848 | 0.618 | 0.729 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1786.752 | 2896.352 | 6044.128 | 8871.360 | 0.617 | 0.681 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1786.080 | 2868.672 | 6040.160 | 8550.144 | 0.623 | 0.706 |
| None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 57.504 | 71.552 | 312.768 | 255.040 | 0.804 | 1.226 |
| None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 49.472 | 71.104 | 285.696 | 243.520 | 0.696 | 1.173 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 58.112 | 72.896 | 312.768 | 288.256 | 0.797 | 1.085 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 57.952 | 71.680 | 312.768 | 255.552 | 0.808 | 1.224 |
| None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.336 | 144.256 | 580.128 | 500.160 | 0.571 | 1.160 |
| None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 76.160 | 123.712 | 552.544 | 447.648 | 0.616 | 1.234 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.400 | 145.184 | 580.032 | 504.032 | 0.568 | 1.151 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.368 | 143.904 | 580.192 | 499.936 | 0.572 | 1.161 |
| None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.216 | 209.568 | 787.872 | 747.712 | 0.846 | 1.054 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 121.984 | 168.256 | 651.968 | 628.256 | 0.725 | 1.038 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.088 | 211.488 | 788.320 | 864.352 | 0.837 | 0.912 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.440 | 208.576 | 787.424 | 749.120 | 0.851 | 1.051 |
| None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 249.472 | 441.376 | 1405.440 | 1431.648 | 0.565 | 0.982 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 172.960 | 312.064 | 1172.064 | 1096.448 | 0.554 | 1.069 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 249.632 | 446.336 | 1405.408 | 1448.480 | 0.559 | 0.970 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 250.944 | 440.128 | 1406.624 | 1421.952 | 0.570 | 0.989 |
| None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2418.720 | 2747.936 | 7330.432 | 9023.712 | 0.880 | 0.812 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 1353.696 | 1608.480 | 4941.696 | 6078.752 | 0.842 | 0.813 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2427.456 | 2746.816 | 7329.792 | 10539.968 | 0.884 | 0.695 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2426.688 | 2763.168 | 7336.256 | 9057.536 | 0.878 | 0.810 |
| None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3554.240 | 5634.400 | 12919.872 | 16843.489 | 0.631 | 0.767 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 2003.648 | 3250.784 | 8610.144 | 10015.424 | 0.616 | 0.860 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3582.080 | 5710.944 | 12923.328 | 17011.871 | 0.627 | 0.760 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3581.920 | 5618.144 | 12934.528 | 16745.888 | 0.638 | 0.772 |
| None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.120 | 71.232 | 269.760 | 295.680 | 0.802 | 0.912 |
| None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 49.408 | 65.312 | 242.304 | 253.952 | 0.756 | 0.954 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.504 | 72.544 | 269.632 | 298.976 | 0.793 | 0.902 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.760 | 71.040 | 269.600 | 296.640 | 0.813 | 0.909 |
| None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 82.336 | 147.168 | 466.080 | 487.456 | 0.559 | 0.956 |
| None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 76.704 | 115.040 | 435.392 | 453.248 | 0.667 | 0.961 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 81.856 | 147.424 | 465.920 | 499.552 | 0.555 | 0.933 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 81.760 | 146.656 | 466.176 | 485.984 | 0.557 | 0.959 |
| None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 176.608 | 206.976 | 678.080 | 866.976 | 0.853 | 0.782 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 121.664 | 164.768 | 538.240 | 636.160 | 0.738 | 0.846 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 176.608 | 209.664 | 677.696 | 883.424 | 0.842 | 0.767 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 177.440 | 207.840 | 677.248 | 868.288 | 0.854 | 0.780 |
| None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 250.272 | 449.536 | 1163.424 | 1420.832 | 0.557 | 0.819 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 173.472 | 305.376 | 929.408 | 1104.544 | 0.568 | 0.841 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 249.376 | 454.976 | 1163.648 | 1455.296 | 0.548 | 0.800 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 250.368 | 450.144 | 1163.520 | 1409.984 | 0.556 | 0.825 |
| None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2416.576 | 2726.208 | 6835.520 | 10442.784 | 0.886 | 0.655 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 1357.440 | 1590.752 | 4433.664 | 5975.296 | 0.853 | 0.742 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2427.360 | 2747.040 | 6853.056 | 10670.784 | 0.884 | 0.642 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2441.120 | 2718.944 | 6836.640 | 10433.792 | 0.898 | 0.655 |
| None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3555.392 | 5620.960 | 11944.000 | 16504.801 | 0.633 | 0.724 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 2010.848 | 3241.152 | 7636.064 | 9870.464 | 0.620 | 0.774 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3557.440 | 5688.352 | 11935.744 | 17090.496 | 0.625 | 0.698 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3562.720 | 5630.432 | 11939.168 | 16392.033 | 0.633 | 0.728 |
</details>
### Perf after this PR
**FWD**
| Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) |
|---------|-----------|---------------|------------|----------------|----------------------------|
| Average | 0.776 | | | | |
| Max | 1.006 | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) |
| Min | 0.566 | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) |
**BWD**
| Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) |
|---------|-----------|-------------|------------|----------------|-----------------------------|
| Average | 0.817 | | | | |
| Max | 1.150 | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) |
| Min | 0.454 | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) |
<details>
<summary> Full performance sweep </summary>
| score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup |
|---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------|
| None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.680 | 17.056 | 64.544 | 73.376 | 0.919 | 0.880 |
| None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.712 | 19.872 | 65.408 | 72.864 | 0.791 | 0.898 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 16.160 | 17.280 | 64.896 | 73.888 | 0.935 | 0.878 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 16.192 | 17.120 | 64.896 | 75.424 | 0.946 | 0.860 |
| None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.648 | 22.496 | 89.184 | 82.592 | 0.873 | 1.080 |
| None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 20.320 | 26.816 | 91.264 | 82.880 | 0.758 | 1.101 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 20.096 | 22.528 | 89.184 | 83.776 | 0.892 | 1.065 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.680 | 22.432 | 89.184 | 120.096 | 0.877 | 0.743 |
| None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.384 | 32.512 | 119.232 | 128.960 | 0.996 | 0.925 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.176 | 37.248 | 113.664 | 119.520 | 0.810 | 0.951 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.512 | 32.928 | 119.264 | 131.456 | 0.987 | 0.907 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.448 | 32.704 | 119.200 | 128.352 | 0.992 | 0.929 |
| None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 41.952 | 62.176 | 199.040 | 214.304 | 0.675 | 0.929 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 39.744 | 62.880 | 189.504 | 179.968 | 0.632 | 1.053 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 41.472 | 62.784 | 199.136 | 217.664 | 0.661 | 0.915 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 42.048 | 61.952 | 199.168 | 214.496 | 0.679 | 0.929 |
| None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 341.184 | 357.632 | 980.256 | 1328.896 | 0.954 | 0.738 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 212.576 | 252.960 | 673.888 | 824.864 | 0.840 | 0.817 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 340.000 | 363.296 | 980.768 | 1375.808 | 0.936 | 0.713 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 340.768 | 356.832 | 980.960 | 1326.272 | 0.955 | 0.740 |
| None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 459.392 | 737.120 | 1678.240 | 2205.248 | 0.623 | 0.761 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 292.672 | 468.096 | 1178.016 | 1371.584 | 0.625 | 0.859 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 462.144 | 745.312 | 1680.000 | 2252.512 | 0.620 | 0.746 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 462.112 | 736.576 | 1679.008 | 2216.480 | 0.627 | 0.758 |
| None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.064 | 16.704 | 105.120 | 120.768 | 0.962 | 0.870 |
| None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.552 | 18.144 | 107.136 | 121.696 | 0.857 | 0.880 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.096 | 16.768 | 102.688 | 120.864 | 0.960 | 0.850 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.032 | 16.576 | 104.736 | 124.672 | 0.967 | 0.840 |
| None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.392 | 21.952 | 104.736 | 174.656 | 0.883 | 0.600 |
| None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 20.128 | 23.712 | 105.216 | 199.008 | 0.849 | 0.529 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.904 | 21.888 | 103.744 | 179.520 | 0.909 | 0.578 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.968 | 21.952 | 104.640 | 177.312 | 0.910 | 0.590 |
| None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.096 | 31.904 | 118.720 | 231.968 | 1.006 | 0.512 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.528 | 33.952 | 112.480 | 218.304 | 0.899 | 0.515 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.160 | 32.224 | 118.752 | 237.312 | 0.998 | 0.500 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.128 | 32.032 | 118.240 | 233.120 | 1.003 | 0.507 |
| None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.312 | 61.280 | 177.408 | 350.688 | 0.674 | 0.506 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 39.552 | 59.360 | 168.832 | 371.488 | 0.666 | 0.454 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.984 | 61.696 | 177.376 | 360.416 | 0.680 | 0.492 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.312 | 61.760 | 177.184 | 355.744 | 0.669 | 0.498 |
| None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 339.744 | 357.888 | 939.712 | 1665.376 | 0.949 | 0.564 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 212.608 | 248.832 | 633.280 | 1122.848 | 0.854 | 0.564 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 339.712 | 363.232 | 940.448 | 1689.440 | 0.935 | 0.557 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 341.056 | 355.264 | 940.128 | 1641.152 | 0.960 | 0.573 |
| None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.736 | 741.024 | 1569.824 | 2559.552 | 0.622 | 0.613 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 293.856 | 464.192 | 1066.240 | 1840.416 | 0.633 | 0.579 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.704 | 753.152 | 1570.112 | 2641.088 | 0.612 | 0.594 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.832 | 745.536 | 1570.144 | 2602.560 | 0.618 | 0.603 |
| None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.680 | 41.280 | 171.840 | 158.176 | 0.864 | 1.086 |
| None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 31.360 | 42.976 | 158.912 | 139.264 | 0.730 | 1.141 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.168 | 41.600 | 171.648 | 161.344 | 0.845 | 1.064 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.136 | 41.152 | 171.808 | 158.336 | 0.854 | 1.085 |
| None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.832 | 76.384 | 295.680 | 277.696 | 0.639 | 1.065 |
| None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 45.632 | 72.512 | 281.760 | 250.752 | 0.629 | 1.124 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 49.504 | 76.608 | 295.584 | 279.712 | 0.646 | 1.057 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.864 | 75.904 | 295.456 | 277.568 | 0.644 | 1.064 |
| None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 99.392 | 111.232 | 408.640 | 442.656 | 0.894 | 0.923 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 71.392 | 95.168 | 338.784 | 341.760 | 0.750 | 0.991 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 99.808 | 112.256 | 408.608 | 456.160 | 0.889 | 0.896 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 100.032 | 110.816 | 408.512 | 444.192 | 0.903 | 0.920 |
| None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.040 | 226.112 | 726.880 | 774.176 | 0.597 | 0.939 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 99.904 | 169.696 | 616.448 | 607.104 | 0.589 | 1.015 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.488 | 228.384 | 727.776 | 782.368 | 0.593 | 0.930 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.744 | 225.664 | 728.000 | 773.600 | 0.602 | 0.941 |
| None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1324.192 | 1387.808 | 3866.944 | 5217.184 | 0.954 | 0.741 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 738.464 | 832.608 | 2507.392 | 3146.688 | 0.887 | 0.797 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1326.016 | 1404.256 | 3867.872 | 5382.624 | 0.944 | 0.719 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1326.144 | 1386.688 | 3867.552 | 5203.264 | 0.956 | 0.743 |
| None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1847.488 | 2866.336 | 6612.704 | 8597.696 | 0.645 | 0.769 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1066.592 | 1660.640 | 4357.696 | 5174.016 | 0.642 | 0.842 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1850.464 | 2905.408 | 6616.928 | 8793.280 | 0.637 | 0.752 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1848.896 | 2834.720 | 6623.872 | 8637.920 | 0.652 | 0.767 |
| None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.384 | 38.656 | 150.336 | 182.624 | 0.941 | 0.823 |
| None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 31.360 | 38.112 | 137.664 | 171.840 | 0.823 | 0.801 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.608 | 39.040 | 150.528 | 183.872 | 0.938 | 0.819 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.064 | 38.656 | 150.560 | 183.520 | 0.933 | 0.820 |
| None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.344 | 76.352 | 253.920 | 301.440 | 0.646 | 0.842 |
| None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 46.720 | 65.824 | 239.424 | 296.384 | 0.710 | 0.808 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.248 | 76.416 | 253.728 | 307.808 | 0.644 | 0.824 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.376 | 76.288 | 253.728 | 304.736 | 0.647 | 0.833 |
| None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.264 | 110.144 | 364.960 | 503.072 | 0.901 | 0.725 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 71.136 | 92.384 | 294.432 | 393.056 | 0.770 | 0.749 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.200 | 111.360 | 365.152 | 512.640 | 0.891 | 0.712 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.264 | 110.240 | 365.088 | 504.224 | 0.900 | 0.724 |
| None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.680 | 230.336 | 613.472 | 816.896 | 0.589 | 0.751 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 100.256 | 165.088 | 502.144 | 676.480 | 0.607 | 0.742 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.008 | 232.480 | 613.184 | 836.672 | 0.581 | 0.733 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.232 | 230.624 | 613.536 | 827.136 | 0.586 | 0.742 |
| None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1324.064 | 1378.688 | 3631.808 | 5308.384 | 0.960 | 0.684 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 731.776 | 826.688 | 2263.168 | 3241.344 | 0.885 | 0.698 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1316.128 | 1403.200 | 3625.088 | 5550.688 | 0.938 | 0.653 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1311.904 | 1378.880 | 3616.320 | 5353.696 | 0.951 | 0.675 |
| None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1837.856 | 2887.392 | 6121.632 | 8586.656 | 0.637 | 0.713 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1066.976 | 1654.368 | 3843.136 | 5291.040 | 0.645 | 0.726 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1854.208 | 2896.832 | 6130.112 | 8745.984 | 0.640 | 0.701 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1860.512 | 2889.344 | 6135.648 | 8750.592 | 0.644 | 0.701 |
| None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 60.640 | 71.552 | 315.968 | 296.512 | 0.847 | 1.066 |
| None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 50.784 | 71.040 | 284.288 | 258.880 | 0.715 | 1.098 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 61.312 | 72.704 | 315.680 | 302.016 | 0.843 | 1.045 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 60.800 | 71.776 | 316.320 | 297.152 | 0.847 | 1.065 |
| None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.576 | 144.416 | 580.576 | 535.936 | 0.586 | 1.083 |
| None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 76.064 | 123.648 | 553.344 | 481.376 | 0.615 | 1.150 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.160 | 145.248 | 581.024 | 540.000 | 0.579 | 1.076 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.512 | 143.552 | 581.088 | 535.776 | 0.589 | 1.085 |
| None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.152 | 209.408 | 798.400 | 868.704 | 0.903 | 0.919 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 127.552 | 168.800 | 650.816 | 663.328 | 0.756 | 0.981 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.376 | 211.360 | 798.080 | 895.552 | 0.896 | 0.891 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.440 | 208.576 | 797.888 | 873.152 | 0.908 | 0.914 |
| None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 257.536 | 441.760 | 1408.960 | 1514.720 | 0.583 | 0.930 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 179.328 | 312.096 | 1170.368 | 1177.472 | 0.575 | 0.994 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 259.264 | 446.944 | 1408.768 | 1530.400 | 0.580 | 0.921 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 258.080 | 440.480 | 1408.864 | 1514.144 | 0.586 | 0.930 |
| None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2595.808 | 2771.456 | 7616.704 | 10405.248 | 0.937 | 0.732 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 1435.744 | 1610.336 | 4927.520 | 6220.000 | 0.892 | 0.792 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2595.264 | 2745.056 | 7611.232 | 10631.392 | 0.945 | 0.716 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2576.256 | 2735.456 | 7626.400 | 10346.976 | 0.942 | 0.737 |
| None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3679.744 | 5634.816 | 13077.056 | 17182.528 | 0.653 | 0.761 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 2099.360 | 3250.176 | 8589.664 | 10236.672 | 0.646 | 0.839 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3676.800 | 5716.288 | 13073.088 | 17311.071 | 0.643 | 0.755 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3679.136 | 5570.496 | 13070.720 | 17192.863 | 0.660 | 0.760 |
| None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.600 | 71.008 | 272.320 | 300.000 | 0.868 | 0.908 |
| None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 50.176 | 65.344 | 241.568 | 258.912 | 0.768 | 0.933 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.120 | 72.512 | 272.672 | 305.408 | 0.843 | 0.893 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.248 | 71.136 | 272.640 | 301.120 | 0.861 | 0.905 |
| None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.872 | 146.784 | 466.912 | 496.832 | 0.571 | 0.940 |
| None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 76.704 | 115.072 | 435.584 | 462.112 | 0.667 | 0.943 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.392 | 147.392 | 466.656 | 504.448 | 0.566 | 0.925 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.360 | 146.688 | 466.656 | 499.040 | 0.568 | 0.935 |
| None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 189.024 | 207.584 | 684.768 | 873.568 | 0.911 | 0.784 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 126.944 | 164.288 | 536.192 | 645.984 | 0.773 | 0.830 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 188.768 | 209.760 | 684.096 | 897.504 | 0.900 | 0.762 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 189.408 | 207.776 | 685.024 | 876.384 | 0.912 | 0.782 |
| None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 259.168 | 449.536 | 1167.936 | 1433.280 | 0.577 | 0.815 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 180.000 | 305.312 | 928.000 | 1113.920 | 0.590 | 0.833 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 258.464 | 455.136 | 1167.808 | 1462.848 | 0.568 | 0.798 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 257.824 | 450.208 | 1167.744 | 1448.000 | 0.573 | 0.806 |
| None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2598.368 | 2729.120 | 7134.400 | 10381.632 | 0.952 | 0.687 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 1435.456 | 1591.040 | 4424.768 | 6035.808 | 0.902 | 0.733 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2594.752 | 2725.952 | 7128.384 | 10822.496 | 0.952 | 0.659 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2597.888 | 2716.960 | 7101.568 | 10385.440 | 0.956 | 0.684 |
| None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3647.648 | 5581.632 | 12089.952 | 16667.233 | 0.654 | 0.725 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 2093.952 | 3241.440 | 7579.392 | 9847.936 | 0.646 | 0.770 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3650.528 | 5650.688 | 12105.568 | 16963.680 | 0.646 | 0.714 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3680.064 | 5585.312 | 12117.504 | 16935.040 | 0.659 | 0.716 |
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135505
Approved by: https://github.com/Chillee
2024-09-10 09:30:00 +00:00
|
|
|
test_Bq_Bkv = [
|
|
|
|
|
(3, 1),
|
|
|
|
|
(5, 1),
|
|
|
|
|
(8, 1),
|
|
|
|
|
(16, 1),
|
|
|
|
|
]
|
|
|
|
|
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
(Hq, Hkv) = (16, 8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def query_key_value_clones(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
dtype: torch.dtype = None,
|
|
|
|
|
):
|
|
|
|
|
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
|
|
|
|
|
if dtype is None:
|
|
|
|
|
dtype = query.dtype
|
|
|
|
|
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
|
|
|
|
|
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
|
|
|
|
|
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
|
|
|
|
|
return query_ref, key_ref, value_ref
|
|
|
|
|
|
|
|
|
|
|
2024-08-02 20:50:47 +00:00
|
|
|
class TestFlexDecoding(InductorTestCase):
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
def _check_equal(
|
|
|
|
|
self,
|
|
|
|
|
golden_out: torch.Tensor,
|
|
|
|
|
ref_out: torch.Tensor,
|
|
|
|
|
compiled_out: torch.Tensor,
|
|
|
|
|
fudge_factor: float,
|
|
|
|
|
tensor_name: Optional[str] = None,
|
|
|
|
|
):
|
|
|
|
|
compiled_error = (golden_out - compiled_out).abs().mean()
|
|
|
|
|
ref_error = (golden_out - ref_out).abs().mean()
|
|
|
|
|
if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any():
|
|
|
|
|
self.assertTrue(False, "Output/Grad with NaN")
|
|
|
|
|
if ref_error < (1e-4) * golden_out.abs().mean():
|
|
|
|
|
print(
|
|
|
|
|
"very small ref error of ",
|
|
|
|
|
(ref_error.to(torch.float64) * (1e5) / golden_out.abs().mean()),
|
|
|
|
|
)
|
|
|
|
|
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
|
|
|
|
|
torch.testing.assert_close(
|
|
|
|
|
golden_out.to(dtype=compiled_out.dtype),
|
|
|
|
|
compiled_out,
|
|
|
|
|
atol=tolerance.atol,
|
|
|
|
|
rtol=tolerance.rtol,
|
|
|
|
|
)
|
|
|
|
|
elif compiled_error > ref_error * fudge_factor:
|
|
|
|
|
name = tensor_name if tensor_name is not None else ""
|
|
|
|
|
msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
|
|
|
|
|
self.assertTrue(False, msg)
|
|
|
|
|
|
Add explicit GQA support. (#131559)
### tl;dr
This PR adds GQA support to higher order op `flex_attention`.
## Details
When `enable_gqa` is set to True, HOP `flex_attention(score_mod, query, key, value, block_mask, enable_gqa)` runs Group Query Attention(GQA), where the number of query heads (Hq) is a multiple of number of key/value heads (Hkv). Each group of query heads (`Hq//Hkv` heads) attends to a shared kv head.
Otherwise, `flex_attention` assumes Multi Head Attention (MHA) where the number of query heads is equal the number of kv heads.
The `score_mod` and `mask_mod` API are adapted accordingly to take `q_head` as head index.
```
def score_mod(score: torch.Tensor, batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
def mask_mod(batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```
## Example
```python
import torch
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask
torch.manual_seed(0)
def query_key_value_clones(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype = None,
):
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
if dtype is None:
dtype = query.dtype
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref
# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim).
# query and key/value can have different num_heads and seq_len
# Here 8 query heads share one KV head.
query = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
key = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
value = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
query1, key1, value1 = query_key_value_clones(query, key, value)
# Lets create a score_modification. We take alibi_bias as an example.
# score_mod takes batch index, query head index, query index, and key/value index.
def _generate_alibi_bias(num_kv_heads: int, num_q_heads: int):
def _alibi_bias(
score: torch.Tensor,
b: torch.Tensor,
hq: torch.Tensor,
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
# Let's calculate kv head from query head index
group = num_q_heads // num_kv_heads
hkv = hq // group
scale = torch.exp2(-((hkv + 1) * 8.0 / num_kv_heads))
return score + (token_kv - token_q) * scale
return _alibi_bias
# Let's apply a casual mask on top of it
def causal_mask(b, h, q, kv):
return q >= kv
# Generate a block mask for our new mask_mod function.
# The mask is broadcasted long head & batch dimensions.
block_mask = create_block_mask(causal_mask, B=1, H=1, Q_LEN=2048, KV_LEN=2048)
# Lets call flex_attention with our new score modification and block mask under eager mode.
output = flex_attention(query, key, value, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
# Now lets compile flex_attention and run the flex_attention kernel.
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention(query1, key1, value1, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
torch.testing.assert_close(output, out_compiled, atol=5e-2, rtol=2e-2)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131559
Approved by: https://github.com/drisspg
2024-08-09 18:09:18 +00:00
|
|
|
def _check_out(
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
self,
|
|
|
|
|
golden_out: torch.Tensor,
|
|
|
|
|
ref_out: torch.Tensor,
|
|
|
|
|
compiled_out: torch.Tensor,
|
|
|
|
|
):
|
|
|
|
|
dtype = ref_out.dtype
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
# Note, it seems like we really are less accurate than the float32
|
|
|
|
|
# computation, likely due to the online softmax
|
|
|
|
|
if dtype == torch.float32:
|
|
|
|
|
fudge_factor = 10.0
|
|
|
|
|
else:
|
|
|
|
|
fudge_factor = 1.1
|
|
|
|
|
|
|
|
|
|
# Checkout output
|
|
|
|
|
self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")
|
|
|
|
|
|
|
|
|
|
def run_test(
|
|
|
|
|
self,
|
2024-08-27 17:39:14 +00:00
|
|
|
score_mod: Optional[Callable],
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
dtype: torch.dtype = torch.float16,
|
|
|
|
|
Q_B: int = B,
|
|
|
|
|
Q_H: int = Hq,
|
|
|
|
|
Q_S: int = 1,
|
|
|
|
|
Q_D: int = D,
|
|
|
|
|
KV_B: int = B,
|
|
|
|
|
KV_H: int = Hkv,
|
|
|
|
|
KV_S: int = S,
|
2024-08-21 19:03:22 +00:00
|
|
|
V_D: int = D,
|
2024-08-27 17:39:14 +00:00
|
|
|
block_mask: Optional[BlockMask] = None,
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
):
|
2024-08-27 17:39:14 +00:00
|
|
|
assert (
|
|
|
|
|
score_mod is not None or block_mask is not None
|
|
|
|
|
), "Must provide score_mod or block_mask"
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
assert Q_H % KV_H == 0
|
|
|
|
|
q = torch.randn(
|
Add explicit GQA support. (#131559)
### tl;dr
This PR adds GQA support to higher order op `flex_attention`.
## Details
When `enable_gqa` is set to True, HOP `flex_attention(score_mod, query, key, value, block_mask, enable_gqa)` runs Group Query Attention(GQA), where the number of query heads (Hq) is a multiple of number of key/value heads (Hkv). Each group of query heads (`Hq//Hkv` heads) attends to a shared kv head.
Otherwise, `flex_attention` assumes Multi Head Attention (MHA) where the number of query heads is equal the number of kv heads.
The `score_mod` and `mask_mod` API are adapted accordingly to take `q_head` as head index.
```
def score_mod(score: torch.Tensor, batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
def mask_mod(batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```
## Example
```python
import torch
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask
torch.manual_seed(0)
def query_key_value_clones(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype = None,
):
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
if dtype is None:
dtype = query.dtype
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref
# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim).
# query and key/value can have different num_heads and seq_len
# Here 8 query heads share one KV head.
query = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
key = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
value = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
query1, key1, value1 = query_key_value_clones(query, key, value)
# Lets create a score_modification. We take alibi_bias as an example.
# score_mod takes batch index, query head index, query index, and key/value index.
def _generate_alibi_bias(num_kv_heads: int, num_q_heads: int):
def _alibi_bias(
score: torch.Tensor,
b: torch.Tensor,
hq: torch.Tensor,
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
# Let's calculate kv head from query head index
group = num_q_heads // num_kv_heads
hkv = hq // group
scale = torch.exp2(-((hkv + 1) * 8.0 / num_kv_heads))
return score + (token_kv - token_q) * scale
return _alibi_bias
# Let's apply a casual mask on top of it
def causal_mask(b, h, q, kv):
return q >= kv
# Generate a block mask for our new mask_mod function.
# The mask is broadcasted long head & batch dimensions.
block_mask = create_block_mask(causal_mask, B=1, H=1, Q_LEN=2048, KV_LEN=2048)
# Lets call flex_attention with our new score modification and block mask under eager mode.
output = flex_attention(query, key, value, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
# Now lets compile flex_attention and run the flex_attention kernel.
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention(query1, key1, value1, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
torch.testing.assert_close(output, out_compiled, atol=5e-2, rtol=2e-2)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131559
Approved by: https://github.com/drisspg
2024-08-09 18:09:18 +00:00
|
|
|
(Q_B, Q_H, Q_S, Q_D),
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
dtype=dtype,
|
|
|
|
|
device="cuda",
|
|
|
|
|
requires_grad=False,
|
|
|
|
|
)
|
|
|
|
|
k = torch.randn(
|
2024-08-21 19:03:22 +00:00
|
|
|
(KV_B, KV_H, KV_S, Q_D), dtype=dtype, device="cuda", requires_grad=False
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
)
|
|
|
|
|
v = torch.randn(
|
2024-08-21 19:03:22 +00:00
|
|
|
(KV_B, KV_H, KV_S, V_D), dtype=dtype, device="cuda", requires_grad=False
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
)
|
|
|
|
|
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
|
|
|
|
|
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
|
|
|
|
|
|
Add explicit GQA support. (#131559)
### tl;dr
This PR adds GQA support to higher order op `flex_attention`.
## Details
When `enable_gqa` is set to True, HOP `flex_attention(score_mod, query, key, value, block_mask, enable_gqa)` runs Group Query Attention(GQA), where the number of query heads (Hq) is a multiple of number of key/value heads (Hkv). Each group of query heads (`Hq//Hkv` heads) attends to a shared kv head.
Otherwise, `flex_attention` assumes Multi Head Attention (MHA) where the number of query heads is equal the number of kv heads.
The `score_mod` and `mask_mod` API are adapted accordingly to take `q_head` as head index.
```
def score_mod(score: torch.Tensor, batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
def mask_mod(batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```
## Example
```python
import torch
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask
torch.manual_seed(0)
def query_key_value_clones(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype = None,
):
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
if dtype is None:
dtype = query.dtype
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref
# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim).
# query and key/value can have different num_heads and seq_len
# Here 8 query heads share one KV head.
query = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
key = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
value = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
query1, key1, value1 = query_key_value_clones(query, key, value)
# Lets create a score_modification. We take alibi_bias as an example.
# score_mod takes batch index, query head index, query index, and key/value index.
def _generate_alibi_bias(num_kv_heads: int, num_q_heads: int):
def _alibi_bias(
score: torch.Tensor,
b: torch.Tensor,
hq: torch.Tensor,
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
# Let's calculate kv head from query head index
group = num_q_heads // num_kv_heads
hkv = hq // group
scale = torch.exp2(-((hkv + 1) * 8.0 / num_kv_heads))
return score + (token_kv - token_q) * scale
return _alibi_bias
# Let's apply a casual mask on top of it
def causal_mask(b, h, q, kv):
return q >= kv
# Generate a block mask for our new mask_mod function.
# The mask is broadcasted long head & batch dimensions.
block_mask = create_block_mask(causal_mask, B=1, H=1, Q_LEN=2048, KV_LEN=2048)
# Lets call flex_attention with our new score modification and block mask under eager mode.
output = flex_attention(query, key, value, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
# Now lets compile flex_attention and run the flex_attention kernel.
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention(query1, key1, value1, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
torch.testing.assert_close(output, out_compiled, atol=5e-2, rtol=2e-2)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131559
Approved by: https://github.com/drisspg
2024-08-09 18:09:18 +00:00
|
|
|
sdpa_partial = create_attention(
|
|
|
|
|
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
|
|
|
|
|
)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
compiled_sdpa = torch.compile(sdpa_partial)
|
2024-08-22 17:28:52 +00:00
|
|
|
golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
|
|
|
|
|
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
|
|
|
|
|
compiled_out, compiled_lse = compiled_sdpa(q, k, v, return_lse=True)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
Add explicit GQA support. (#131559)
### tl;dr
This PR adds GQA support to higher order op `flex_attention`.
## Details
When `enable_gqa` is set to True, HOP `flex_attention(score_mod, query, key, value, block_mask, enable_gqa)` runs Group Query Attention(GQA), where the number of query heads (Hq) is a multiple of number of key/value heads (Hkv). Each group of query heads (`Hq//Hkv` heads) attends to a shared kv head.
Otherwise, `flex_attention` assumes Multi Head Attention (MHA) where the number of query heads is equal the number of kv heads.
The `score_mod` and `mask_mod` API are adapted accordingly to take `q_head` as head index.
```
def score_mod(score: torch.Tensor, batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
def mask_mod(batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```
## Example
```python
import torch
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask
torch.manual_seed(0)
def query_key_value_clones(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype = None,
):
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
if dtype is None:
dtype = query.dtype
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref
# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim).
# query and key/value can have different num_heads and seq_len
# Here 8 query heads share one KV head.
query = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
key = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
value = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
query1, key1, value1 = query_key_value_clones(query, key, value)
# Lets create a score_modification. We take alibi_bias as an example.
# score_mod takes batch index, query head index, query index, and key/value index.
def _generate_alibi_bias(num_kv_heads: int, num_q_heads: int):
def _alibi_bias(
score: torch.Tensor,
b: torch.Tensor,
hq: torch.Tensor,
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
# Let's calculate kv head from query head index
group = num_q_heads // num_kv_heads
hkv = hq // group
scale = torch.exp2(-((hkv + 1) * 8.0 / num_kv_heads))
return score + (token_kv - token_q) * scale
return _alibi_bias
# Let's apply a casual mask on top of it
def causal_mask(b, h, q, kv):
return q >= kv
# Generate a block mask for our new mask_mod function.
# The mask is broadcasted long head & batch dimensions.
block_mask = create_block_mask(causal_mask, B=1, H=1, Q_LEN=2048, KV_LEN=2048)
# Lets call flex_attention with our new score modification and block mask under eager mode.
output = flex_attention(query, key, value, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
# Now lets compile flex_attention and run the flex_attention kernel.
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention(query1, key1, value1, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
torch.testing.assert_close(output, out_compiled, atol=5e-2, rtol=2e-2)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131559
Approved by: https://github.com/drisspg
2024-08-09 18:09:18 +00:00
|
|
|
self._check_out(
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
golden_out,
|
|
|
|
|
ref_out,
|
|
|
|
|
compiled_out,
|
|
|
|
|
)
|
2024-08-22 17:28:52 +00:00
|
|
|
self._check_out(
|
|
|
|
|
gold_lse,
|
|
|
|
|
ref_lse,
|
|
|
|
|
compiled_lse,
|
|
|
|
|
)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
2024-07-24 01:36:12 +00:00
|
|
|
def run_test_with_call(
|
|
|
|
|
self,
|
|
|
|
|
sdpa_call: Callable,
|
|
|
|
|
golden_call: Optional[Callable] = None,
|
|
|
|
|
dtype: torch.dtype = torch.float16,
|
|
|
|
|
Q_B: int = B,
|
|
|
|
|
Q_H: int = Hq,
|
|
|
|
|
Q_S: int = 1,
|
|
|
|
|
Q_D: int = D,
|
|
|
|
|
KV_B: int = B,
|
|
|
|
|
KV_H: int = Hkv,
|
|
|
|
|
KV_S: int = S,
|
2024-08-21 19:03:22 +00:00
|
|
|
V_D: int = D,
|
2024-07-24 01:36:12 +00:00
|
|
|
):
|
|
|
|
|
if not golden_call:
|
|
|
|
|
golden_call = sdpa_call
|
|
|
|
|
q = torch.randn(
|
|
|
|
|
(Q_B, KV_H, Q_S * (Q_H // KV_H), Q_D),
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
device="cuda",
|
|
|
|
|
requires_grad=False,
|
|
|
|
|
)
|
|
|
|
|
k = torch.randn(
|
2024-08-21 19:03:22 +00:00
|
|
|
(KV_B, KV_H, KV_S, Q_D), dtype=dtype, device="cuda", requires_grad=False
|
2024-07-24 01:36:12 +00:00
|
|
|
)
|
|
|
|
|
v = torch.randn(
|
2024-08-21 19:03:22 +00:00
|
|
|
(KV_B, KV_H, KV_S, V_D), dtype=dtype, device="cuda", requires_grad=False
|
2024-07-24 01:36:12 +00:00
|
|
|
)
|
|
|
|
|
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
|
|
|
|
|
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
|
|
|
|
|
|
|
|
|
|
compiled_sdpa = torch.compile(sdpa_call)
|
|
|
|
|
golden_out = golden_call(q_gold, k_gold, v_gold)
|
|
|
|
|
ref_out = golden_call(q_ref, k_ref, v_ref)
|
|
|
|
|
compiled_out = compiled_sdpa(q, k, v)
|
|
|
|
|
|
Add explicit GQA support. (#131559)
### tl;dr
This PR adds GQA support to higher order op `flex_attention`.
## Details
When `enable_gqa` is set to True, HOP `flex_attention(score_mod, query, key, value, block_mask, enable_gqa)` runs Group Query Attention(GQA), where the number of query heads (Hq) is a multiple of number of key/value heads (Hkv). Each group of query heads (`Hq//Hkv` heads) attends to a shared kv head.
Otherwise, `flex_attention` assumes Multi Head Attention (MHA) where the number of query heads is equal the number of kv heads.
The `score_mod` and `mask_mod` API are adapted accordingly to take `q_head` as head index.
```
def score_mod(score: torch.Tensor, batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
def mask_mod(batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```
## Example
```python
import torch
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask
torch.manual_seed(0)
def query_key_value_clones(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype = None,
):
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
if dtype is None:
dtype = query.dtype
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref
# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim).
# query and key/value can have different num_heads and seq_len
# Here 8 query heads share one KV head.
query = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
key = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
value = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
query1, key1, value1 = query_key_value_clones(query, key, value)
# Lets create a score_modification. We take alibi_bias as an example.
# score_mod takes batch index, query head index, query index, and key/value index.
def _generate_alibi_bias(num_kv_heads: int, num_q_heads: int):
def _alibi_bias(
score: torch.Tensor,
b: torch.Tensor,
hq: torch.Tensor,
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
# Let's calculate kv head from query head index
group = num_q_heads // num_kv_heads
hkv = hq // group
scale = torch.exp2(-((hkv + 1) * 8.0 / num_kv_heads))
return score + (token_kv - token_q) * scale
return _alibi_bias
# Let's apply a casual mask on top of it
def causal_mask(b, h, q, kv):
return q >= kv
# Generate a block mask for our new mask_mod function.
# The mask is broadcasted long head & batch dimensions.
block_mask = create_block_mask(causal_mask, B=1, H=1, Q_LEN=2048, KV_LEN=2048)
# Lets call flex_attention with our new score modification and block mask under eager mode.
output = flex_attention(query, key, value, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
# Now lets compile flex_attention and run the flex_attention kernel.
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention(query1, key1, value1, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
torch.testing.assert_close(output, out_compiled, atol=5e-2, rtol=2e-2)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131559
Approved by: https://github.com/drisspg
2024-08-09 18:09:18 +00:00
|
|
|
self._check_out(
|
2024-07-24 01:36:12 +00:00
|
|
|
golden_out,
|
|
|
|
|
ref_out,
|
|
|
|
|
compiled_out,
|
|
|
|
|
)
|
|
|
|
|
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
@supported_platform
|
|
|
|
|
@expectedFailure
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
|
|
|
def test_bw_decoding_fails(self, dtype):
|
|
|
|
|
make_kv = functools.partial(
|
|
|
|
|
torch.randn,
|
|
|
|
|
(2, 2, 128, 4),
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
device="cuda",
|
|
|
|
|
requires_grad=True,
|
|
|
|
|
)
|
|
|
|
|
make_q = functools.partial(
|
|
|
|
|
torch.randn,
|
|
|
|
|
(2, 2, 8, 4),
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
device="cuda",
|
|
|
|
|
requires_grad=True,
|
|
|
|
|
)
|
|
|
|
|
q, k, v, backward_grad = make_q(), make_kv(), make_kv(), make_q()
|
|
|
|
|
|
2024-07-17 22:24:22 +00:00
|
|
|
block_mask = _create_empty_block_mask(q, k)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
@torch.compile
|
|
|
|
|
def sdpa_hop(q, k, v, score_mod, block_mask):
|
2024-08-10 23:01:14 +00:00
|
|
|
return flex_attention(q, k, v, score_mod)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
output = sdpa_hop(q, k, v, _identity, block_mask)
|
|
|
|
|
|
2024-08-10 23:01:14 +00:00
|
|
|
output.backward(backward_grad)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
|
|
|
@common_utils.parametrize("score_mod", test_score_mods)
|
|
|
|
|
@common_utils.parametrize("head_dims", test_Hq_Hkv)
|
|
|
|
|
def test_builtin_score_mods(
|
|
|
|
|
self, dtype: torch.dtype, score_mod: Callable, head_dims
|
|
|
|
|
):
|
|
|
|
|
Hq, Hkv = head_dims
|
|
|
|
|
assert Hq % Hkv == 0
|
|
|
|
|
self.run_test(score_mod, dtype, Q_H=Hq, KV_H=Hkv)
|
|
|
|
|
|
|
|
|
|
def input_strides_1(B, H, S, D):
|
|
|
|
|
return ((H * S * D, S * D, D, 1), 997) # offset
|
|
|
|
|
|
|
|
|
|
def input_strides_2(B, H, S, D):
|
|
|
|
|
return ((H * D, D, B * H * D, 1), 499) # transposed dimensions
|
|
|
|
|
|
|
|
|
|
def input_strides_3(B, H, S, D):
|
|
|
|
|
return ((S * (D + 1), B * S * (D + 1), (D + 1), 1), 293) # additional buffer
|
|
|
|
|
|
|
|
|
|
def input_strides_4(B, H, S, D):
|
|
|
|
|
return ((1, D, (B + 1) * (H + 1) * D, 1), 97) # shared dimension
|
|
|
|
|
|
|
|
|
|
test_input_strides = [
|
|
|
|
|
input_strides_1,
|
|
|
|
|
input_strides_2,
|
|
|
|
|
input_strides_3,
|
|
|
|
|
input_strides_4,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
|
|
|
@common_utils.parametrize("k_s", test_input_strides)
|
|
|
|
|
@common_utils.parametrize("v_s", test_input_strides)
|
|
|
|
|
@common_utils.parametrize("head_dims", test_Hq_Hkv)
|
|
|
|
|
def test_strided_inputs(self, dtype: torch.dtype, k_s, v_s, head_dims):
|
|
|
|
|
Hq, Hkv = head_dims
|
|
|
|
|
assert Hq % Hkv == 0
|
|
|
|
|
q1 = torch.randn((B * Hq * D), dtype=dtype, device="cuda")
|
|
|
|
|
k1 = torch.randn((B * Hkv * S * D * 4), dtype=dtype, device="cuda")
|
|
|
|
|
v1 = torch.randn((B * Hkv * S * D * 4), dtype=dtype, device="cuda")
|
|
|
|
|
|
|
|
|
|
k_shape = (B, Hkv, S, D)
|
|
|
|
|
v_shape = (B, Hkv, S, D)
|
|
|
|
|
|
Add explicit GQA support. (#131559)
### tl;dr
This PR adds GQA support to higher order op `flex_attention`.
## Details
When `enable_gqa` is set to True, HOP `flex_attention(score_mod, query, key, value, block_mask, enable_gqa)` runs Group Query Attention(GQA), where the number of query heads (Hq) is a multiple of number of key/value heads (Hkv). Each group of query heads (`Hq//Hkv` heads) attends to a shared kv head.
Otherwise, `flex_attention` assumes Multi Head Attention (MHA) where the number of query heads is equal the number of kv heads.
The `score_mod` and `mask_mod` API are adapted accordingly to take `q_head` as head index.
```
def score_mod(score: torch.Tensor, batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
def mask_mod(batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```
## Example
```python
import torch
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask
torch.manual_seed(0)
def query_key_value_clones(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype = None,
):
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
if dtype is None:
dtype = query.dtype
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref
# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim).
# query and key/value can have different num_heads and seq_len
# Here 8 query heads share one KV head.
query = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
key = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
value = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
query1, key1, value1 = query_key_value_clones(query, key, value)
# Lets create a score_modification. We take alibi_bias as an example.
# score_mod takes batch index, query head index, query index, and key/value index.
def _generate_alibi_bias(num_kv_heads: int, num_q_heads: int):
def _alibi_bias(
score: torch.Tensor,
b: torch.Tensor,
hq: torch.Tensor,
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
# Let's calculate kv head from query head index
group = num_q_heads // num_kv_heads
hkv = hq // group
scale = torch.exp2(-((hkv + 1) * 8.0 / num_kv_heads))
return score + (token_kv - token_q) * scale
return _alibi_bias
# Let's apply a casual mask on top of it
def causal_mask(b, h, q, kv):
return q >= kv
# Generate a block mask for our new mask_mod function.
# The mask is broadcasted long head & batch dimensions.
block_mask = create_block_mask(causal_mask, B=1, H=1, Q_LEN=2048, KV_LEN=2048)
# Lets call flex_attention with our new score modification and block mask under eager mode.
output = flex_attention(query, key, value, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
# Now lets compile flex_attention and run the flex_attention kernel.
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention(query1, key1, value1, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
torch.testing.assert_close(output, out_compiled, atol=5e-2, rtol=2e-2)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131559
Approved by: https://github.com/drisspg
2024-08-09 18:09:18 +00:00
|
|
|
q = q1.view(1, Hq, B, D).transpose(0, 2)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
k_strides, k_offset = k_s(B, Hkv, S, D)
|
|
|
|
|
k_max = [x * (y - 1) for x, y in zip(k_strides, k_shape)]
|
|
|
|
|
assert sum(k_max) + k_offset < B * Hkv * S * D * 4
|
|
|
|
|
assert k_strides[-1] == 1
|
|
|
|
|
k = torch.as_strided(k1, k_shape, k_strides, k_offset)
|
|
|
|
|
|
|
|
|
|
v_strides, v_offset = v_s(B, Hkv, S, D)
|
|
|
|
|
v_max = [x * (y - 1) for x, y in zip(v_strides, v_shape)]
|
|
|
|
|
assert sum(v_max) + v_offset < B * Hkv * S * D * 4
|
|
|
|
|
assert v_strides[-1] == 1
|
|
|
|
|
v = torch.as_strided(v1, v_shape, v_strides, v_offset)
|
|
|
|
|
|
|
|
|
|
sdpa_partial = create_attention(
|
Add explicit GQA support. (#131559)
### tl;dr
This PR adds GQA support to higher order op `flex_attention`.
## Details
When `enable_gqa` is set to True, HOP `flex_attention(score_mod, query, key, value, block_mask, enable_gqa)` runs Group Query Attention(GQA), where the number of query heads (Hq) is a multiple of number of key/value heads (Hkv). Each group of query heads (`Hq//Hkv` heads) attends to a shared kv head.
Otherwise, `flex_attention` assumes Multi Head Attention (MHA) where the number of query heads is equal the number of kv heads.
The `score_mod` and `mask_mod` API are adapted accordingly to take `q_head` as head index.
```
def score_mod(score: torch.Tensor, batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
def mask_mod(batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```
## Example
```python
import torch
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask
torch.manual_seed(0)
def query_key_value_clones(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype = None,
):
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
if dtype is None:
dtype = query.dtype
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref
# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim).
# query and key/value can have different num_heads and seq_len
# Here 8 query heads share one KV head.
query = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
key = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
value = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
query1, key1, value1 = query_key_value_clones(query, key, value)
# Lets create a score_modification. We take alibi_bias as an example.
# score_mod takes batch index, query head index, query index, and key/value index.
def _generate_alibi_bias(num_kv_heads: int, num_q_heads: int):
def _alibi_bias(
score: torch.Tensor,
b: torch.Tensor,
hq: torch.Tensor,
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
# Let's calculate kv head from query head index
group = num_q_heads // num_kv_heads
hkv = hq // group
scale = torch.exp2(-((hkv + 1) * 8.0 / num_kv_heads))
return score + (token_kv - token_q) * scale
return _alibi_bias
# Let's apply a casual mask on top of it
def causal_mask(b, h, q, kv):
return q >= kv
# Generate a block mask for our new mask_mod function.
# The mask is broadcasted long head & batch dimensions.
block_mask = create_block_mask(causal_mask, B=1, H=1, Q_LEN=2048, KV_LEN=2048)
# Lets call flex_attention with our new score modification and block mask under eager mode.
output = flex_attention(query, key, value, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
# Now lets compile flex_attention and run the flex_attention kernel.
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention(query1, key1, value1, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
torch.testing.assert_close(output, out_compiled, atol=5e-2, rtol=2e-2)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131559
Approved by: https://github.com/drisspg
2024-08-09 18:09:18 +00:00
|
|
|
score_mod=_generate_alibi_bias(8),
|
|
|
|
|
block_mask=None,
|
|
|
|
|
enable_gqa=(not Hq == Hkv),
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
)
|
|
|
|
|
compiled_sdpa = torch.compile(sdpa_partial)
|
|
|
|
|
ref_out = sdpa_partial(q, k, v)
|
|
|
|
|
compiled_out = compiled_sdpa(q, k, v)
|
|
|
|
|
|
|
|
|
|
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
|
|
|
|
|
torch.testing.assert_close(
|
|
|
|
|
ref_out, compiled_out, atol=tolerance.atol, rtol=tolerance.rtol
|
|
|
|
|
)
|
|
|
|
|
|
[FlexAttention] Add broadcast support for kv batch dimension (#135505)
This PR adds broadcast support for KV batch dimension.
## Details
Consider Q of shape `[Bq, Hq, Q_LEN, D]`, and K, V of shape `[Bkv, Hkv, KV_LEN, D]`. Prior to this diff, we require `Bq == Bkv`. However, for some use cases, we may have Bkv < Bq. For example, in paged attention, we provide K, V of shape `[1, Hkv, MAX_LEN, D]`, while still providing Q of shape `[Bq, Hq, Q_LEN, D]`. Here, MAX_LEN is the maximal number of tokens supported by paged attention.
This PR relax this requirement to be `Bq == Bkv or (Bq > 1 and Bkv == 0)`. This support covers both flex decoding, flex attention forward and backward.
## Benchmark
GPU: H100
We see negligible (1%~2%) performance change from this PR when `Bq == Bkv`.
```
python benchmarks/transformer/score_mod.py --calculate-bwd
```
### Perf before this PR
**FWD**
| Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) |
|---------|-----------|---------------|------------|----------------|------------------------------|
| Average | 0.743 | | | | |
| Max | 0.955 | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) |
| Min | 0.548 | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) |
**BWD**
| Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) |
|---------|-----------|-------------|------------|----------------|-----------------------------|
| Average | 0.834 | | | | |
| Max | 1.261 | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) |
| Min | 0.456 | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) |
<details>
<summary> Full performance sweep </summary>
| score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup |
|---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------|
| None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.264 | 17.184 | 107.040 | 140.800 | 0.888 | 0.760 |
| None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.840 | 19.744 | 112.576 | 140.064 | 0.802 | 0.804 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.232 | 17.344 | 87.744 | 142.496 | 0.878 | 0.616 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.264 | 17.184 | 108.192 | 143.328 | 0.888 | 0.755 |
| None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.904 | 22.400 | 106.432 | 136.512 | 0.889 | 0.780 |
| None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.424 | 26.752 | 91.712 | 106.688 | 0.726 | 0.860 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.808 | 22.432 | 89.024 | 101.920 | 0.883 | 0.873 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.840 | 22.272 | 88.896 | 102.592 | 0.891 | 0.867 |
| None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.240 | 32.416 | 116.768 | 112.256 | 0.933 | 1.040 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 29.536 | 37.024 | 113.664 | 102.688 | 0.798 | 1.107 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.656 | 32.800 | 116.992 | 127.008 | 0.935 | 0.921 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.592 | 32.480 | 116.928 | 112.160 | 0.942 | 1.043 |
| None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.448 | 61.920 | 198.656 | 204.512 | 0.653 | 0.971 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 37.760 | 62.528 | 189.536 | 170.624 | 0.604 | 1.111 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.896 | 62.368 | 198.304 | 205.824 | 0.656 | 0.963 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.448 | 61.952 | 198.432 | 203.648 | 0.653 | 0.974 |
| None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 318.528 | 355.904 | 947.232 | 1162.496 | 0.895 | 0.815 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 199.776 | 252.128 | 677.792 | 813.184 | 0.792 | 0.834 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 316.512 | 363.328 | 947.712 | 1361.984 | 0.871 | 0.696 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 317.984 | 356.864 | 947.264 | 1165.024 | 0.891 | 0.813 |
| None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 446.656 | 734.656 | 1664.288 | 2172.960 | 0.608 | 0.766 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 278.688 | 467.648 | 1182.624 | 1339.296 | 0.596 | 0.883 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 447.872 | 744.096 | 1662.944 | 2196.544 | 0.602 | 0.757 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 448.128 | 732.928 | 1663.072 | 2156.800 | 0.611 | 0.771 |
| None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.648 | 16.640 | 107.520 | 143.008 | 0.940 | 0.752 |
| None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.776 | 18.240 | 129.056 | 141.920 | 0.865 | 0.909 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.168 | 16.640 | 103.616 | 139.648 | 0.912 | 0.742 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.616 | 16.640 | 128.608 | 164.448 | 0.938 | 0.782 |
| None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.776 | 21.952 | 125.344 | 170.304 | 0.901 | 0.736 |
| None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.776 | 23.712 | 104.288 | 196.896 | 0.834 | 0.530 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.072 | 21.952 | 102.080 | 177.056 | 0.869 | 0.577 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.648 | 21.920 | 109.920 | 170.848 | 0.896 | 0.643 |
| None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.464 | 31.936 | 127.808 | 228.832 | 0.954 | 0.559 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 29.472 | 33.856 | 113.152 | 215.072 | 0.871 | 0.526 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.496 | 32.160 | 116.576 | 231.744 | 0.948 | 0.503 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.464 | 31.904 | 116.320 | 229.824 | 0.955 | 0.506 |
| None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.480 | 61.440 | 176.448 | 345.312 | 0.659 | 0.511 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 38.304 | 59.424 | 169.312 | 371.360 | 0.645 | 0.456 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.960 | 61.760 | 176.512 | 358.912 | 0.663 | 0.492 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.352 | 61.696 | 176.512 | 344.928 | 0.654 | 0.512 |
| None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 316.224 | 357.728 | 905.728 | 1668.448 | 0.884 | 0.543 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 199.904 | 248.416 | 636.544 | 1109.088 | 0.805 | 0.574 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 314.880 | 363.616 | 906.304 | 1658.176 | 0.866 | 0.547 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 316.160 | 354.368 | 906.080 | 1649.024 | 0.892 | 0.549 |
| None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.912 | 739.840 | 1555.808 | 2521.952 | 0.604 | 0.617 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 279.776 | 463.904 | 1068.928 | 1849.888 | 0.603 | 0.578 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.080 | 748.960 | 1553.504 | 2629.888 | 0.596 | 0.591 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.208 | 740.608 | 1558.880 | 2524.960 | 0.602 | 0.617 |
| None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 33.568 | 41.280 | 170.016 | 147.584 | 0.813 | 1.152 |
| None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 30.688 | 43.040 | 159.552 | 146.720 | 0.713 | 1.087 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 34.112 | 41.504 | 170.112 | 152.672 | 0.822 | 1.114 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 34.240 | 41.152 | 170.272 | 134.976 | 0.832 | 1.261 |
| None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.672 | 76.416 | 295.296 | 263.648 | 0.637 | 1.120 |
| None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 45.088 | 72.576 | 281.920 | 237.664 | 0.621 | 1.186 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.032 | 76.672 | 295.520 | 265.248 | 0.626 | 1.114 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.096 | 76.096 | 295.456 | 262.112 | 0.632 | 1.127 |
| None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 93.920 | 111.232 | 401.568 | 382.944 | 0.844 | 1.049 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 68.192 | 95.232 | 338.752 | 326.816 | 0.716 | 1.037 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 93.984 | 111.840 | 401.856 | 444.224 | 0.840 | 0.905 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 94.176 | 110.496 | 401.600 | 383.136 | 0.852 | 1.048 |
| None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.488 | 227.040 | 727.424 | 739.712 | 0.579 | 0.983 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 95.616 | 169.760 | 616.864 | 574.112 | 0.563 | 1.074 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.680 | 228.672 | 727.616 | 746.048 | 0.576 | 0.975 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.104 | 225.696 | 727.904 | 735.392 | 0.581 | 0.990 |
| None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1227.296 | 1386.656 | 3720.192 | 4539.904 | 0.885 | 0.819 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 691.360 | 831.712 | 2515.872 | 3067.808 | 0.831 | 0.820 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1228.192 | 1403.136 | 3715.520 | 5309.280 | 0.875 | 0.700 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1229.024 | 1384.992 | 3715.904 | 4550.368 | 0.887 | 0.817 |
| None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1784.832 | 2865.888 | 6539.840 | 8460.224 | 0.623 | 0.773 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1017.408 | 1660.480 | 4369.824 | 5056.992 | 0.613 | 0.864 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1792.448 | 2904.864 | 6546.080 | 8537.024 | 0.617 | 0.767 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1795.552 | 2856.864 | 6544.672 | 8400.160 | 0.629 | 0.779 |
| None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 38.880 | 148.832 | 179.936 | 0.881 | 0.827 |
| None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 31.168 | 38.080 | 138.528 | 167.552 | 0.818 | 0.827 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 39.168 | 148.512 | 181.248 | 0.874 | 0.819 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 38.784 | 148.864 | 180.224 | 0.883 | 0.826 |
| None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.832 | 76.352 | 253.632 | 295.968 | 0.640 | 0.857 |
| None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 45.760 | 65.792 | 239.040 | 290.752 | 0.696 | 0.822 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.768 | 76.576 | 253.312 | 304.032 | 0.637 | 0.833 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.768 | 76.192 | 253.600 | 296.096 | 0.640 | 0.856 |
| None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.728 | 109.728 | 357.696 | 498.912 | 0.854 | 0.717 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 68.704 | 92.288 | 295.616 | 386.240 | 0.744 | 0.765 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.632 | 111.392 | 357.408 | 512.448 | 0.841 | 0.697 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.280 | 109.952 | 357.696 | 501.440 | 0.848 | 0.713 |
| None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.392 | 230.496 | 612.224 | 807.552 | 0.570 | 0.758 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 96.512 | 165.184 | 502.624 | 672.384 | 0.584 | 0.748 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.360 | 232.608 | 612.064 | 832.320 | 0.565 | 0.735 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.008 | 230.528 | 612.640 | 804.320 | 0.568 | 0.762 |
| None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1227.968 | 1377.408 | 3477.920 | 5324.384 | 0.892 | 0.653 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 695.264 | 824.544 | 2268.224 | 3210.208 | 0.843 | 0.707 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1228.640 | 1404.576 | 3476.832 | 5463.456 | 0.875 | 0.636 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1228.416 | 1378.752 | 3478.048 | 5367.712 | 0.891 | 0.648 |
| None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1788.736 | 2867.712 | 6039.520 | 8616.256 | 0.624 | 0.701 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1021.952 | 1653.824 | 3866.208 | 5306.848 | 0.618 | 0.729 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1786.752 | 2896.352 | 6044.128 | 8871.360 | 0.617 | 0.681 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1786.080 | 2868.672 | 6040.160 | 8550.144 | 0.623 | 0.706 |
| None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 57.504 | 71.552 | 312.768 | 255.040 | 0.804 | 1.226 |
| None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 49.472 | 71.104 | 285.696 | 243.520 | 0.696 | 1.173 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 58.112 | 72.896 | 312.768 | 288.256 | 0.797 | 1.085 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 57.952 | 71.680 | 312.768 | 255.552 | 0.808 | 1.224 |
| None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.336 | 144.256 | 580.128 | 500.160 | 0.571 | 1.160 |
| None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 76.160 | 123.712 | 552.544 | 447.648 | 0.616 | 1.234 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.400 | 145.184 | 580.032 | 504.032 | 0.568 | 1.151 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.368 | 143.904 | 580.192 | 499.936 | 0.572 | 1.161 |
| None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.216 | 209.568 | 787.872 | 747.712 | 0.846 | 1.054 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 121.984 | 168.256 | 651.968 | 628.256 | 0.725 | 1.038 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.088 | 211.488 | 788.320 | 864.352 | 0.837 | 0.912 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.440 | 208.576 | 787.424 | 749.120 | 0.851 | 1.051 |
| None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 249.472 | 441.376 | 1405.440 | 1431.648 | 0.565 | 0.982 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 172.960 | 312.064 | 1172.064 | 1096.448 | 0.554 | 1.069 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 249.632 | 446.336 | 1405.408 | 1448.480 | 0.559 | 0.970 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 250.944 | 440.128 | 1406.624 | 1421.952 | 0.570 | 0.989 |
| None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2418.720 | 2747.936 | 7330.432 | 9023.712 | 0.880 | 0.812 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 1353.696 | 1608.480 | 4941.696 | 6078.752 | 0.842 | 0.813 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2427.456 | 2746.816 | 7329.792 | 10539.968 | 0.884 | 0.695 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2426.688 | 2763.168 | 7336.256 | 9057.536 | 0.878 | 0.810 |
| None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3554.240 | 5634.400 | 12919.872 | 16843.489 | 0.631 | 0.767 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 2003.648 | 3250.784 | 8610.144 | 10015.424 | 0.616 | 0.860 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3582.080 | 5710.944 | 12923.328 | 17011.871 | 0.627 | 0.760 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3581.920 | 5618.144 | 12934.528 | 16745.888 | 0.638 | 0.772 |
| None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.120 | 71.232 | 269.760 | 295.680 | 0.802 | 0.912 |
| None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 49.408 | 65.312 | 242.304 | 253.952 | 0.756 | 0.954 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.504 | 72.544 | 269.632 | 298.976 | 0.793 | 0.902 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.760 | 71.040 | 269.600 | 296.640 | 0.813 | 0.909 |
| None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 82.336 | 147.168 | 466.080 | 487.456 | 0.559 | 0.956 |
| None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 76.704 | 115.040 | 435.392 | 453.248 | 0.667 | 0.961 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 81.856 | 147.424 | 465.920 | 499.552 | 0.555 | 0.933 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 81.760 | 146.656 | 466.176 | 485.984 | 0.557 | 0.959 |
| None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 176.608 | 206.976 | 678.080 | 866.976 | 0.853 | 0.782 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 121.664 | 164.768 | 538.240 | 636.160 | 0.738 | 0.846 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 176.608 | 209.664 | 677.696 | 883.424 | 0.842 | 0.767 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 177.440 | 207.840 | 677.248 | 868.288 | 0.854 | 0.780 |
| None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 250.272 | 449.536 | 1163.424 | 1420.832 | 0.557 | 0.819 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 173.472 | 305.376 | 929.408 | 1104.544 | 0.568 | 0.841 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 249.376 | 454.976 | 1163.648 | 1455.296 | 0.548 | 0.800 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 250.368 | 450.144 | 1163.520 | 1409.984 | 0.556 | 0.825 |
| None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2416.576 | 2726.208 | 6835.520 | 10442.784 | 0.886 | 0.655 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 1357.440 | 1590.752 | 4433.664 | 5975.296 | 0.853 | 0.742 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2427.360 | 2747.040 | 6853.056 | 10670.784 | 0.884 | 0.642 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2441.120 | 2718.944 | 6836.640 | 10433.792 | 0.898 | 0.655 |
| None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3555.392 | 5620.960 | 11944.000 | 16504.801 | 0.633 | 0.724 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 2010.848 | 3241.152 | 7636.064 | 9870.464 | 0.620 | 0.774 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3557.440 | 5688.352 | 11935.744 | 17090.496 | 0.625 | 0.698 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3562.720 | 5630.432 | 11939.168 | 16392.033 | 0.633 | 0.728 |
</details>
### Perf after this PR
**FWD**
| Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) |
|---------|-----------|---------------|------------|----------------|----------------------------|
| Average | 0.776 | | | | |
| Max | 1.006 | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) |
| Min | 0.566 | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) |
**BWD**
| Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) |
|---------|-----------|-------------|------------|----------------|-----------------------------|
| Average | 0.817 | | | | |
| Max | 1.150 | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) |
| Min | 0.454 | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) |
<details>
<summary> Full performance sweep </summary>
| score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup |
|---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------|
| None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.680 | 17.056 | 64.544 | 73.376 | 0.919 | 0.880 |
| None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.712 | 19.872 | 65.408 | 72.864 | 0.791 | 0.898 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 16.160 | 17.280 | 64.896 | 73.888 | 0.935 | 0.878 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 16.192 | 17.120 | 64.896 | 75.424 | 0.946 | 0.860 |
| None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.648 | 22.496 | 89.184 | 82.592 | 0.873 | 1.080 |
| None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 20.320 | 26.816 | 91.264 | 82.880 | 0.758 | 1.101 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 20.096 | 22.528 | 89.184 | 83.776 | 0.892 | 1.065 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.680 | 22.432 | 89.184 | 120.096 | 0.877 | 0.743 |
| None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.384 | 32.512 | 119.232 | 128.960 | 0.996 | 0.925 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.176 | 37.248 | 113.664 | 119.520 | 0.810 | 0.951 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.512 | 32.928 | 119.264 | 131.456 | 0.987 | 0.907 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.448 | 32.704 | 119.200 | 128.352 | 0.992 | 0.929 |
| None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 41.952 | 62.176 | 199.040 | 214.304 | 0.675 | 0.929 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 39.744 | 62.880 | 189.504 | 179.968 | 0.632 | 1.053 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 41.472 | 62.784 | 199.136 | 217.664 | 0.661 | 0.915 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 42.048 | 61.952 | 199.168 | 214.496 | 0.679 | 0.929 |
| None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 341.184 | 357.632 | 980.256 | 1328.896 | 0.954 | 0.738 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 212.576 | 252.960 | 673.888 | 824.864 | 0.840 | 0.817 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 340.000 | 363.296 | 980.768 | 1375.808 | 0.936 | 0.713 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 340.768 | 356.832 | 980.960 | 1326.272 | 0.955 | 0.740 |
| None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 459.392 | 737.120 | 1678.240 | 2205.248 | 0.623 | 0.761 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 292.672 | 468.096 | 1178.016 | 1371.584 | 0.625 | 0.859 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 462.144 | 745.312 | 1680.000 | 2252.512 | 0.620 | 0.746 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 462.112 | 736.576 | 1679.008 | 2216.480 | 0.627 | 0.758 |
| None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.064 | 16.704 | 105.120 | 120.768 | 0.962 | 0.870 |
| None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.552 | 18.144 | 107.136 | 121.696 | 0.857 | 0.880 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.096 | 16.768 | 102.688 | 120.864 | 0.960 | 0.850 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.032 | 16.576 | 104.736 | 124.672 | 0.967 | 0.840 |
| None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.392 | 21.952 | 104.736 | 174.656 | 0.883 | 0.600 |
| None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 20.128 | 23.712 | 105.216 | 199.008 | 0.849 | 0.529 |
| relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.904 | 21.888 | 103.744 | 179.520 | 0.909 | 0.578 |
| head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.968 | 21.952 | 104.640 | 177.312 | 0.910 | 0.590 |
| None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.096 | 31.904 | 118.720 | 231.968 | 1.006 | 0.512 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.528 | 33.952 | 112.480 | 218.304 | 0.899 | 0.515 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.160 | 32.224 | 118.752 | 237.312 | 0.998 | 0.500 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.128 | 32.032 | 118.240 | 233.120 | 1.003 | 0.507 |
| None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.312 | 61.280 | 177.408 | 350.688 | 0.674 | 0.506 |
| None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 39.552 | 59.360 | 168.832 | 371.488 | 0.666 | 0.454 |
| relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.984 | 61.696 | 177.376 | 360.416 | 0.680 | 0.492 |
| head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.312 | 61.760 | 177.184 | 355.744 | 0.669 | 0.498 |
| None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 339.744 | 357.888 | 939.712 | 1665.376 | 0.949 | 0.564 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 212.608 | 248.832 | 633.280 | 1122.848 | 0.854 | 0.564 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 339.712 | 363.232 | 940.448 | 1689.440 | 0.935 | 0.557 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 341.056 | 355.264 | 940.128 | 1641.152 | 0.960 | 0.573 |
| None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.736 | 741.024 | 1569.824 | 2559.552 | 0.622 | 0.613 |
| None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 293.856 | 464.192 | 1066.240 | 1840.416 | 0.633 | 0.579 |
| relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.704 | 753.152 | 1570.112 | 2641.088 | 0.612 | 0.594 |
| head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.832 | 745.536 | 1570.144 | 2602.560 | 0.618 | 0.603 |
| None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.680 | 41.280 | 171.840 | 158.176 | 0.864 | 1.086 |
| None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 31.360 | 42.976 | 158.912 | 139.264 | 0.730 | 1.141 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.168 | 41.600 | 171.648 | 161.344 | 0.845 | 1.064 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.136 | 41.152 | 171.808 | 158.336 | 0.854 | 1.085 |
| None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.832 | 76.384 | 295.680 | 277.696 | 0.639 | 1.065 |
| None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 45.632 | 72.512 | 281.760 | 250.752 | 0.629 | 1.124 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 49.504 | 76.608 | 295.584 | 279.712 | 0.646 | 1.057 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.864 | 75.904 | 295.456 | 277.568 | 0.644 | 1.064 |
| None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 99.392 | 111.232 | 408.640 | 442.656 | 0.894 | 0.923 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 71.392 | 95.168 | 338.784 | 341.760 | 0.750 | 0.991 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 99.808 | 112.256 | 408.608 | 456.160 | 0.889 | 0.896 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 100.032 | 110.816 | 408.512 | 444.192 | 0.903 | 0.920 |
| None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.040 | 226.112 | 726.880 | 774.176 | 0.597 | 0.939 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 99.904 | 169.696 | 616.448 | 607.104 | 0.589 | 1.015 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.488 | 228.384 | 727.776 | 782.368 | 0.593 | 0.930 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.744 | 225.664 | 728.000 | 773.600 | 0.602 | 0.941 |
| None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1324.192 | 1387.808 | 3866.944 | 5217.184 | 0.954 | 0.741 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 738.464 | 832.608 | 2507.392 | 3146.688 | 0.887 | 0.797 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1326.016 | 1404.256 | 3867.872 | 5382.624 | 0.944 | 0.719 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1326.144 | 1386.688 | 3867.552 | 5203.264 | 0.956 | 0.743 |
| None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1847.488 | 2866.336 | 6612.704 | 8597.696 | 0.645 | 0.769 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1066.592 | 1660.640 | 4357.696 | 5174.016 | 0.642 | 0.842 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1850.464 | 2905.408 | 6616.928 | 8793.280 | 0.637 | 0.752 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1848.896 | 2834.720 | 6623.872 | 8637.920 | 0.652 | 0.767 |
| None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.384 | 38.656 | 150.336 | 182.624 | 0.941 | 0.823 |
| None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 31.360 | 38.112 | 137.664 | 171.840 | 0.823 | 0.801 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.608 | 39.040 | 150.528 | 183.872 | 0.938 | 0.819 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.064 | 38.656 | 150.560 | 183.520 | 0.933 | 0.820 |
| None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.344 | 76.352 | 253.920 | 301.440 | 0.646 | 0.842 |
| None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 46.720 | 65.824 | 239.424 | 296.384 | 0.710 | 0.808 |
| relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.248 | 76.416 | 253.728 | 307.808 | 0.644 | 0.824 |
| head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.376 | 76.288 | 253.728 | 304.736 | 0.647 | 0.833 |
| None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.264 | 110.144 | 364.960 | 503.072 | 0.901 | 0.725 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 71.136 | 92.384 | 294.432 | 393.056 | 0.770 | 0.749 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.200 | 111.360 | 365.152 | 512.640 | 0.891 | 0.712 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.264 | 110.240 | 365.088 | 504.224 | 0.900 | 0.724 |
| None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.680 | 230.336 | 613.472 | 816.896 | 0.589 | 0.751 |
| None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 100.256 | 165.088 | 502.144 | 676.480 | 0.607 | 0.742 |
| relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.008 | 232.480 | 613.184 | 836.672 | 0.581 | 0.733 |
| head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.232 | 230.624 | 613.536 | 827.136 | 0.586 | 0.742 |
| None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1324.064 | 1378.688 | 3631.808 | 5308.384 | 0.960 | 0.684 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 731.776 | 826.688 | 2263.168 | 3241.344 | 0.885 | 0.698 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1316.128 | 1403.200 | 3625.088 | 5550.688 | 0.938 | 0.653 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1311.904 | 1378.880 | 3616.320 | 5353.696 | 0.951 | 0.675 |
| None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1837.856 | 2887.392 | 6121.632 | 8586.656 | 0.637 | 0.713 |
| None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1066.976 | 1654.368 | 3843.136 | 5291.040 | 0.645 | 0.726 |
| relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1854.208 | 2896.832 | 6130.112 | 8745.984 | 0.640 | 0.701 |
| head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1860.512 | 2889.344 | 6135.648 | 8750.592 | 0.644 | 0.701 |
| None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 60.640 | 71.552 | 315.968 | 296.512 | 0.847 | 1.066 |
| None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 50.784 | 71.040 | 284.288 | 258.880 | 0.715 | 1.098 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 61.312 | 72.704 | 315.680 | 302.016 | 0.843 | 1.045 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 60.800 | 71.776 | 316.320 | 297.152 | 0.847 | 1.065 |
| None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.576 | 144.416 | 580.576 | 535.936 | 0.586 | 1.083 |
| None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 76.064 | 123.648 | 553.344 | 481.376 | 0.615 | 1.150 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.160 | 145.248 | 581.024 | 540.000 | 0.579 | 1.076 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.512 | 143.552 | 581.088 | 535.776 | 0.589 | 1.085 |
| None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.152 | 209.408 | 798.400 | 868.704 | 0.903 | 0.919 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 127.552 | 168.800 | 650.816 | 663.328 | 0.756 | 0.981 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.376 | 211.360 | 798.080 | 895.552 | 0.896 | 0.891 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.440 | 208.576 | 797.888 | 873.152 | 0.908 | 0.914 |
| None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 257.536 | 441.760 | 1408.960 | 1514.720 | 0.583 | 0.930 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 179.328 | 312.096 | 1170.368 | 1177.472 | 0.575 | 0.994 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 259.264 | 446.944 | 1408.768 | 1530.400 | 0.580 | 0.921 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 258.080 | 440.480 | 1408.864 | 1514.144 | 0.586 | 0.930 |
| None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2595.808 | 2771.456 | 7616.704 | 10405.248 | 0.937 | 0.732 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 1435.744 | 1610.336 | 4927.520 | 6220.000 | 0.892 | 0.792 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2595.264 | 2745.056 | 7611.232 | 10631.392 | 0.945 | 0.716 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2576.256 | 2735.456 | 7626.400 | 10346.976 | 0.942 | 0.737 |
| None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3679.744 | 5634.816 | 13077.056 | 17182.528 | 0.653 | 0.761 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 2099.360 | 3250.176 | 8589.664 | 10236.672 | 0.646 | 0.839 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3676.800 | 5716.288 | 13073.088 | 17311.071 | 0.643 | 0.755 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3679.136 | 5570.496 | 13070.720 | 17192.863 | 0.660 | 0.760 |
| None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.600 | 71.008 | 272.320 | 300.000 | 0.868 | 0.908 |
| None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 50.176 | 65.344 | 241.568 | 258.912 | 0.768 | 0.933 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.120 | 72.512 | 272.672 | 305.408 | 0.843 | 0.893 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.248 | 71.136 | 272.640 | 301.120 | 0.861 | 0.905 |
| None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.872 | 146.784 | 466.912 | 496.832 | 0.571 | 0.940 |
| None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 76.704 | 115.072 | 435.584 | 462.112 | 0.667 | 0.943 |
| relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.392 | 147.392 | 466.656 | 504.448 | 0.566 | 0.925 |
| head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.360 | 146.688 | 466.656 | 499.040 | 0.568 | 0.935 |
| None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 189.024 | 207.584 | 684.768 | 873.568 | 0.911 | 0.784 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 126.944 | 164.288 | 536.192 | 645.984 | 0.773 | 0.830 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 188.768 | 209.760 | 684.096 | 897.504 | 0.900 | 0.762 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 189.408 | 207.776 | 685.024 | 876.384 | 0.912 | 0.782 |
| None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 259.168 | 449.536 | 1167.936 | 1433.280 | 0.577 | 0.815 |
| None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 180.000 | 305.312 | 928.000 | 1113.920 | 0.590 | 0.833 |
| relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 258.464 | 455.136 | 1167.808 | 1462.848 | 0.568 | 0.798 |
| head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 257.824 | 450.208 | 1167.744 | 1448.000 | 0.573 | 0.806 |
| None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2598.368 | 2729.120 | 7134.400 | 10381.632 | 0.952 | 0.687 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 1435.456 | 1591.040 | 4424.768 | 6035.808 | 0.902 | 0.733 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2594.752 | 2725.952 | 7128.384 | 10822.496 | 0.952 | 0.659 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2597.888 | 2716.960 | 7101.568 | 10385.440 | 0.956 | 0.684 |
| None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3647.648 | 5581.632 | 12089.952 | 16667.233 | 0.654 | 0.725 |
| None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 2093.952 | 3241.440 | 7579.392 | 9847.936 | 0.646 | 0.770 |
| relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3650.528 | 5650.688 | 12105.568 | 16963.680 | 0.646 | 0.714 |
| head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3680.064 | 5585.312 | 12117.504 | 16935.040 | 0.659 | 0.716 |
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135505
Approved by: https://github.com/Chillee
2024-09-10 09:30:00 +00:00
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
|
|
|
@common_utils.parametrize("head_dims", test_Hq_Hkv)
|
|
|
|
|
@common_utils.parametrize("batch_dims", test_Bq_Bkv)
|
|
|
|
|
@common_utils.parametrize("score_mod", test_score_mods)
|
|
|
|
|
def test_kv_batch_broadcast(
|
|
|
|
|
self,
|
|
|
|
|
dtype: torch.dtype,
|
|
|
|
|
head_dims: Tuple[int, int],
|
|
|
|
|
batch_dims: Tuple[int, int],
|
|
|
|
|
score_mod: Callable,
|
|
|
|
|
):
|
|
|
|
|
Hq, Hkv = head_dims
|
|
|
|
|
assert Hq % Hkv == 0
|
|
|
|
|
|
|
|
|
|
Bq, Bkv = batch_dims
|
|
|
|
|
assert Bq > 1 and Bkv == 1
|
|
|
|
|
|
|
|
|
|
self.run_test(
|
|
|
|
|
score_mod,
|
|
|
|
|
dtype,
|
|
|
|
|
Bq,
|
|
|
|
|
Hq,
|
|
|
|
|
1,
|
|
|
|
|
D,
|
|
|
|
|
Bkv,
|
|
|
|
|
Hkv,
|
|
|
|
|
S,
|
|
|
|
|
D,
|
|
|
|
|
)
|
|
|
|
|
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
|
|
|
def test_skip_odd_keys(self, dtype: torch.dtype):
|
|
|
|
|
def score_mod(score, b, h, q, kv):
|
|
|
|
|
return torch.where(kv % 2 == 0, score, float("-inf"))
|
|
|
|
|
|
|
|
|
|
self.run_test(score_mod, dtype)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
|
|
|
def test_function_composition(self, dtype: torch.dtype):
|
|
|
|
|
def score_mod_1(score, b, h, m, n):
|
|
|
|
|
return score + (m - n)
|
|
|
|
|
|
|
|
|
|
def score_mod_2(score, b, h, m, n):
|
|
|
|
|
return torch.where(m <= n, score, float("-inf"))
|
|
|
|
|
|
2024-07-17 22:24:22 +00:00
|
|
|
def composed_score_mod(score, b, h, m, n):
|
|
|
|
|
return score_mod_2(score_mod_1(score, b, h, m, n), b, h, m, n)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
self.run_test(composed_score_mod, dtype)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
|
|
|
def test_captured_buffers(self, dtype: torch.dtype):
|
2024-07-24 01:36:12 +00:00
|
|
|
head_offset = torch.rand(Hq, device="cuda", dtype=dtype)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
def score_mod(score, b, h, m, n):
|
|
|
|
|
return score + head_offset[h]
|
|
|
|
|
|
|
|
|
|
self.run_test(score_mod, dtype)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
|
|
|
def test_captured_buffers_all_dims(self, dtype: torch.dtype):
|
2024-07-24 01:36:12 +00:00
|
|
|
head_scale = torch.randn(Hq, device="cuda")
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
batch_scale = torch.randn(B, device="cuda")
|
|
|
|
|
kv_scale = torch.randn(S, device="cuda")
|
2024-07-24 01:36:12 +00:00
|
|
|
q_scale = torch.randn(1, device="cuda")
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
def all_bias(score, batch, head, token_q, token_kv):
|
|
|
|
|
score = score + kv_scale[token_kv]
|
|
|
|
|
score = score + q_scale[token_q]
|
|
|
|
|
score = score + head_scale[head]
|
2024-07-24 01:36:12 +00:00
|
|
|
score = score + batch_scale[batch]
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
return score
|
|
|
|
|
|
|
|
|
|
self.run_test(all_bias, dtype)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
|
|
|
def test_seq_masking(self, dtype):
|
|
|
|
|
seq_idx = torch.zeros(S, device="cuda", dtype=torch.bool)
|
|
|
|
|
seq_idx[S // 2 :] = 1
|
|
|
|
|
|
|
|
|
|
def seq_mask_mod(score, b, h, q, kv):
|
|
|
|
|
return torch.where(seq_idx[q] == seq_idx[kv], score, float("-inf"))
|
|
|
|
|
|
|
|
|
|
self.run_test(seq_mask_mod, dtype)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
|
|
|
def test_load_from_bias_seq_only(self, dtype):
|
2024-07-24 01:36:12 +00:00
|
|
|
bias = torch.randn(1, S, device="cuda", dtype=dtype)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
def bias_mod(score, b, h, q, kv):
|
|
|
|
|
return score + bias[q, kv]
|
|
|
|
|
|
|
|
|
|
self.run_test(bias_mod, dtype)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
|
|
|
def test_load_from_bias_seq_batch(self, dtype):
|
2024-07-24 01:36:12 +00:00
|
|
|
bias = torch.randn(B, 1, S, device="cuda", dtype=dtype)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
def bias_mod(score, b, h, q, kv):
|
|
|
|
|
return score + bias[b, q, kv]
|
|
|
|
|
|
|
|
|
|
self.run_test(bias_mod, dtype)
|
|
|
|
|
|
2024-09-19 18:02:39 +00:00
|
|
|
@skipIfRocm
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
|
|
|
def test_load_from_bias_head_seq_batch(self, dtype):
|
|
|
|
|
bias = torch.randn(
|
|
|
|
|
B,
|
2024-07-24 01:36:12 +00:00
|
|
|
Hq,
|
|
|
|
|
1,
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
S,
|
|
|
|
|
device="cuda",
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def bias_mod(score, b, h, q, kv):
|
|
|
|
|
return score + bias[b, h, q, kv]
|
|
|
|
|
|
|
|
|
|
self.run_test(bias_mod, dtype)
|
|
|
|
|
|
2024-08-21 19:03:22 +00:00
|
|
|
# TODO this config segfaults with Triton without:
|
|
|
|
|
# https://github.com/triton-lang/triton/pull/4540
|
|
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("score_mod", test_score_mods)
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
|
|
|
@common_utils.parametrize("head_dims", [(D, D // 2), (D // 2, D)])
|
|
|
|
|
def test_non_equal_head_dims(self, dtype, score_mod, head_dims):
|
|
|
|
|
qk_d, v_d = head_dims
|
|
|
|
|
context = nullcontext() if qk_d > v_d else self.assertRaises(ValueError)
|
|
|
|
|
with context:
|
|
|
|
|
self.run_test(score_mod, dtype, B, Hq, 1, qk_d, B, Hkv, S, V_D=v_d)
|
|
|
|
|
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
|
|
|
def test_subgraph_respect_decompostion(self, dtype):
|
|
|
|
|
from torch._decomp import core_aten_decompositions
|
|
|
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
|
|
|
|
|
|
def score_mod_func(score, b, h, q, kv):
|
|
|
|
|
return score - q // (1 + kv)
|
|
|
|
|
|
|
|
|
|
make_kv = functools.partial(
|
|
|
|
|
torch.randn,
|
|
|
|
|
(2, 2, 128, 4),
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
device="cuda",
|
|
|
|
|
requires_grad=True,
|
|
|
|
|
)
|
|
|
|
|
make_q = functools.partial(
|
|
|
|
|
torch.randn,
|
|
|
|
|
(2, 2, 8, 4),
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
device="cuda",
|
|
|
|
|
requires_grad=True,
|
|
|
|
|
)
|
|
|
|
|
query, key, value = make_q(), make_kv(), make_kv()
|
|
|
|
|
# floor_div is not decomposed in decompostion_table is empty
|
2024-07-17 22:24:22 +00:00
|
|
|
attention = functools.partial(flex_attention, score_mod=score_mod_func)
|
|
|
|
|
gm = make_fx(attention, decomposition_table={})(query, key, value)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
self.assertExpectedInline(
|
|
|
|
|
gm.sdpa_score0.code.strip(),
|
|
|
|
|
"""\
|
|
|
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|
|
|
|
add = torch.ops.aten.add.Tensor(arg4_1, 1); arg4_1 = None
|
|
|
|
|
floor_divide = torch.ops.aten.floor_divide.default(arg3_1, add); arg3_1 = add = None
|
|
|
|
|
sub = torch.ops.aten.sub.Tensor(arg0_1, floor_divide); arg0_1 = floor_divide = None
|
|
|
|
|
return sub""",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# floor_div is decomposed for core_aten_decompositions
|
2024-07-17 22:24:22 +00:00
|
|
|
gm = make_fx(attention, decomposition_table=core_aten_decompositions())(
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
query, key, value
|
|
|
|
|
)
|
|
|
|
|
self.assertExpectedInline(
|
|
|
|
|
gm.sdpa_score0.code.strip(),
|
|
|
|
|
"""\
|
|
|
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|
|
|
|
add = torch.ops.aten.add.Tensor(arg4_1, 1); arg4_1 = None
|
|
|
|
|
div = torch.ops.aten.div.Tensor_mode(arg3_1, add, rounding_mode = 'floor'); arg3_1 = add = None
|
|
|
|
|
sub = torch.ops.aten.sub.Tensor(arg0_1, div); arg0_1 = div = None
|
|
|
|
|
return sub""",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
|
|
|
def test_silu_on_score(self, dtype):
|
|
|
|
|
def silu_score(score, b, h, q, kv):
|
|
|
|
|
return torch.nn.functional.silu(score)
|
|
|
|
|
|
|
|
|
|
self.run_test(silu_score, dtype)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
|
|
|
def test_padded_dense_causal(self, dtype):
|
|
|
|
|
seq_len = torch.arange(B, device="cuda", dtype=torch.int32) + 1
|
|
|
|
|
|
|
|
|
|
def create_padded_dense_wrapper(orig_score_mod):
|
|
|
|
|
def njt_score_mod(qk, b, h, q, kv):
|
|
|
|
|
return torch.where(
|
|
|
|
|
qk <= seq_len[b], orig_score_mod(qk, b, h, q, kv), -float("inf")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return njt_score_mod
|
|
|
|
|
|
|
|
|
|
causal_njt = create_padded_dense_wrapper(_causal)
|
|
|
|
|
|
|
|
|
|
self.run_test(causal_njt, dtype)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
|
|
|
def test_captured_scale(self, dtype):
|
|
|
|
|
scale = torch.ones((), device="cuda", dtype=torch.int32)
|
|
|
|
|
|
|
|
|
|
def score_mod_scale(qk, b, h, q, kv):
|
|
|
|
|
return qk + scale
|
|
|
|
|
|
|
|
|
|
self.run_test(score_mod_scale, dtype)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
|
|
|
def test_recompile_changed_score_mod(self, dtype):
|
|
|
|
|
scale = torch.ones((), device="cuda", dtype=torch.int32)
|
|
|
|
|
ADD = True
|
|
|
|
|
|
|
|
|
|
def score_mod_scale(qk, b, h, q, kv):
|
|
|
|
|
if ADD:
|
|
|
|
|
return qk + scale
|
|
|
|
|
else:
|
|
|
|
|
return qk * scale
|
|
|
|
|
|
|
|
|
|
self.run_test(score_mod_scale, dtype)
|
|
|
|
|
ADD = False
|
|
|
|
|
self.run_test(score_mod_scale, dtype)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@expectedFailure # If we capture a tensor then we can perform a reduction on it, and that shouldn't be allowed
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes_fast)
|
|
|
|
|
def test_captured_reduction(self, dtype):
|
|
|
|
|
scale = torch.randn((B, 8), device="cuda")
|
|
|
|
|
|
|
|
|
|
def score_mod_scale(qk, b, h, q, kv):
|
|
|
|
|
return qk + scale[b].sum(dim=-1)
|
|
|
|
|
|
|
|
|
|
self.run_test(score_mod_scale, dtype)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
def test_multiple_score_mod_calls(self):
|
|
|
|
|
query = torch.randn((1, 8, 4, 64), dtype=torch.float32, device="cuda")
|
|
|
|
|
keys = [
|
|
|
|
|
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
|
|
|
|
|
for _ in range(2)
|
|
|
|
|
]
|
|
|
|
|
values = [
|
|
|
|
|
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
|
|
|
|
|
for _ in range(2)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def scoremod_1(qk, b, h, q, kv):
|
|
|
|
|
return qk + (q - kv)
|
|
|
|
|
|
|
|
|
|
def scoremod_2(qk, b, h, q, kv):
|
|
|
|
|
return torch.where(q >= kv, qk, -float("inf"))
|
|
|
|
|
|
|
|
|
|
def f(q, k1, k2, v1, v2):
|
2024-07-17 22:24:22 +00:00
|
|
|
q2 = flex_attention(q, k1, v1, score_mod=scoremod_1)
|
|
|
|
|
return flex_attention(q2, k2, v2, score_mod=scoremod_2)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
out = f(query, *keys, *values)
|
|
|
|
|
out2 = torch.compile(f)(query, *keys, *values)
|
|
|
|
|
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
|
|
|
|
|
torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
def test_multiple_score_mod_calls2(self):
|
|
|
|
|
query = torch.randn((1, 8, 4, 64), dtype=torch.float32, device="cuda")
|
|
|
|
|
keys = [
|
|
|
|
|
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
|
|
|
|
|
for _ in range(3)
|
|
|
|
|
]
|
|
|
|
|
values = [
|
|
|
|
|
torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
|
|
|
|
|
for _ in range(3)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def scoremod_1(qk, b, h, q, kv):
|
|
|
|
|
return qk + (q - kv)
|
|
|
|
|
|
|
|
|
|
def scoremod_2(qk, b, h, q, kv):
|
|
|
|
|
return torch.where(q >= kv, qk, -float("inf"))
|
|
|
|
|
|
2024-07-17 22:24:22 +00:00
|
|
|
attention1 = functools.partial(flex_attention, score_mod=scoremod_1)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
def f(q, k1, k2, k3, v1, v2, v3):
|
|
|
|
|
q2 = attention1(q, k1, v1)
|
2024-07-17 22:24:22 +00:00
|
|
|
q3 = flex_attention(q2, k2, v2, score_mod=scoremod_2)
|
|
|
|
|
return flex_attention(q3, k3, v3, score_mod=scoremod_1)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
out = f(query, *keys, *values)
|
|
|
|
|
out2 = torch.compile(f)(query, *keys, *values)
|
|
|
|
|
self.assertTrue((out - out2).abs().mean() < 1e-2)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
|
|
|
def test_njt_causal(self, dtype):
|
|
|
|
|
offsets = torch.tensor(
|
|
|
|
|
[0, 1024, 1024 + 512, S], device="cuda", dtype=torch.int32
|
|
|
|
|
)
|
|
|
|
|
seq_idx = torch.zeros(S, device="cuda", dtype=torch.int32)
|
|
|
|
|
for idx in range(len(offsets) - 1):
|
|
|
|
|
seq_idx[offsets[idx] : offsets[idx + 1]] = idx
|
|
|
|
|
|
|
|
|
|
def create_njt_wrapper(orig_score_mod, offsets, seq_idx):
|
|
|
|
|
def njt_score_mod(qk, b, h, q, kv):
|
|
|
|
|
q_nested = q - offsets[seq_idx[q]]
|
|
|
|
|
kv_nested = kv - offsets[seq_idx[kv]]
|
|
|
|
|
return orig_score_mod(qk, b, h, q_nested, kv_nested)
|
|
|
|
|
|
|
|
|
|
return njt_score_mod
|
|
|
|
|
|
|
|
|
|
causal_njt = create_njt_wrapper(_causal, offsets, seq_idx)
|
|
|
|
|
|
|
|
|
|
self.run_test(causal_njt, dtype)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
def test_mixed_dtypes_fails(self):
|
|
|
|
|
query = torch.randn((1, 1, 8, 64), dtype=torch.float32, device="cuda")
|
|
|
|
|
key = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
|
|
|
|
|
value = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
ValueError, "Expected query, key, and value to have the same dtype"
|
|
|
|
|
):
|
2024-07-17 22:24:22 +00:00
|
|
|
flex_attention(query, key, value, _identity)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@patch.object(torch._inductor.config, "max_autotune", True)
|
|
|
|
|
def test_max_autotune(self):
|
|
|
|
|
def score_mod(score, b, h, m, n):
|
|
|
|
|
return score * 2
|
|
|
|
|
|
|
|
|
|
self.run_test(score_mod)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
@patch.object(torch._inductor.config, "max_autotune", True)
|
|
|
|
|
def test_max_autotune_with_captured(self):
|
2024-08-10 00:24:34 +00:00
|
|
|
head_scale = torch.randn(Hq, device="cuda")
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
batch_scale = torch.randn(B, device="cuda")
|
|
|
|
|
tok_scale = torch.randn(S, device="cuda")
|
2024-08-10 00:24:34 +00:00
|
|
|
q_scale = torch.randn(1, device="cuda")
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
def bias_mod(score, batch, head, token_q, token_kv):
|
|
|
|
|
score = score + tok_scale[token_kv]
|
|
|
|
|
score = score + q_scale[token_q]
|
|
|
|
|
score = score + batch_scale[batch]
|
|
|
|
|
score = score + head_scale[head]
|
|
|
|
|
return score
|
|
|
|
|
|
|
|
|
|
self.run_test(bias_mod)
|
|
|
|
|
|
2024-08-22 17:28:52 +00:00
|
|
|
@supported_platform
|
|
|
|
|
def test_fully_masked_out_rows_0_check_gqa(self):
|
|
|
|
|
# Ensure fully masked out rows won't cause NaNs.
|
|
|
|
|
query = torch.randn(
|
|
|
|
|
(B, Hq, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
|
|
|
|
)
|
|
|
|
|
key = torch.randn(
|
|
|
|
|
(B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
|
|
|
|
)
|
|
|
|
|
value = torch.randn(
|
|
|
|
|
(B, Hkv, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
M = S // 2
|
|
|
|
|
|
|
|
|
|
def mask_mod(b, h, q, kv):
|
|
|
|
|
return q < M
|
|
|
|
|
|
|
|
|
|
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
|
|
|
|
|
|
|
|
|
|
flex = torch.compile(flex_attention, dynamic=False)
|
|
|
|
|
|
|
|
|
|
out, lse = flex(
|
|
|
|
|
query, key, value, block_mask=block_mask, enable_gqa=True, return_lse=True
|
|
|
|
|
)
|
|
|
|
|
self.assertEqual(out[:, :, M:, :].sum(), 0)
|
2024-08-29 02:39:44 +00:00
|
|
|
self.assertTrue((lse[:, :, M:] == -float("inf")).all())
|
2024-08-22 17:28:52 +00:00
|
|
|
|
|
|
|
|
loss = out.sum() + lse.sum()
|
|
|
|
|
loss.backward()
|
|
|
|
|
self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
|
|
|
|
|
|
2024-07-24 01:36:12 +00:00
|
|
|
@supported_platform
|
2024-08-01 03:04:45 +00:00
|
|
|
def test_windowed_no_mask_vs_sdpa(self):
|
|
|
|
|
score_mod = _generate_windowed(1000)
|
|
|
|
|
attention = functools.partial(flex_attention, score_mod=score_mod)
|
|
|
|
|
|
|
|
|
|
sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
|
2024-07-24 01:36:12 +00:00
|
|
|
|
|
|
|
|
sdpa_attention = functools.partial(
|
2024-08-01 03:04:45 +00:00
|
|
|
torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
|
2024-07-24 01:36:12 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
2024-08-01 03:04:45 +00:00
|
|
|
def test_windowed_full_mask_vs_sdpa(self):
|
2024-07-24 01:36:12 +00:00
|
|
|
def mask_mod(b, h, q, kv):
|
2024-08-01 03:04:45 +00:00
|
|
|
return q + 1000 >= kv
|
|
|
|
|
|
|
|
|
|
score_mod = _generate_windowed(1000)
|
2024-07-24 01:36:12 +00:00
|
|
|
|
|
|
|
|
block_mask = create_block_mask(mask_mod, 1, 1, 8, S)
|
|
|
|
|
attention = functools.partial(
|
2024-08-01 03:04:45 +00:00
|
|
|
flex_attention, block_mask=block_mask, score_mod=score_mod
|
2024-07-24 01:36:12 +00:00
|
|
|
)
|
|
|
|
|
|
2024-08-01 03:04:45 +00:00
|
|
|
sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
|
2024-07-24 01:36:12 +00:00
|
|
|
sdpa_attention = functools.partial(
|
2024-08-01 03:04:45 +00:00
|
|
|
torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
|
2024-07-24 01:36:12 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
2024-08-01 03:04:45 +00:00
|
|
|
def test_windowed_partial_block_vs_sdpa(self):
|
2024-07-24 01:36:12 +00:00
|
|
|
def mask_mod(b, h, q, kv):
|
2024-08-01 03:04:45 +00:00
|
|
|
return q + 1000 >= kv
|
2024-07-24 01:36:12 +00:00
|
|
|
|
|
|
|
|
block_mask = create_block_mask(mask_mod, 1, 1, 8, S)
|
|
|
|
|
attention = functools.partial(flex_attention, block_mask=block_mask)
|
|
|
|
|
|
2024-08-01 03:04:45 +00:00
|
|
|
sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
|
2024-07-24 01:36:12 +00:00
|
|
|
sdpa_attention = functools.partial(
|
2024-08-01 03:04:45 +00:00
|
|
|
torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
|
2024-07-24 01:36:12 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8)
|
|
|
|
|
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
@supported_platform
|
|
|
|
|
@common_utils.parametrize("dtype", test_dtypes)
|
|
|
|
|
@common_utils.parametrize("score_mod", [_identity, _causal])
|
|
|
|
|
def test_logsumexp_correctness(self, dtype, score_mod):
|
|
|
|
|
make_kv = functools.partial(
|
|
|
|
|
torch.randn,
|
|
|
|
|
(B, Hkv, S, D),
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
device="cuda",
|
|
|
|
|
requires_grad=True,
|
|
|
|
|
)
|
|
|
|
|
make_q = functools.partial(
|
|
|
|
|
torch.randn,
|
|
|
|
|
(B, Hkv, Hq // Hkv, D),
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
device="cuda",
|
|
|
|
|
requires_grad=True,
|
|
|
|
|
)
|
|
|
|
|
q, k, v = make_q(), make_kv(), make_kv()
|
|
|
|
|
|
|
|
|
|
@torch.compile
|
2024-08-10 23:01:14 +00:00
|
|
|
def sdpa_hop(q, k, v, score_mod):
|
|
|
|
|
return flex_attention(q, k, v, score_mod, return_lse=True)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
@torch.compile(backend="aot_eager")
|
2024-08-10 23:01:14 +00:00
|
|
|
def eager_sdpa_hop(q, k, v, score_mod):
|
|
|
|
|
return flex_attention(q, k, v, score_mod, return_lse=True)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
ref_out, ref_lse = eager_sdpa_hop(
|
|
|
|
|
q.to(torch.float64),
|
|
|
|
|
k.to(torch.float64),
|
|
|
|
|
v.to(torch.float64),
|
|
|
|
|
score_mod,
|
|
|
|
|
)
|
2024-08-10 23:01:14 +00:00
|
|
|
compiled_out, compiled_lse = sdpa_hop(q, k, v, score_mod)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
self.assertTrue(ref_lse.dtype == torch.float64)
|
|
|
|
|
self.assertTrue(compiled_lse.dtype == torch.float32)
|
|
|
|
|
|
|
|
|
|
tolerance = Tolerances(atol=2e-2, rtol=2e-2)
|
|
|
|
|
torch.testing.assert_close(
|
|
|
|
|
ref_out.to(dtype=torch.float32),
|
|
|
|
|
compiled_out.to(dtype=torch.float32),
|
|
|
|
|
atol=tolerance.atol,
|
|
|
|
|
rtol=tolerance.rtol,
|
|
|
|
|
)
|
|
|
|
|
torch.testing.assert_close(
|
|
|
|
|
ref_lse.to(dtype=torch.float32),
|
|
|
|
|
compiled_lse.to(dtype=torch.float32),
|
|
|
|
|
atol=tolerance.atol,
|
|
|
|
|
rtol=tolerance.rtol,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@supported_platform
|
|
|
|
|
def test_logsumexp_only_return(self):
|
2024-07-24 01:36:12 +00:00
|
|
|
make_q = functools.partial(
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
torch.randn,
|
|
|
|
|
(B, Hkv, Hq // Hkv, D),
|
|
|
|
|
dtype=torch.float32,
|
|
|
|
|
device="cuda",
|
|
|
|
|
requires_grad=True,
|
|
|
|
|
)
|
2024-07-24 01:36:12 +00:00
|
|
|
make_kv = functools.partial(
|
|
|
|
|
torch.randn,
|
|
|
|
|
(B, Hkv, S, D),
|
|
|
|
|
dtype=torch.float32,
|
|
|
|
|
device="cuda",
|
|
|
|
|
requires_grad=True,
|
|
|
|
|
)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
2024-07-24 01:36:12 +00:00
|
|
|
q, k, v = make_q(), make_kv(), make_kv()
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
@torch.compile
|
2024-08-10 23:01:14 +00:00
|
|
|
def func(q, k, v, score_mod):
|
|
|
|
|
_, lse = flex_attention(q, k, v, score_mod, return_lse=True)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
lse_2 = lse * 2
|
|
|
|
|
return lse_2
|
|
|
|
|
|
2024-08-10 23:01:14 +00:00
|
|
|
_, code = run_and_get_code(func, q, k, v, _identity)
|
|
|
|
|
# Ensure that we're still generating the flexattention kernel
|
|
|
|
|
FileCheck().check_count(".run(primals_1, primals_2, primals_3", 1, True).run(
|
|
|
|
|
code[0]
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
)
|
|
|
|
|
|
2024-08-27 17:39:14 +00:00
|
|
|
@supported_platform
|
|
|
|
|
def test_non_sparse_mulitple_block_size(self):
|
|
|
|
|
def generate_causal_offset(offset: torch.Tensor):
|
|
|
|
|
def causal_offset_mask(b, h, q_idx, kv_idx):
|
|
|
|
|
return (offset + q_idx) >= kv_idx
|
|
|
|
|
|
|
|
|
|
return causal_offset_mask
|
|
|
|
|
|
|
|
|
|
def noop(score, b, h, q_idx, kv_idx):
|
|
|
|
|
return score
|
|
|
|
|
|
|
|
|
|
mod = generate_causal_offset(
|
|
|
|
|
torch.tensor(192, device="cuda", dtype=torch.int32)
|
|
|
|
|
)
|
|
|
|
|
block_mask = create_block_mask(mod, 1, 1, 1, 65)
|
|
|
|
|
|
|
|
|
|
self.run_test(
|
|
|
|
|
score_mod=None,
|
|
|
|
|
dtype=torch.float32,
|
|
|
|
|
block_mask=block_mask,
|
|
|
|
|
Q_B=1,
|
|
|
|
|
Q_H=1,
|
|
|
|
|
Q_S=1,
|
|
|
|
|
Q_D=16,
|
|
|
|
|
KV_B=1,
|
|
|
|
|
KV_H=1,
|
|
|
|
|
KV_S=65,
|
|
|
|
|
V_D=16,
|
|
|
|
|
)
|
|
|
|
|
|
2024-08-20 22:26:43 +00:00
|
|
|
@supported_platform
|
|
|
|
|
def test_do_not_trigger_dynamic_shapes_on_empty_block_mask(self):
|
|
|
|
|
torch._dynamo.reset()
|
|
|
|
|
H = Hq
|
|
|
|
|
q = torch.randn(B, H, 1, D, device="cuda")
|
|
|
|
|
for i in range(5):
|
|
|
|
|
k = torch.randn(B, H, S + i, D, device="cuda")
|
|
|
|
|
v = torch.randn(B, H, S + i, D, device="cuda")
|
|
|
|
|
compiled_flex_attention = torch.compile(flex_attention)
|
|
|
|
|
ref = flex_attention(q, k, v)
|
|
|
|
|
res = compiled_flex_attention(q, k, v)
|
|
|
|
|
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
|
|
|
|
|
torch.testing.assert_close(
|
|
|
|
|
ref, res, atol=tolerance.atol, rtol=tolerance.rtol
|
|
|
|
|
)
|
|
|
|
|
# Ensure no more re-compilation after the second automatic dynamic shape version.
|
|
|
|
|
if i == 0:
|
|
|
|
|
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
|
|
|
|
|
else:
|
|
|
|
|
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
|
|
|
|
|
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
2024-08-02 20:50:47 +00:00
|
|
|
common_utils.instantiate_parametrized_tests(TestFlexDecoding)
|
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding
tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding.
Higher-order-op `flex_attention` was introduced in (https://github.com/pytorch/pytorch/pull/121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences.
### Details
LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs.
`flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output.
## Examples
Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1).
We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv.
The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64).
```Python
import torch
from torch.nn.attention._flex_attention import _flex_attention as flex_attention
torch.manual_seed(0)
# Lets create some input tensors
# query of shape (B, Hkv, Hq//Hkv, D)
# key/value of shape (B, Hkv, N, D)
query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32)
key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32)
# Lets create a new score_modification checkerboard.
def checkerboard(score, batch, head, token_q, token_kv):
score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
return score
# Lets call flex_attention with this new score modification for decoding.
# The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8.
output = flex_attention(query, key, value, score_mod=checkerboard)
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard)
torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```
## Future Plans
- This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16.
i.e.
```python
q_scale = torch.randn(Hq//Hkv, device="cuda")
q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer
def bias_mod(score, batch, head, q, kv):
score = score + q_scale[token_q]
return score
```
- Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128.
- Dynamic shape and max_autotuning is currently not working
- Add block sparse mask support (#129216 is a draft for flex_attention kernel)
- Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129415
Approved by: https://github.com/Chillee
2024-07-13 00:41:45 +00:00
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
|
|
|
|
|
|
run_tests()
|