mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add explicit GQA support. (#131559)
### tl;dr
This PR adds GQA support to higher order op `flex_attention`.
## Details
When `enable_gqa` is set to True, HOP `flex_attention(score_mod, query, key, value, block_mask, enable_gqa)` runs Group Query Attention(GQA), where the number of query heads (Hq) is a multiple of number of key/value heads (Hkv). Each group of query heads (`Hq//Hkv` heads) attends to a shared kv head.
Otherwise, `flex_attention` assumes Multi Head Attention (MHA) where the number of query heads is equal the number of kv heads.
The `score_mod` and `mask_mod` API are adapted accordingly to take `q_head` as head index.
```
def score_mod(score: torch.Tensor, batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
def mask_mod(batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```
## Example
```python
import torch
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask
torch.manual_seed(0)
def query_key_value_clones(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype = None,
):
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
if dtype is None:
dtype = query.dtype
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref
# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim).
# query and key/value can have different num_heads and seq_len
# Here 8 query heads share one KV head.
query = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
key = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
value = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
query1, key1, value1 = query_key_value_clones(query, key, value)
# Lets create a score_modification. We take alibi_bias as an example.
# score_mod takes batch index, query head index, query index, and key/value index.
def _generate_alibi_bias(num_kv_heads: int, num_q_heads: int):
def _alibi_bias(
score: torch.Tensor,
b: torch.Tensor,
hq: torch.Tensor,
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
# Let's calculate kv head from query head index
group = num_q_heads // num_kv_heads
hkv = hq // group
scale = torch.exp2(-((hkv + 1) * 8.0 / num_kv_heads))
return score + (token_kv - token_q) * scale
return _alibi_bias
# Let's apply a casual mask on top of it
def causal_mask(b, h, q, kv):
return q >= kv
# Generate a block mask for our new mask_mod function.
# The mask is broadcasted long head & batch dimensions.
block_mask = create_block_mask(causal_mask, B=1, H=1, Q_LEN=2048, KV_LEN=2048)
# Lets call flex_attention with our new score modification and block mask under eager mode.
output = flex_attention(query, key, value, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
# Now lets compile flex_attention and run the flex_attention kernel.
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention(query1, key1, value1, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
torch.testing.assert_close(output, out_compiled, atol=5e-2, rtol=2e-2)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131559
Approved by: https://github.com/drisspg
This commit is contained in:
parent
dc8bb2636c
commit
4110cb6ba7
7 changed files with 461 additions and 362 deletions
|
|
@ -15,6 +15,7 @@ import torch.nn.functional as F
|
|||
from torch.nn.attention.flex_attention import (
|
||||
_create_empty_block_mask,
|
||||
create_block_mask,
|
||||
create_mask,
|
||||
flex_attention,
|
||||
)
|
||||
|
||||
|
|
@ -106,9 +107,7 @@ def generate_inputs(
|
|||
torch.rand, kv_shape, device=device, dtype=dtype, requires_grad=requires_grad
|
||||
)
|
||||
query = (
|
||||
make_q()
|
||||
.view(batch_size, num_h_groups * q_sequence_length, kv_heads, head_dim)
|
||||
.transpose(1, 2)
|
||||
make_q().view(batch_size, q_sequence_length, q_heads, head_dim).transpose(1, 2)
|
||||
)
|
||||
key = (
|
||||
make_kv()
|
||||
|
|
@ -146,8 +145,9 @@ def run_single_experiment(
|
|||
if get_func_name(config.mask_mod) == "causal":
|
||||
kwargs["is_causal"] = True
|
||||
|
||||
def eager_sdpa(query, key, value, _):
|
||||
return F.scaled_dot_product_attention(query, key, value, **kwargs)
|
||||
def eager_sdpa(query, key, value, attn_mask):
|
||||
out = F.scaled_dot_product_attention(query, key, value, attn_mask, **kwargs)
|
||||
return out.reshape(batch_size, q_heads, q_seq_len, head_dim)
|
||||
|
||||
if max_autotune:
|
||||
compiled_sdpa = torch.compile(
|
||||
|
|
@ -161,27 +161,62 @@ def run_single_experiment(
|
|||
|
||||
if mask_mod:
|
||||
block_mask = create_block_mask(
|
||||
mask_mod, 1, 1, q_seq_len * (q_heads // kv_heads), kv_seq_len, query.device
|
||||
mask_mod, 1, 1, q_seq_len, kv_seq_len, query.device
|
||||
)
|
||||
else:
|
||||
block_mask = _create_empty_block_mask(query, key)
|
||||
|
||||
if mask_mod and get_func_name(mask_mod) != "causal":
|
||||
attn_mask = create_mask(mask_mod, 1, 1, query.shape[-2], key.shape[-2])
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
# Broadcast query/key for eager.
|
||||
b_key = torch.repeat_interleave(key, q_heads // kv_heads, dim=1)
|
||||
b_value = torch.repeat_interleave(value, q_heads // kv_heads, dim=1)
|
||||
|
||||
forward_eager_time = benchmark_torch_function_in_microseconds(
|
||||
eager_sdpa, query, key, value, score_mod
|
||||
eager_sdpa, query, b_key, b_value, attn_mask
|
||||
)
|
||||
forward_compiled_time = benchmark_torch_function_in_microseconds(
|
||||
compiled_sdpa, query, key, value, score_mod, block_mask
|
||||
compiled_sdpa,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=True,
|
||||
)
|
||||
|
||||
out_eager = eager_sdpa(query, b_key, b_value, attn_mask)
|
||||
out_compile = compiled_sdpa(
|
||||
query,
|
||||
b_key,
|
||||
b_value,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=True,
|
||||
)
|
||||
|
||||
if score_mod is None:
|
||||
torch.testing.assert_close(out_eager, out_compile, atol=1e-2, rtol=1e-2)
|
||||
|
||||
if config.calculate_bwd_time:
|
||||
out_eager = eager_sdpa(query, key, value, score_mod)
|
||||
out_eager = eager_sdpa(query, b_key, b_value, attn_mask)
|
||||
dOut = torch.randn_like(out_eager)
|
||||
backward_eager_time = benchmark_torch_function_in_microseconds(
|
||||
out_eager.backward, dOut, retain_graph=True
|
||||
)
|
||||
|
||||
out_compile = compiled_sdpa(query, key, value, score_mod, block_mask)
|
||||
dOut = torch.randn_like(out_eager)
|
||||
out_compile = compiled_sdpa(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=True,
|
||||
)
|
||||
dOut = torch.randn_like(out_compile)
|
||||
backward_compile_time = benchmark_torch_function_in_microseconds(
|
||||
out_compile.backward, dOut, retain_graph=True
|
||||
)
|
||||
|
|
@ -250,8 +285,6 @@ def calculate_tflops(config: ExperimentConfig, results: ExperimentResults) -> fl
|
|||
def get_func_name(func):
|
||||
if func is None:
|
||||
return "None"
|
||||
if "gqa" in func.__name__:
|
||||
return func.__name__
|
||||
func_str = str(func)
|
||||
if "<locals>" in func_str:
|
||||
# For locally defined functions
|
||||
|
|
@ -369,8 +402,9 @@ def generate_score_mods(score_mods: List[str]) -> List[Callable | None]:
|
|||
return score + 2 * h
|
||||
|
||||
function_dict = {
|
||||
"noop": noop,
|
||||
"noop": None,
|
||||
"causal": None,
|
||||
"offset": None,
|
||||
"rel": relative_bias,
|
||||
"head_bias": head_bias,
|
||||
}
|
||||
|
|
@ -384,45 +418,22 @@ def generate_mask_mods(score_mods: List[str]) -> List[Callable | None]:
|
|||
def causal(b, h, m, n):
|
||||
return m >= n
|
||||
|
||||
def gen_offset(off):
|
||||
def offset(b, h, m, n):
|
||||
return m + off >= n
|
||||
|
||||
return offset
|
||||
|
||||
mask_mod_dict = {
|
||||
"noop": None,
|
||||
"causal": causal,
|
||||
"offset": gen_offset,
|
||||
"rel": None,
|
||||
"head_bias": None,
|
||||
}
|
||||
return [mask_mod_dict[name] for name in score_mods]
|
||||
|
||||
|
||||
def get_gqa_score_mod(score_mod, G, q_seq_len):
|
||||
if score_mod is None:
|
||||
return None
|
||||
|
||||
def score_mod_gqa(score, b, hkv, m, n):
|
||||
g = m // q_seq_len
|
||||
new_m = m % q_seq_len
|
||||
hq = hkv * G + g
|
||||
return score_mod(score, b, hq, new_m, n)
|
||||
|
||||
score_mod_name = get_func_name(score_mod)
|
||||
set_func_name(score_mod_gqa, score_mod_name + "_gqa")
|
||||
return score_mod_gqa
|
||||
|
||||
|
||||
def get_gqa_mask_mod(mask_mod, G, q_seq_len):
|
||||
if mask_mod is None:
|
||||
return None
|
||||
|
||||
def mask_mod_gqa(b, h, m, n):
|
||||
g = m // q_seq_len
|
||||
new_m = m % q_seq_len
|
||||
hq = h * G + g
|
||||
return mask_mod(b, hq, new_m, n)
|
||||
|
||||
mask_mod_name = get_func_name(mask_mod)
|
||||
set_func_name(mask_mod_gqa, mask_mod_name + "_gqa")
|
||||
return mask_mod_gqa
|
||||
|
||||
|
||||
def generate_flash_configs(
|
||||
calculate_bwd: bool,
|
||||
dtype: torch.dtype,
|
||||
|
|
@ -521,16 +532,14 @@ def generate_experiment_configs(
|
|||
(q_heads, kv_heads),
|
||||
(q_seq_len, kv_seq_len),
|
||||
head_dim,
|
||||
score_mod,
|
||||
mask_mod,
|
||||
(score_mod, mask_mod),
|
||||
dtype,
|
||||
) in itertools.product(
|
||||
kv_cache_size if kv_cache_size else batch_sizes,
|
||||
num_heads,
|
||||
q_kv_seq_lens,
|
||||
head_dims,
|
||||
score_mods,
|
||||
mask_mods,
|
||||
zip(score_mods, mask_mods),
|
||||
dtypes,
|
||||
):
|
||||
if kv_cache_size:
|
||||
|
|
@ -541,11 +550,10 @@ def generate_experiment_configs(
|
|||
if bsz <= 0:
|
||||
continue
|
||||
|
||||
if q_heads != kv_heads: # GQA work around before it's explicitly supported
|
||||
assert q_heads % kv_heads == 0
|
||||
G = q_heads // kv_heads
|
||||
score_mod = get_gqa_score_mod(score_mod, G, q_seq_len)
|
||||
mask_mod = get_gqa_mask_mod(mask_mod, G, q_seq_len)
|
||||
assert q_heads % kv_heads == 0
|
||||
|
||||
if mask_mod and get_func_name(mask_mod) == "gen_offset":
|
||||
mask_mod = mask_mod(kv_seq_len // 2)
|
||||
|
||||
all_configs.append(
|
||||
ExperimentConfig(
|
||||
|
|
@ -577,7 +585,7 @@ def main(args):
|
|||
args.mods,
|
||||
args.decoding,
|
||||
args.kv_cache_size,
|
||||
args.cal_bandwidth,
|
||||
args.throughput,
|
||||
)
|
||||
):
|
||||
results.append(
|
||||
|
|
@ -658,7 +666,7 @@ Ignores -b batch size and calculate batch size from kv_cache size instead when s
|
|||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cal-bandwidth",
|
||||
"--throughput",
|
||||
action="store_true",
|
||||
help="Calculate kernel memory bandwidth & computational throughput. ",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -53,13 +53,23 @@ def rmse(ref, res):
|
|||
return torch.sqrt(torch.mean(torch.square(ref - res)))
|
||||
|
||||
|
||||
def create_attention(score_mod, block_mask):
|
||||
return functools.partial(flex_attention, score_mod=score_mod, block_mask=block_mask)
|
||||
def create_attention(score_mod, block_mask, enable_gqa=False):
|
||||
return functools.partial(
|
||||
flex_attention,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
|
||||
|
||||
def create_block_mask_test(score_mod, query, key):
|
||||
block_mask = create_block_mask(
|
||||
score_mod, 1, 1, query.shape[-2], key.shape[-2], query.device
|
||||
score_mod,
|
||||
1,
|
||||
1,
|
||||
query.shape[-2],
|
||||
key.shape[-2],
|
||||
query.device,
|
||||
)
|
||||
return block_mask
|
||||
|
||||
|
|
@ -276,7 +286,9 @@ class TestFlexAttention(InductorTestCase):
|
|||
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
|
||||
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
|
||||
block_mask = None
|
||||
sdpa_partial = create_attention(score_mod, block_mask)
|
||||
sdpa_partial = create_attention(
|
||||
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
|
||||
)
|
||||
compiled_sdpa = torch.compile(sdpa_partial)
|
||||
golden_out = sdpa_partial(q_gold, k_gold, v_gold)
|
||||
ref_out = sdpa_partial(q_ref, k_ref, v_ref)
|
||||
|
|
@ -548,6 +560,23 @@ class TestFlexAttention(InductorTestCase):
|
|||
D,
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
@common_utils.parametrize("score_mod", test_score_mods)
|
||||
def test_GQA(self, dtype: torch.dtype, score_mod: Callable):
|
||||
self.run_test(
|
||||
score_mod,
|
||||
dtype,
|
||||
B,
|
||||
H * 4, # Hq = 4*Hkv.
|
||||
S // 8,
|
||||
D,
|
||||
B,
|
||||
H,
|
||||
S,
|
||||
D,
|
||||
)
|
||||
|
||||
test_strides = [
|
||||
((H * S * D, S * D, D, 1), 997), # offset
|
||||
((H * D, D, B * H * D, 1), 499), # transposed dimensions
|
||||
|
|
@ -1166,6 +1195,29 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
|
||||
self.run_test_with_call(attention)
|
||||
|
||||
@supported_platform
|
||||
def test_GQA_causal_mask(self):
|
||||
def mask_mod(b, h, q, kv):
|
||||
return q >= kv
|
||||
|
||||
block_mask = create_block_mask(mask_mod, 1, 1, S // 8, S // 8)
|
||||
attention = functools.partial(
|
||||
flex_attention, block_mask=block_mask, enable_gqa=True
|
||||
)
|
||||
|
||||
self.run_test_with_call(
|
||||
attention,
|
||||
torch.float16,
|
||||
B,
|
||||
H * 4, # Hq = 4*Hkv.
|
||||
S // 8,
|
||||
D,
|
||||
B,
|
||||
H,
|
||||
S // 8,
|
||||
D,
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
def test_custom_block_mask_generator(self):
|
||||
def mask_mod(b, h, q, kv):
|
||||
|
|
|
|||
|
|
@ -38,30 +38,13 @@ index = torch.ops.aten.index
|
|||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
# score_mod / gqa_mask convert for GQA inputs before GQA is explictly supported
|
||||
def get_gqa_score_mod(score_mod, G, q_seq_len):
|
||||
def score_mod_gqa(score, b, hkv, m, n):
|
||||
g = m // q_seq_len
|
||||
g = torch.where(g < G, g, 0)
|
||||
new_m = m % q_seq_len
|
||||
hq = hkv * G + g
|
||||
return score_mod(score, b, hq, new_m, n)
|
||||
|
||||
return score_mod_gqa
|
||||
|
||||
|
||||
def get_gqa_mask_mod(mask_fn, G, q_seq_len):
|
||||
def mask_mod_gqa(b, hkv, m, n):
|
||||
g = m // q_seq_len
|
||||
new_m = m % q_seq_len
|
||||
hq = hkv * G + g
|
||||
return mask_fn(b, hq, new_m, n)
|
||||
|
||||
return mask_mod_gqa
|
||||
|
||||
|
||||
def create_attention(score_mod, block_mask):
|
||||
return functools.partial(flex_attention, score_mod=score_mod, block_mask=block_mask)
|
||||
def create_attention(score_mod, block_mask, enable_gqa=False):
|
||||
return functools.partial(
|
||||
flex_attention,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
|
||||
|
||||
def create_block_mask_test(score_mod, query, key):
|
||||
|
|
@ -200,7 +183,6 @@ test_Hq_Hkv = [
|
|||
(16, 1),
|
||||
(8, 2),
|
||||
(16, 16),
|
||||
(20, 1),
|
||||
]
|
||||
|
||||
(Hq, Hkv) = (16, 8)
|
||||
|
|
@ -251,20 +233,11 @@ class TestFlexDecoding(InductorTestCase):
|
|||
msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
|
||||
self.assertTrue(False, msg)
|
||||
|
||||
def _check_out_and_grad(
|
||||
def _check_out(
|
||||
self,
|
||||
golden_out: torch.Tensor,
|
||||
ref_out: torch.Tensor,
|
||||
compiled_out: torch.Tensor,
|
||||
q_gold: torch.Tensor,
|
||||
q_ref: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
k_gold: torch.Tensor,
|
||||
k_ref: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v_gold: torch.Tensor,
|
||||
v_ref: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
):
|
||||
dtype = ref_out.dtype
|
||||
with torch.no_grad():
|
||||
|
|
@ -278,21 +251,6 @@ class TestFlexDecoding(InductorTestCase):
|
|||
# Checkout output
|
||||
self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")
|
||||
|
||||
# TODO: add backward support
|
||||
# # Check gradients
|
||||
# q_fudge_factor = 2.5 * fudge_factor
|
||||
# self._check_equal(
|
||||
# q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query"
|
||||
# )
|
||||
# k_fudge_factor = 4 * fudge_factor
|
||||
# self._check_equal(
|
||||
# k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key"
|
||||
# )
|
||||
# v_fudge_factor = 4 * fudge_factor
|
||||
# self._check_equal(
|
||||
# v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value"
|
||||
# )
|
||||
|
||||
def run_test(
|
||||
self,
|
||||
score_mod: Callable,
|
||||
|
|
@ -307,10 +265,8 @@ class TestFlexDecoding(InductorTestCase):
|
|||
KV_D: int = D,
|
||||
):
|
||||
assert Q_H % KV_H == 0
|
||||
score_mod = get_gqa_score_mod(score_mod, G=Q_H // KV_H, q_seq_len=Q_S)
|
||||
|
||||
q = torch.randn(
|
||||
(Q_B, KV_H, Q_S * (Q_H // KV_H), Q_D),
|
||||
(Q_B, Q_H, Q_S, Q_D),
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
requires_grad=False,
|
||||
|
|
@ -325,34 +281,18 @@ class TestFlexDecoding(InductorTestCase):
|
|||
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
|
||||
|
||||
block_mask = None
|
||||
sdpa_partial = create_attention(score_mod, block_mask)
|
||||
sdpa_partial = create_attention(
|
||||
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
|
||||
)
|
||||
compiled_sdpa = torch.compile(sdpa_partial)
|
||||
golden_out = sdpa_partial(q_gold, k_gold, v_gold)
|
||||
ref_out = sdpa_partial(q_ref, k_ref, v_ref)
|
||||
compiled_out = compiled_sdpa(q, k, v)
|
||||
|
||||
# TODO: Add backward support
|
||||
# backward_grad = torch.randn(
|
||||
# (Q_B, KV_H, Q_S *(Q_H // KV_H), Q_D), dtype=dtype, device="cuda"
|
||||
# )
|
||||
|
||||
# golden_out.backward(backward_grad.to(torch.float64))
|
||||
# ref_out.backward(backward_grad)
|
||||
# compiled_out.backward(backward_grad)
|
||||
|
||||
self._check_out_and_grad(
|
||||
self._check_out(
|
||||
golden_out,
|
||||
ref_out,
|
||||
compiled_out,
|
||||
q_gold,
|
||||
q_ref,
|
||||
q,
|
||||
k_gold,
|
||||
k_ref,
|
||||
k,
|
||||
v_gold,
|
||||
v_ref,
|
||||
v,
|
||||
)
|
||||
|
||||
def run_test_with_call(
|
||||
|
|
@ -391,19 +331,10 @@ class TestFlexDecoding(InductorTestCase):
|
|||
ref_out = golden_call(q_ref, k_ref, v_ref)
|
||||
compiled_out = compiled_sdpa(q, k, v)
|
||||
|
||||
self._check_out_and_grad(
|
||||
self._check_out(
|
||||
golden_out,
|
||||
ref_out,
|
||||
compiled_out,
|
||||
q_gold,
|
||||
q_ref,
|
||||
q,
|
||||
k_gold,
|
||||
k_ref,
|
||||
k,
|
||||
v_gold,
|
||||
v_ref,
|
||||
v,
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
|
|
@ -488,7 +419,7 @@ class TestFlexDecoding(InductorTestCase):
|
|||
k_shape = (B, Hkv, S, D)
|
||||
v_shape = (B, Hkv, S, D)
|
||||
|
||||
q = q1.view(Hq // Hkv, Hkv, B, D).transpose(0, 2)
|
||||
q = q1.view(1, Hq, B, D).transpose(0, 2)
|
||||
|
||||
k_strides, k_offset = k_s(B, Hkv, S, D)
|
||||
k_max = [x * (y - 1) for x, y in zip(k_strides, k_shape)]
|
||||
|
|
@ -503,7 +434,9 @@ class TestFlexDecoding(InductorTestCase):
|
|||
v = torch.as_strided(v1, v_shape, v_strides, v_offset)
|
||||
|
||||
sdpa_partial = create_attention(
|
||||
score_mod=_generate_alibi_bias(8), block_mask=None
|
||||
score_mod=_generate_alibi_bias(8),
|
||||
block_mask=None,
|
||||
enable_gqa=(not Hq == Hkv),
|
||||
)
|
||||
compiled_sdpa = torch.compile(sdpa_partial)
|
||||
ref_out = sdpa_partial(q, k, v)
|
||||
|
|
|
|||
|
|
@ -184,6 +184,11 @@ def math_attention(
|
|||
score_mod: The score_mod function
|
||||
other_buffers: Other buffers that are passed to the score_mod function
|
||||
"""
|
||||
# broadcast query & key along head dim for GQA
|
||||
G = query.size(1) // key.size(1)
|
||||
value = torch.repeat_interleave(value, G, dim=1)
|
||||
key = torch.repeat_interleave(key, G, dim=1)
|
||||
|
||||
_, post_mod_scores = _math_attention_inner(
|
||||
query,
|
||||
key,
|
||||
|
|
@ -673,6 +678,10 @@ def sdpa_dense_backward(
|
|||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
G = query.size(1) // key.size(1)
|
||||
key = torch.repeat_interleave(key, G, dim=1)
|
||||
value = torch.repeat_interleave(value, G, dim=1)
|
||||
|
||||
scores, post_mod_scores = _math_attention_inner(
|
||||
query,
|
||||
key,
|
||||
|
|
@ -731,6 +740,18 @@ def sdpa_dense_backward(
|
|||
|
||||
grad_query = grad_scores @ key
|
||||
grad_key = grad_scores.transpose(-2, -1) @ query
|
||||
|
||||
# Reduce DK, DV along broadcasted heads.
|
||||
grad_key = grad_key.view(
|
||||
grad_key.size(0), -1, G, grad_key.size(-2), grad_key.size(-1)
|
||||
)
|
||||
grad_value = grad_value.view(
|
||||
grad_value.size(0), -1, G, grad_value.size(-2), grad_value.size(-1)
|
||||
)
|
||||
|
||||
grad_key = torch.sum(grad_key, 2, keepdim=False)
|
||||
grad_value = torch.sum(grad_value, 2, keepdim=False)
|
||||
|
||||
return grad_query.contiguous(), grad_key.contiguous(), grad_value.contiguous()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ log = logging.getLogger(__name__)
|
|||
aten = torch.ops.aten
|
||||
|
||||
|
||||
def flex_attention_grid(batch_size, num_heads, num_queries, d_model, meta):
|
||||
def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta):
|
||||
"""How is this kernel parallelized?
|
||||
We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1)
|
||||
Each block is responsible for iterating over blocks of keys and values calculating
|
||||
|
|
@ -37,7 +37,7 @@ def flex_attention_grid(batch_size, num_heads, num_queries, d_model, meta):
|
|||
"""
|
||||
import triton
|
||||
|
||||
return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * num_heads, 1)
|
||||
return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1)
|
||||
|
||||
|
||||
def create_placeholder(
|
||||
|
|
@ -134,6 +134,7 @@ compute_flex_attention = r"""
|
|||
# Q: Query, K: Key, V: Value
|
||||
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
||||
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
||||
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
||||
#
|
||||
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
||||
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
||||
|
|
@ -163,29 +164,31 @@ compute_flex_attention = r"""
|
|||
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
|
||||
|
||||
Z = {{size("Q", 0)}}
|
||||
H = {{size("Q", 1)}}
|
||||
HQ = {{size("Q", 1)}}
|
||||
Q_LEN = {{size("Q", 2)}}
|
||||
KV_LEN = {{size("K", 2)}}
|
||||
|
||||
MATMUL_PRECISION = Q.dtype.element_ty
|
||||
|
||||
q_start = tl.program_id(0)
|
||||
off_z = tl.program_id(1) // H
|
||||
off_h = tl.program_id(1) % H
|
||||
off_z = tl.program_id(1) // HQ
|
||||
off_hq = tl.program_id(1) % HQ
|
||||
off_hkv = off_hq // GQA_SHARED_HEADS
|
||||
off_g = off_hq % GQA_SHARED_HEADS
|
||||
|
||||
q_offset = off_z * stride_qz + off_h * stride_qh
|
||||
k_offset = off_z * stride_kz + off_h * stride_kh
|
||||
v_offset = off_z * stride_vz + off_h * stride_vh
|
||||
q_offset = off_z * stride_qz + off_hq * stride_qh
|
||||
k_offset = off_z * stride_kz + off_hkv * stride_kh
|
||||
v_offset = off_z * stride_vz + off_hkv * stride_vh
|
||||
|
||||
Q = Q + q_offset
|
||||
K = K + k_offset
|
||||
V = V + v_offset
|
||||
|
||||
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
||||
SPARSE_H = {{size("KV_NUM_BLKS", 1)}}
|
||||
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
|
||||
|
||||
sparse_idx_z = off_z % SPARSE_Z
|
||||
sparse_idx_h = off_h % SPARSE_H
|
||||
sparse_idx_hq = off_hq % SPARSE_HQ
|
||||
|
||||
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
||||
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
||||
|
|
@ -201,7 +204,7 @@ compute_flex_attention = r"""
|
|||
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
|
||||
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
||||
sparse_hz_offset = sparse_idx_z * SPARSE_H + sparse_idx_h
|
||||
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
||||
sparse_kv_num_blks_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT + q_start // SPARSE_Q_MULTIPLE
|
||||
sparse_kv_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (q_start // SPARSE_Q_MULTIPLE) * SPARSE_KV_BLOCK_CNT # noqa: B950
|
||||
|
||||
|
|
@ -244,7 +247,7 @@ compute_flex_attention = r"""
|
|||
acc, l_i, m_i = forward_inner(
|
||||
q, K_block_ptr, V_block_ptr,
|
||||
acc, l_i, m_i,
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
off_z, off_hq, offs_m[:, None], offs_n[None, :],
|
||||
kv_indices, kv_num_blocks,
|
||||
0, kv_num_blocks * SPARSE_KV_MULTIPLE,
|
||||
MATMUL_PRECISION,
|
||||
|
|
@ -282,7 +285,7 @@ compute_flex_attention = r"""
|
|||
acc, l_i, m_i = forward_inner(
|
||||
q, K_block_ptr, V_block_ptr,
|
||||
acc, l_i, m_i,
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
off_z, off_hq, offs_m[:, None], offs_n[None, :],
|
||||
kv_indices, kv_num_blocks,
|
||||
0, kv_num_blocks * SPARSE_KV_MULTIPLE,
|
||||
MATMUL_PRECISION,
|
||||
|
|
@ -294,16 +297,14 @@ compute_flex_attention = r"""
|
|||
# Store output and logsumexp
|
||||
l_i = tl.where(l_i == 0, 1, l_i)
|
||||
acc = acc / l_i[:, None]
|
||||
idx_z = tl.program_id(1) // H
|
||||
idx_h = tl.program_id(1) % H
|
||||
|
||||
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
idx_z = tl.program_id(1) // HQ
|
||||
idx_hq = tl.program_id(1) % HQ
|
||||
idx_m = offs_m[:, None]
|
||||
idx_d = tl.arange(0, BLOCK_DMODEL)[None, :]
|
||||
|
||||
mask = idx_m < Q_LEN
|
||||
# TODO generalize and add proper mask support
|
||||
{{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc", "mask")}}
|
||||
{{store_output(("idx_z", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
|
||||
|
||||
# TODO dont want to write this if we dont require grad
|
||||
if OUTPUT_LOGSUMEXP:
|
||||
|
|
@ -320,7 +321,8 @@ def forward_inner(
|
|||
q, K_block_ptr, V_block_ptr,
|
||||
# accumulated values
|
||||
acc, l_i, m_i,
|
||||
# Offsets
|
||||
# Offsets used as inputs to score_mod & mask_mod
|
||||
# of size [BLOCK_M, BLOCK_N] or scalar.
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
# blocksparse data
|
||||
kv_indices, kv_num_blocks,
|
||||
|
|
@ -351,8 +353,6 @@ def forward_inner(
|
|||
if not PRESCALE_QK:
|
||||
qk *= SM_SCALE
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
||||
m = offs_m[:, None]
|
||||
n = offs_n[None, :]
|
||||
# TODO: Add load mask in modification when M/N Boundary is not safe
|
||||
{{ modification(
|
||||
subgraph_number=0,
|
||||
|
|
@ -360,8 +360,8 @@ def forward_inner(
|
|||
score="qk",
|
||||
b="off_z",
|
||||
h="off_h",
|
||||
m="m",
|
||||
n="n",
|
||||
m="offs_m",
|
||||
n="offs_n",
|
||||
out="qk"
|
||||
) | indent_except_first(2) }}
|
||||
|
||||
|
|
@ -372,8 +372,8 @@ def forward_inner(
|
|||
score="qk",
|
||||
b="off_z",
|
||||
h="off_h",
|
||||
m="m",
|
||||
n="n",
|
||||
m="offs_m",
|
||||
n="offs_n",
|
||||
) | indent_except_first(3) }}
|
||||
# apply mask for partially unmasked blocks
|
||||
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
||||
|
|
@ -627,7 +627,7 @@ def flex_attention(
|
|||
query.get_stride(),
|
||||
)
|
||||
# see NOTE:[TritonTemplates with multiple outputs]
|
||||
logsumexp_shape = query.get_size()[:-1] # [B, H, M]
|
||||
logsumexp_shape = query.get_size()[:-1] # [B, Hq, Mq]
|
||||
logsumexp = empty_strided(
|
||||
logsumexp_shape,
|
||||
None,
|
||||
|
|
@ -636,6 +636,11 @@ def flex_attention(
|
|||
)
|
||||
kernel_options = dict(kernel_options)
|
||||
kernel_options["SM_SCALE"] = scale
|
||||
|
||||
# Determine GQA broadcast factor.
|
||||
gqa_shared_heads = query.get_size()[1] // key.get_size()[1]
|
||||
kernel_options["GQA_SHARED_HEADS"] = gqa_shared_heads
|
||||
|
||||
# Inside of Triton kernel, only apply partial masking if partial blocks are computed.
|
||||
# full_kv_num_blocks is None if partial blocks are not computed
|
||||
kernel_options["HAS_FULL_BLOCKS"] = full_kv_num_blocks is not None
|
||||
|
|
@ -736,20 +741,20 @@ def flex_attention(
|
|||
|
||||
|
||||
def flex_attention_backward_grid(
|
||||
batch_size, num_heads, num_queries, d_model, num_key_value, meta
|
||||
batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta
|
||||
):
|
||||
"""How is this kernel parallelized?
|
||||
Currently this is only parallelizing over batch * num_heads, but we can, and want to
|
||||
parallelize over ceil_div(num_key_value, key_value_block_size). To do this will either require
|
||||
atomic updates to some grad values or to have a two pass kernel design.
|
||||
Currently this is only parallelizing over batch* kv_heads, but we can, and want to
|
||||
parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size).
|
||||
To do this will either require atomic updates to some grad values or to have a two pass kernel design.
|
||||
"""
|
||||
import triton
|
||||
|
||||
return (
|
||||
triton.cdiv(num_queries, meta["BLOCK_M2"])
|
||||
triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads)
|
||||
+ triton.cdiv(num_key_value, meta["BLOCK_N1"]),
|
||||
1,
|
||||
batch_size * num_heads,
|
||||
batch_size * kv_heads,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -768,6 +773,7 @@ flex_attention_backward_template = TritonTemplate(
|
|||
# inductor codegen
|
||||
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
||||
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
||||
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
||||
# (Modifiable) Performance tuning options
|
||||
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
||||
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
||||
|
|
@ -799,7 +805,8 @@ flex_attention_backward_template = TritonTemplate(
|
|||
stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}}
|
||||
|
||||
Z = {{size("Q", 0)}}
|
||||
H = {{size("Q", 1)}}
|
||||
HQ = {{size("Q", 1)}}
|
||||
HKV = {{size("K", 1)}}
|
||||
Q_LEN = {{size("Q", 2)}}
|
||||
KV_LEN = {{size("K", 2)}}
|
||||
|
||||
|
|
@ -807,39 +814,25 @@ flex_attention_backward_template = TritonTemplate(
|
|||
|
||||
pid = tl.program_id(0)
|
||||
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
||||
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
||||
|
||||
off_hz = tl.program_id(2)
|
||||
off_z = off_hz // H # batch idx
|
||||
off_h = off_hz % H # head idx
|
||||
off_z = off_hz // HKV # batch idx
|
||||
off_hkv = off_hz % HKV # kv head idx
|
||||
|
||||
SM_Z = {{size("KV_NUM_BLKS", 0)}}
|
||||
SM_H = {{size("KV_NUM_BLKS", 1)}}
|
||||
SM_HQ = {{size("KV_NUM_BLKS", 1)}}
|
||||
|
||||
sparse_idx_z = off_z % SM_Z
|
||||
sparse_idx_h = off_h % SM_H
|
||||
|
||||
sparse_hz_offset = sparse_idx_z * SM_H + sparse_idx_h
|
||||
k_adj = (stride_kh * off_hkv + stride_kz * off_z).to(tl.int64)
|
||||
v_adj = (stride_vh * off_hkv + stride_vz * off_z).to(tl.int64)
|
||||
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_z).to(tl.int64)
|
||||
|
||||
off_chz = (off_hz * Q_LEN).to(tl.int64)
|
||||
q_adj = (stride_qh * (off_hz % H) + stride_qz * (off_hz // H)).to(tl.int64)
|
||||
k_adj = (stride_kh * (off_hz % H) + stride_kz * (off_hz // H)).to(tl.int64)
|
||||
v_adj = (stride_vh * (off_hz % H) + stride_vz * (off_hz // H)).to(tl.int64)
|
||||
do_adj = (stride_doh * (off_hz % H) + stride_doz * (off_hz // H)).to(tl.int64)
|
||||
|
||||
dq_adj = (stride_dqh * (off_hz % H) + stride_dqz * (off_hz // H)).to(tl.int64)
|
||||
dv_adj = (stride_dvh * (off_hz % H) + stride_dvz * (off_hz // H)).to(tl.int64)
|
||||
|
||||
# offset pointers for batch/head
|
||||
Q += q_adj
|
||||
# offset K, V, DV pointers for batch/kv-head
|
||||
K += k_adj
|
||||
V += v_adj
|
||||
DO += do_adj
|
||||
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
||||
# if Q is broadcasted)
|
||||
DQ += dq_adj
|
||||
DV += dv_adj
|
||||
LSE += off_chz
|
||||
DELTA += off_chz
|
||||
|
||||
RCP_LN2 = 1.44269504
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
|
|
@ -849,43 +842,60 @@ flex_attention_backward_template = TritonTemplate(
|
|||
# THIS BLOCK DOES DQ
|
||||
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
||||
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
||||
off_pid_mask = off_pid // SPARSE_Q_MULTIPLE
|
||||
|
||||
KV_IDX_N = {{size("KV_IDX", 3)}}
|
||||
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
||||
start_m2_block = off_pid % NUM_Q_BLOCKS
|
||||
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
||||
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
|
||||
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
|
||||
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
|
||||
|
||||
sparse_idx_hq2 = off_hq2 % SM_HQ
|
||||
sparse_hz_offset = sparse_idx_z * SM_HQ + sparse_idx_hq2
|
||||
|
||||
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
||||
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
||||
|
||||
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
|
||||
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_z).to(tl.int64)
|
||||
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_z).to(tl.int64)
|
||||
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_z).to(tl.int64)
|
||||
off_chz2 = ((off_z * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
||||
|
||||
Q2 = Q + q_adj2
|
||||
DO2 = DO + do_adj2
|
||||
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
||||
# if Q is broadcasted)
|
||||
DQ2 = DQ + dq_adj2
|
||||
LSE2 = LSE + off_chz2
|
||||
DELTA2 = DELTA + off_chz2
|
||||
|
||||
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
start_m2 = off_pid * BLOCK_M2
|
||||
start_m2 = start_m2_block * BLOCK_M2
|
||||
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
||||
|
||||
# load Q and do: they stay in SRAM throughout the inner loop.
|
||||
q = tl.load(Q + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd)
|
||||
do = tl.load(DO + offs_m2[:, None] * stride_dom + offs_k[None, :] * stride_dod)
|
||||
q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd)
|
||||
do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_k[None, :] * stride_dod)
|
||||
|
||||
if PRESCALE_QK:
|
||||
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
||||
|
||||
Di = tl.load(DELTA + offs_m2)
|
||||
lse = tl.load(LSE + offs_m2)
|
||||
Di = tl.load(DELTA2 + offs_m2)
|
||||
lse = tl.load(LSE2 + offs_m2)
|
||||
lse = lse[:, None]
|
||||
|
||||
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
||||
kv_indices = KV_IDX + sparse_kv_idx_offset
|
||||
kv_start = tl.load(KV_IDX + sparse_kv_idx_offset) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
|
||||
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
||||
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
||||
dq = bwd_dq_inner(
|
||||
dq, q, K, V, do, Di, lse,
|
||||
off_z, off_h, offs_m2, offs_n2,
|
||||
K, V,
|
||||
dq, q, do, Di, lse,
|
||||
off_z, off_hq2, offs_m2, offs_n2,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
|
|
@ -900,11 +910,11 @@ flex_attention_backward_template = TritonTemplate(
|
|||
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
|
||||
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
||||
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
||||
dq = bwd_dq_inner(
|
||||
dq, q, K, V, do, Di, lse,
|
||||
off_z, off_h, offs_m2, offs_n2,
|
||||
K, V,
|
||||
dq, q, do, Di, lse,
|
||||
off_z, off_hq2, offs_m2, offs_n2,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
|
|
@ -913,7 +923,7 @@ flex_attention_backward_template = TritonTemplate(
|
|||
)
|
||||
|
||||
# Write back dQ.
|
||||
dq_ptrs = DQ + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
||||
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
||||
dq *= SM_SCALE
|
||||
tl.store(dq_ptrs, dq)
|
||||
else:
|
||||
|
|
@ -923,14 +933,10 @@ flex_attention_backward_template = TritonTemplate(
|
|||
|
||||
pid_mask = pid // SPARSE_KV_MULTIPLE
|
||||
|
||||
Q_IDX_M = {{size("Q_IDX", 3)}}
|
||||
stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}}
|
||||
stride_q_idx_h = {{stride("Q_IDX", 1)}}
|
||||
stride_q_idx_n = {{stride("Q_IDX", 2)}}
|
||||
|
||||
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
||||
sparse_q_idx_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask * stride_q_idx_n # noqa: B950
|
||||
|
||||
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
|
|
@ -943,46 +949,66 @@ flex_attention_backward_template = TritonTemplate(
|
|||
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
||||
v = tl.load(V + offs_n1[:, None] * stride_vn + offs_k[None, :] * stride_vd)
|
||||
|
||||
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
||||
q_indices = Q_IDX + sparse_q_idx_offset
|
||||
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
||||
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
||||
for off_g in range(0, GQA_SHARED_HEADS):
|
||||
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
||||
|
||||
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
||||
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
||||
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
|
||||
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_z).to(tl.int64)
|
||||
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_z).to(tl.int64)
|
||||
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_z).to(tl.int64)
|
||||
off_chz1 = ((off_z * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
||||
|
||||
dk, dv = bwd_dkdv_inner(
|
||||
dk, dv, Q, k, v, DO, DELTA, LSE,
|
||||
off_z, off_h, offs_n1, offs_m1,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}},
|
||||
IS_FULL_BLOCKS=False
|
||||
)
|
||||
Q1 = Q + q_adj1
|
||||
DO1 = DO + do_adj1
|
||||
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
||||
# if Q is broadcasted)
|
||||
LSE1 = LSE + off_chz1
|
||||
DELTA1 = DELTA + off_chz1
|
||||
|
||||
sparse_idx_hq1 = off_hq1 % SM_HQ
|
||||
sparse_hz_offset = sparse_idx_z * SM_HQ + sparse_idx_hq1
|
||||
|
||||
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
||||
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
||||
|
||||
if HAS_FULL_BLOCKS:
|
||||
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
||||
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
||||
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
||||
q_indices = Q_IDX + sparse_q_idx_offset
|
||||
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
||||
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
||||
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
||||
|
||||
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
||||
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
||||
|
||||
dk, dv = bwd_dkdv_inner(
|
||||
dk, dv, Q, k, v, DO, DELTA, LSE,
|
||||
off_z, off_h, offs_n1, offs_m1,
|
||||
Q1, DO1, DELTA1, LSE1,
|
||||
dk, dv, k, v,
|
||||
off_z, off_hq1, offs_n1, offs_m1,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}},
|
||||
IS_FULL_BLOCKS=True
|
||||
IS_FULL_BLOCKS=False
|
||||
)
|
||||
|
||||
|
||||
if HAS_FULL_BLOCKS:
|
||||
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
||||
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
||||
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
||||
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
||||
|
||||
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
||||
dk, dv = bwd_dkdv_inner(
|
||||
Q1, DO1, DELTA1, LSE1,
|
||||
dk, dv, k, v,
|
||||
off_z, off_hq1, offs_n1, offs_m1,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}},
|
||||
IS_FULL_BLOCKS=True
|
||||
)
|
||||
|
||||
# Write back dV and dK.
|
||||
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_k[None, :] * stride_dvd
|
||||
|
||||
|
|
@ -993,12 +1019,13 @@ flex_attention_backward_template = TritonTemplate(
|
|||
|
||||
dk *= SM_SCALE
|
||||
mask = index_n < KV_LEN
|
||||
{{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
|
||||
{{store_output(("off_z", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
|
||||
|
||||
@triton.jit
|
||||
def bwd_dq_inner(
|
||||
dq, q, K, V, do, Di, lse,
|
||||
off_z, off_h, offs_m2, offs_n2,
|
||||
K, V, # pointers
|
||||
dq, q, do, Di, lse,
|
||||
off_z, off_hq, offs_m2, offs_n2,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
|
|
@ -1032,7 +1059,7 @@ def bwd_dq_inner(
|
|||
output_name="post_mod_scores",
|
||||
score="qk",
|
||||
b="off_z",
|
||||
h="off_h",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
out="qk"
|
||||
|
|
@ -1044,7 +1071,7 @@ def bwd_dq_inner(
|
|||
output_name="mask_mod_output",
|
||||
score="qk",
|
||||
b="off_z",
|
||||
h="off_h",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
) | indent_except_first(3) }}
|
||||
|
|
@ -1065,7 +1092,7 @@ def bwd_dq_inner(
|
|||
output_name = "grad_scores",
|
||||
score="pre_mod_scores",
|
||||
b="off_z",
|
||||
h="off_h",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
grad_score_mod="ds"
|
||||
|
|
@ -1093,8 +1120,9 @@ def bwd_dq_inner(
|
|||
|
||||
@triton.jit
|
||||
def bwd_dkdv_inner(
|
||||
dk, dv, Q, k, v, DO, DELTA, LSE,
|
||||
off_z, off_h, offs_n1, offs_m1,
|
||||
Q, DO, DELTA, LSE, # pointers
|
||||
dk, dv, k, v,
|
||||
off_z, off_hq, offs_n1, offs_m1,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
|
|
@ -1130,7 +1158,7 @@ def bwd_dkdv_inner(
|
|||
output_name="post_mod_scores",
|
||||
score="qkT",
|
||||
b="off_z",
|
||||
h="off_h",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
out="qkT"
|
||||
|
|
@ -1141,7 +1169,7 @@ def bwd_dkdv_inner(
|
|||
output_name="mask_mod_output",
|
||||
score="qkT",
|
||||
b="off_z",
|
||||
h="off_h",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
) | indent_except_first(3) }}
|
||||
|
|
@ -1167,7 +1195,7 @@ def bwd_dkdv_inner(
|
|||
output_name = "grad_scores",
|
||||
score="pre_mod_scores",
|
||||
b="off_z",
|
||||
h="off_h",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
grad_score_mod="dsT"
|
||||
|
|
@ -1305,6 +1333,11 @@ def flex_attention_backward(*args, **kwargs):
|
|||
|
||||
kernel_options = dict(kernel_options)
|
||||
kernel_options["SM_SCALE"] = scale
|
||||
|
||||
# Determine GQA factor
|
||||
gqa_shared_heads = query.get_size()[1] // key.get_size()[1]
|
||||
kernel_options["GQA_SHARED_HEADS"] = gqa_shared_heads
|
||||
|
||||
# Inside of Triton kernel, only apply partial masking if partial blocks are computed.
|
||||
# full_kv_num_blocks is torch.zeros([1, 1, 1]) if partial blocks are not computed.
|
||||
kernel_options["HAS_FULL_BLOCKS"] = full_kv_num_blocks is not None
|
||||
|
|
@ -1370,7 +1403,7 @@ def flex_attention_backward(*args, **kwargs):
|
|||
layout=layout_k, # We use store_output only for grad_key
|
||||
subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer, mask_graph_buffer],
|
||||
mutated_inputs=[grad_query, grad_value],
|
||||
call_sizes=query.get_size() + [key.get_size()[2]],
|
||||
call_sizes=query.get_size() + key.get_size()[1:3],
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
**kernel_options,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from torch._inductor.virtualized import V
|
|||
|
||||
from ..ir import FixedLayout, FlexibleLayout
|
||||
from ..lowering import empty, empty_strided, lowerings
|
||||
from ..runtime.runtime_utils import next_power_of_2
|
||||
from ..runtime.runtime_utils import is_power_of_2, next_power_of_2
|
||||
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
||||
from .flex_attention import compute_forward_block, compute_next_offset_func
|
||||
|
||||
|
|
@ -18,15 +18,15 @@ aten = torch.ops.aten
|
|||
prims = torch.ops.prims
|
||||
|
||||
|
||||
def flex_decoding_grid(batch_size, num_heads, n_keys, d_model, meta):
|
||||
def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta):
|
||||
"""How is this kernel parallelized?
|
||||
We create a grid of (batch_size * num_heads, SPLIT_KV, 1)
|
||||
We create a grid of (batch_size * kv_heads, SPLIT_KV, 1)
|
||||
Each block is responsible for iterating over blocks of keys and values calculating
|
||||
the local output for their tile of keys and values over all full length of query.
|
||||
groups of SPLIT_KV blocks then combine their output to produce the final result.
|
||||
"""
|
||||
|
||||
return (batch_size * num_heads, meta["SPLIT_KV"], 1)
|
||||
return (batch_size * kv_heads, meta["SPLIT_KV"], 1)
|
||||
|
||||
|
||||
flex_decoding_template = TritonTemplate(
|
||||
|
|
@ -45,6 +45,7 @@ flex_decoding_template = TritonTemplate(
|
|||
# TILE_KV: length of each local KV split
|
||||
# BLOCK_M: block size that Q is padded along seqlen dim.
|
||||
# BLOCK_N: block size of K & V along N dimension.
|
||||
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
||||
#
|
||||
# change of base out of the loop
|
||||
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
||||
|
|
@ -61,70 +62,73 @@ flex_decoding_template = TritonTemplate(
|
|||
#
|
||||
# Output: ACC output accumulated across local KV split.
|
||||
|
||||
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
||||
|
||||
# Define Q Strides
|
||||
stride_qz = {{stride("Q", 0)}}
|
||||
stride_qh = {{stride("Q", 1)}}
|
||||
stride_qm = {{stride("Q", 2)}}
|
||||
stride_qk = {{stride("Q", 3)}}
|
||||
# Define K Strides
|
||||
stride_kz = {{stride("K", 0)}}
|
||||
stride_kh = {{stride("K", 1)}}
|
||||
stride_kn = {{stride("K", 2)}}
|
||||
stride_kk = {{stride("K", 3)}}
|
||||
# Define V Strides
|
||||
stride_vz = {{stride("V", 0)}}
|
||||
stride_vh = {{stride("V", 1)}}
|
||||
stride_vk = {{stride("V", 2)}}
|
||||
stride_vn = {{stride("V", 3)}}
|
||||
# Define M Strides
|
||||
stride_mz = {{stride("M", 0)}}
|
||||
stride_mh = {{stride("M", 1)}}
|
||||
stride_mt = {{stride("M", 2)}}
|
||||
stride_mm = {{stride("M", 3)}}
|
||||
# Define L Strides
|
||||
stride_lz = {{stride("L", 0)}}
|
||||
stride_lh = {{stride("L", 1)}}
|
||||
stride_lt = {{stride("L", 2)}}
|
||||
stride_lm = {{stride("L", 3)}}
|
||||
stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}}
|
||||
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
|
||||
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
|
||||
stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}}
|
||||
stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}}
|
||||
|
||||
|
||||
Z = {{size("Q", 0)}}
|
||||
H = {{size("Q", 1)}}
|
||||
Q_LEN = {{size("Q", 2)}}
|
||||
HKV = {{size("Q", 1)}}
|
||||
G: tl.constexpr = GQA_SHARED_HEADS
|
||||
HQ = HKV * G
|
||||
Q_LEN = {{size("Q", 3)}}
|
||||
KV_LEN = {{size("K", 2)}}
|
||||
# Make sure each split is a multiple of BLOCK_N
|
||||
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
|
||||
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
|
||||
|
||||
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
|
||||
|
||||
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
||||
SPARSE_H = {{size("KV_NUM_BLKS", 1)}}
|
||||
|
||||
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
||||
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
||||
SPARSE_KV_BLOCK_CNT: tl.constexpr = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
|
||||
|
||||
MATMUL_PRECISION = Q.dtype.element_ty
|
||||
|
||||
off_z = tl.program_id(0) // H
|
||||
off_h = tl.program_id(0) % H
|
||||
# Make sure each split is a multiple of BLOCK_N
|
||||
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
|
||||
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
|
||||
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
|
||||
|
||||
off_z = tl.program_id(0) // HKV
|
||||
off_hkv = tl.program_id(0) % HKV
|
||||
off_t = tl.program_id(1)
|
||||
|
||||
sparse_idx_z = off_z % SPARSE_Z
|
||||
sparse_idx_h = off_h % SPARSE_H
|
||||
q_offset = off_z * stride_qz + off_hkv * stride_qh
|
||||
k_offset = off_z * stride_kz + off_hkv * stride_kh
|
||||
v_offset = off_z * stride_vz + off_hkv * stride_vh
|
||||
|
||||
# TODO: strided KV_IDX and KV_NUM_BLKS
|
||||
sparse_hz_offset = sparse_idx_z * SPARSE_H + sparse_idx_h
|
||||
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
||||
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
|
||||
|
||||
sparse_idx_z = off_z % SPARSE_Z
|
||||
# TODO: support masks not broadcasted along the head dimension.
|
||||
tl.device_assert(SPARSE_HQ == 1)
|
||||
sparse_idx_h = 0
|
||||
|
||||
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
||||
SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
# initialize offsets
|
||||
tl.device_assert(BLOCK_M % G == 0)
|
||||
BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
|
||||
off_g = tl.arange(0, G) # [G]
|
||||
offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
||||
offs_hq = offs_g + off_hkv * G
|
||||
off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
|
||||
offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
# KV_IDX / FULL_KV_IDX and KV_NUM_BLKS / FULL_KV_NUM_BLKS are always contiguous.
|
||||
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_h
|
||||
|
||||
# Calculate KV blocks that belong this CTA.
|
||||
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
|
||||
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
|
||||
|
||||
q_offset = off_z * stride_qz + off_h * stride_qh
|
||||
k_offset = off_z * stride_kz + off_h * stride_kh
|
||||
v_offset = off_z * stride_vz + off_h * stride_vh
|
||||
q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
|
||||
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + q_offset,
|
||||
shape=(Q_LEN, BLOCK_DMODEL), # (M, d)
|
||||
|
|
@ -134,18 +138,13 @@ flex_decoding_template = TritonTemplate(
|
|||
order=(1, 0)
|
||||
)
|
||||
if SAFE_M_BOUNDARY:
|
||||
q = tl.load(Q_block_ptr)
|
||||
q = tl.load(Q + q_offset + q_range)
|
||||
else:
|
||||
q = tl.load(Q_block_ptr, boundary_check=(0, ))
|
||||
mask = off_m[None, :, None] < Q_LEN
|
||||
q = tl.load(Q + q_offset + q_range, mask)
|
||||
|
||||
# initialize offsets
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
q = tl.reshape(q, [BLOCK_M, BLOCK_DMODEL])
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# Apply both score_mod and mask_mod
|
||||
|
|
@ -171,7 +170,7 @@ flex_decoding_template = TritonTemplate(
|
|||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(KV_LEN, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(off_n, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
|
|
@ -183,7 +182,7 @@ flex_decoding_template = TritonTemplate(
|
|||
# accumulatd values
|
||||
acc, l_i, m_i,
|
||||
#offsets
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
||||
#block sparse data
|
||||
kv_indices, kv_num_blocks,
|
||||
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
||||
|
|
@ -216,7 +215,7 @@ flex_decoding_template = TritonTemplate(
|
|||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(KV_LEN, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(off_n, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
|
|
@ -228,7 +227,7 @@ flex_decoding_template = TritonTemplate(
|
|||
# accumulatd values
|
||||
acc, l_i, m_i,
|
||||
#offsets
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
||||
#block sparse data
|
||||
kv_indices, kv_num_blocks,
|
||||
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
||||
|
|
@ -237,42 +236,46 @@ flex_decoding_template = TritonTemplate(
|
|||
IS_FULL_BLOCKS=True,
|
||||
)
|
||||
|
||||
m_offset = off_h * stride_mh + off_z * stride_mz
|
||||
l_offset = off_h * stride_lh + off_z * stride_lz
|
||||
m_offset = off_t * stride_mt + off_z * stride_mz
|
||||
l_offset = off_t * stride_lt + off_z * stride_lz
|
||||
|
||||
M_block_ptr = tl.make_block_ptr(
|
||||
base=M + m_offset,
|
||||
shape=(SPLIT_KV, Q_LEN), # (T, M)
|
||||
strides=(stride_mt, stride_mm),
|
||||
offsets=(off_t, 0),
|
||||
block_shape=(1, BLOCK_M),
|
||||
shape=(G, Q_LEN), # (G, M)
|
||||
strides=(stride_mh, stride_mm),
|
||||
offsets=(off_hkv*G, 0),
|
||||
block_shape=(G, BLOCK_M_PER_HQ),
|
||||
order=(1, 0)
|
||||
)
|
||||
L_block_ptr = tl.make_block_ptr(
|
||||
base=L + l_offset,
|
||||
shape=(SPLIT_KV, Q_LEN), # (T, M)
|
||||
strides=(stride_lt, stride_lm),
|
||||
offsets=(off_t, 0),
|
||||
block_shape=(1, BLOCK_M),
|
||||
shape=(G, Q_LEN), # (G, M)
|
||||
strides=(stride_lh, stride_lm),
|
||||
offsets=(off_hkv*G, 0),
|
||||
block_shape=(G, BLOCK_M_PER_HQ),
|
||||
order=(1, 0)
|
||||
)
|
||||
|
||||
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
|
||||
m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
|
||||
l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
|
||||
if SAFE_M_BOUNDARY:
|
||||
tl.store(M_block_ptr, m_i[None, :])
|
||||
tl.store(L_block_ptr, l_i[None, :])
|
||||
tl.store(M_block_ptr, m_i)
|
||||
tl.store(L_block_ptr, l_i)
|
||||
else:
|
||||
tl.store(M_block_ptr, m_i[None, :], boundary_check=(1,))
|
||||
tl.store(L_block_ptr, l_i[None, :], boundary_check=(1,))
|
||||
tl.store(M_block_ptr, m_i, boundary_check=(1,))
|
||||
tl.store(L_block_ptr, l_i, boundary_check=(1,))
|
||||
|
||||
# -- store output
|
||||
idx_z = off_z
|
||||
idx_h = off_h
|
||||
idx_t = off_t
|
||||
idx_m = offs_m[:, None]
|
||||
idx_d = offs_d[None, :]
|
||||
idx_hq = off_hkv*G + off_g[:, None, None]
|
||||
idx_m = off_m[None, :, None]
|
||||
idx_d = offs_d[None, None, :]
|
||||
# TODO generalize and add proper mask support
|
||||
mask = (idx_m < Q_LEN)
|
||||
{{store_output(("idx_z", "idx_h", "idx_t", "idx_m", "idx_d"), "acc", "mask")}}
|
||||
acc = acc.reshape(G, BLOCK_M_PER_HQ, BLOCK_DMODEL)
|
||||
{{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
|
||||
"""
|
||||
+ compute_forward_block
|
||||
+ compute_next_offset_func,
|
||||
|
|
@ -325,6 +328,15 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
) = block_mask
|
||||
|
||||
kernel_options = dict(kernel_options)
|
||||
|
||||
# Calculate GQA head sharing
|
||||
gqa_shared_heads = query.get_size()[1] // key.get_size()[1]
|
||||
if not is_power_of_2(gqa_shared_heads):
|
||||
raise ValueError(
|
||||
"Number of shared query heads sharing the same KV head must be power of 2. "
|
||||
)
|
||||
kernel_options["GQA_SHARED_HEADS"] = gqa_shared_heads
|
||||
|
||||
# Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod
|
||||
kernel_options["HAS_FULL_BLOCKS"] = full_kv_num_blocks is not None
|
||||
if (
|
||||
|
|
@ -363,10 +375,10 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
assert kernel_options["SPLIT_KV"] <= MAX_SPLIT_KV
|
||||
|
||||
# create config dependent intermediate buffers
|
||||
buf_ML_shape = query.get_size()[:-2] + [
|
||||
MAX_SPLIT_KV,
|
||||
query.get_size()[-2],
|
||||
] # [B, H, SPLIT_KV, M]
|
||||
buf_ACC_shape = (
|
||||
query.get_size()[:1] + [MAX_SPLIT_KV] + query.get_size()[1:]
|
||||
) # [B, SPLIT_KV, Hq, M, D]
|
||||
buf_ML_shape = buf_ACC_shape[:-1]
|
||||
buf_M = empty_strided(
|
||||
buf_ML_shape,
|
||||
None,
|
||||
|
|
@ -380,10 +392,6 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
device=query.get_device(),
|
||||
)
|
||||
|
||||
buf_ACC_shape = (
|
||||
query.get_size()[:-2] + [MAX_SPLIT_KV] + query.get_size()[-2:]
|
||||
) # [B, H, SPLIT_KV, M, D]
|
||||
|
||||
layout_acc = FixedLayout(
|
||||
query.get_device(),
|
||||
torch.float32,
|
||||
|
|
@ -403,14 +411,31 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
V.graph.sizevars.size_hint(
|
||||
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||
)
|
||||
* gqa_shared_heads
|
||||
),
|
||||
16,
|
||||
)
|
||||
)
|
||||
V.graph.sizevars.guard_leq(m, sympy.Integer(kernel_options["BLOCK_M"]))
|
||||
|
||||
# Reshape query for GQA: [B, Hq, Mq, D] -> [B, Hkv, G, Mq, D]
|
||||
gqa_query_shape = (
|
||||
query.get_size()[:1]
|
||||
+ [key.get_size()[1], gqa_shared_heads]
|
||||
+ query.get_size()[2:]
|
||||
)
|
||||
gqa_query_stride = (
|
||||
query.get_stride()[:1]
|
||||
+ [query.get_stride()[1] * gqa_shared_heads, query.get_stride()[1]]
|
||||
+ query.get_stride()[2:]
|
||||
)
|
||||
query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride)
|
||||
|
||||
V.graph.sizevars.guard_leq(
|
||||
m * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"])
|
||||
)
|
||||
|
||||
kernel_options["SAFE_M_BOUNDARY"] = (
|
||||
query.get_size()[-2] % kernel_options["BLOCK_M"]
|
||||
(m * gqa_shared_heads) % kernel_options["BLOCK_M"]
|
||||
) == 0
|
||||
kernel_options["SAFE_N_BOUNDARY"] = True
|
||||
|
||||
|
|
@ -472,18 +497,18 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
|
||||
# Reduction
|
||||
|
||||
g_M = lowerings[aten.max](buf_M, dim=-2, keepdim=True)[0]
|
||||
g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0]
|
||||
adj_M = lowerings[aten.sub](buf_M, g_M)
|
||||
alpha = lowerings[aten.exp2](adj_M)
|
||||
|
||||
buf_L = lowerings[aten.mul](buf_L, alpha)
|
||||
g_L = lowerings[aten.sum](buf_L, axis=-2)
|
||||
g_L = lowerings[aten.sum](buf_L, axis=1)
|
||||
logsumexp = lowerings[aten.log2](g_L)
|
||||
logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=-2))
|
||||
logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1))
|
||||
|
||||
alpha_unseq = lowerings[aten.unsqueeze](alpha, 4)
|
||||
buf_ACC = lowerings[aten.mul](buf_ACC, alpha_unseq)
|
||||
output = lowerings[aten.sum](buf_ACC, axis=-3)
|
||||
output = lowerings[aten.sum](buf_ACC, axis=1)
|
||||
L_unseq = lowerings[aten.unsqueeze](g_L, 3)
|
||||
output = lowerings[aten.div](output, L_unseq)
|
||||
output = lowerings[prims.convert_element_type](output, query.get_dtype())
|
||||
|
|
|
|||
|
|
@ -79,9 +79,10 @@ def _vmap_for_bhqkv(
|
|||
prefix: Tuple[Optional[int], ...],
|
||||
suffix: Tuple[Optional[int], ...] = (),
|
||||
out_dims: Union[int, List[Optional[int]]] = 0,
|
||||
group_dim: bool = False,
|
||||
):
|
||||
"""Used to vmap both score_mods and mask_mods over 4-dimensional inputs.
|
||||
Mapping over the [b, h, q_idx, kv_idx] dimensions.
|
||||
"""Used to vmap both score_mods and mask_mods over 4-dimensional/5-dimension inputs.
|
||||
Mapping over the [b, hq, q_idx, kv_idx] or [b, hkv, g, q_idx, kv_idx] dimensions.
|
||||
|
||||
Args:
|
||||
fn (callable): The function to vmap.
|
||||
|
|
@ -98,10 +99,19 @@ def _vmap_for_bhqkv(
|
|||
callable: The vmapped function.
|
||||
"""
|
||||
# We vamp a function 4 times, broadcasting the [b, h, q_idx, kv_idx] dimensions
|
||||
dimensions: List[Tuple[None | int, None | int, None | int, None | int]] = []
|
||||
dimensions = [
|
||||
(None, None, None, 0),
|
||||
(None, None, 0, None),
|
||||
(None, 0, None, None),
|
||||
]
|
||||
|
||||
if group_dim:
|
||||
dimensions += [
|
||||
(None, 0, None, None),
|
||||
]
|
||||
|
||||
dimensions += [
|
||||
(0, None, None, None),
|
||||
]
|
||||
|
||||
|
|
@ -616,7 +626,7 @@ def create_mask(
|
|||
Args:
|
||||
mod_fn (Union[_score_mod_signature, _mask_mod_signature]): Function to modify attention scores.
|
||||
B (int): Batch size.
|
||||
H (int): Number of heads.
|
||||
H (int): Number of query heads.
|
||||
Q_LEN (int): Sequence length of query.
|
||||
KV_LEN (int): Sequence length of key/value.
|
||||
device (str): Device to run the mask creation on.
|
||||
|
|
@ -698,7 +708,7 @@ def create_block_mask(
|
|||
It should return a boolean tensor indicating which attention connections are allowed (True)
|
||||
or masked out (False).
|
||||
B (int): Batch size.
|
||||
H (int): Number of heads.
|
||||
H (int): Number of query heads.
|
||||
Q_LEN (int): Sequence length of query.
|
||||
KV_LEN (int): Sequence length of key/value.
|
||||
device (str): Device to run the mask creation on.
|
||||
|
|
@ -794,6 +804,7 @@ def flex_attention(
|
|||
score_mod: Optional[_score_mod_signature] = None,
|
||||
block_mask: Optional[BlockMask] = None,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
kernel_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Tensor:
|
||||
r"""This function implements scaled dot product attention with an arbitrary attention score modification function.
|
||||
|
|
@ -818,20 +829,21 @@ def flex_attention(
|
|||
- ``score``: A scalar tensor representing the attention score,
|
||||
with the same data type and device as the query, key, and value tensors.
|
||||
- ``batch``, ``head``, ``q_idx``, ``k_idx``: Scalar tensors indicating
|
||||
the batch index, head index, query index, and key/value index, respectively.
|
||||
the batch index, query head index, query index, and key/value index, respectively.
|
||||
These should have the ``torch.int`` data type and be located on the same device as the score tensor.
|
||||
|
||||
Args:
|
||||
query (Tensor): Query tensor; shape :math:`(B, H, L, E)`.
|
||||
key (Tensor): Key tensor; shape :math:`(B, H, S, E)`.
|
||||
value (Tensor): Value tensor; shape :math:`(B, H, S, Ev)`.
|
||||
query (Tensor): Query tensor; shape :math:`(B, Hq, L, E)`.
|
||||
key (Tensor): Key tensor; shape :math:`(B, Hkv, S, E)`.
|
||||
value (Tensor): Value tensor; shape :math:`(B, Hkv, S, Ev)`.
|
||||
score_mod (Optional[Callable]): Function to modify attention scores. By default no score_mod is applied.
|
||||
block_mask (Optional[BlockMask]): BlockMask object that controls the blocksparsity pattern of the attention.
|
||||
scale (Optional[float]): Scaling factor applied prior to softmax. If none, the default value is set to :math:`\frac{1}{\sqrt{E}}`.
|
||||
enable_gqa (bool): If set to True, enables Grouped Query Attention (GQA) and broadcasts key/value heads to query heads.
|
||||
kernel_options (Optional[Dict[str, Any]]): Options to pass into the Triton kernels.
|
||||
|
||||
Returns:
|
||||
output (Tensor): Attention output; shape :math:`(B, H, L, Ev)`.
|
||||
output (Tensor): Attention output; shape :math:`(B, Hq, L, Ev)`.
|
||||
|
||||
Shape legend:
|
||||
- :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
|
||||
|
|
@ -855,6 +867,20 @@ def flex_attention(
|
|||
raise NotImplementedError("NYI: S must be <128 or a multiple of 128")
|
||||
if key.size(-2) % 128 != 0:
|
||||
raise NotImplementedError("NYI: L must be a multiple of 128")
|
||||
if (not enable_gqa) and query.size(-3) != key.size(-3):
|
||||
raise ValueError(
|
||||
f"Expect query and key/value to have the same number of heads "
|
||||
f"but got Hq={query.size(-3)} and Hkv={key.size(-3)}. "
|
||||
f"Try setting enable_gqa=True for GQA."
|
||||
)
|
||||
if enable_gqa:
|
||||
Hq = query.size(1)
|
||||
Hkv = key.size(1)
|
||||
if Hq % Hkv != 0:
|
||||
raise ValueError(
|
||||
f"Expect number of query heads to be a multiple of kv heads for GQA "
|
||||
f"but got Hq={Hq} and Hkv={Hkv}."
|
||||
)
|
||||
|
||||
if score_mod is None:
|
||||
score_mod = _identity
|
||||
|
|
@ -871,8 +897,9 @@ def flex_attention(
|
|||
)
|
||||
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
# mark head_dim always to be static
|
||||
# mark head_dim and number of heads to be static
|
||||
for x in [query, key, value]:
|
||||
torch._dynamo.mark_static(x, -3)
|
||||
torch._dynamo.mark_static(x, -1)
|
||||
out, _ = flex_attention_hop(
|
||||
query, key, value, score_mod, block_mask.as_tuple(), scale, kernel_options
|
||||
|
|
|
|||
Loading…
Reference in a new issue