Add explicit GQA support. (#131559)

### tl;dr
This PR adds GQA support to higher order op `flex_attention`.

## Details
When `enable_gqa` is set to True, HOP `flex_attention(score_mod, query, key, value, block_mask, enable_gqa)` runs Group Query Attention(GQA), where the number of query heads (Hq) is a multiple of number of key/value heads (Hkv). Each group of query heads (`Hq//Hkv` heads) attends to a shared kv head.
Otherwise, `flex_attention` assumes Multi Head Attention (MHA) where the number of query heads is equal the number of kv heads.

The `score_mod` and `mask_mod` API are adapted accordingly to take `q_head` as head index.
```
def score_mod(score: torch.Tensor, batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor

def mask_mod(batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```

## Example
```python
import torch
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask

torch.manual_seed(0)

def query_key_value_clones(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    dtype: torch.dtype = None,
):
    """Clones the query, key, and value tensors and moves them to the specified dtype."""
    if dtype is None:
        dtype = query.dtype
    query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
    key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
    value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
    return query_ref, key_ref, value_ref

# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim).
# query and key/value can have different num_heads and seq_len
# Here 8 query heads share one KV head.
query = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
key = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
value = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)

query1, key1, value1 = query_key_value_clones(query, key, value)

# Lets create a score_modification. We take alibi_bias as an example.
# score_mod takes batch index, query head index, query index, and key/value index.
def _generate_alibi_bias(num_kv_heads: int, num_q_heads: int):
    def _alibi_bias(
        score: torch.Tensor,
        b: torch.Tensor,
        hq: torch.Tensor,
        token_q: torch.Tensor,
        token_kv: torch.Tensor,
    ) -> torch.Tensor:
        # Let's calculate kv head from query head index
        group = num_q_heads // num_kv_heads
        hkv = hq // group

        scale = torch.exp2(-((hkv + 1) * 8.0 / num_kv_heads))
        return score + (token_kv - token_q) * scale

    return _alibi_bias

# Let's apply a casual mask on top of it
def causal_mask(b, h, q, kv):
    return q >= kv

# Generate a block mask for our new mask_mod function.
# The mask is broadcasted long head & batch dimensions.
block_mask = create_block_mask(causal_mask, B=1, H=1, Q_LEN=2048, KV_LEN=2048)

# Lets call flex_attention with our new score modification and block mask under eager mode.
output = flex_attention(query, key, value, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)

# Now lets compile flex_attention and run the flex_attention kernel.
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention(query1, key1, value1, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)

torch.testing.assert_close(output, out_compiled, atol=5e-2, rtol=2e-2)

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131559
Approved by: https://github.com/drisspg
This commit is contained in:
joydddd 2024-08-09 11:09:18 -07:00 committed by PyTorch MergeBot
parent dc8bb2636c
commit 4110cb6ba7
7 changed files with 461 additions and 362 deletions

View file

@ -15,6 +15,7 @@ import torch.nn.functional as F
from torch.nn.attention.flex_attention import (
_create_empty_block_mask,
create_block_mask,
create_mask,
flex_attention,
)
@ -106,9 +107,7 @@ def generate_inputs(
torch.rand, kv_shape, device=device, dtype=dtype, requires_grad=requires_grad
)
query = (
make_q()
.view(batch_size, num_h_groups * q_sequence_length, kv_heads, head_dim)
.transpose(1, 2)
make_q().view(batch_size, q_sequence_length, q_heads, head_dim).transpose(1, 2)
)
key = (
make_kv()
@ -146,8 +145,9 @@ def run_single_experiment(
if get_func_name(config.mask_mod) == "causal":
kwargs["is_causal"] = True
def eager_sdpa(query, key, value, _):
return F.scaled_dot_product_attention(query, key, value, **kwargs)
def eager_sdpa(query, key, value, attn_mask):
out = F.scaled_dot_product_attention(query, key, value, attn_mask, **kwargs)
return out.reshape(batch_size, q_heads, q_seq_len, head_dim)
if max_autotune:
compiled_sdpa = torch.compile(
@ -161,27 +161,62 @@ def run_single_experiment(
if mask_mod:
block_mask = create_block_mask(
mask_mod, 1, 1, q_seq_len * (q_heads // kv_heads), kv_seq_len, query.device
mask_mod, 1, 1, q_seq_len, kv_seq_len, query.device
)
else:
block_mask = _create_empty_block_mask(query, key)
if mask_mod and get_func_name(mask_mod) != "causal":
attn_mask = create_mask(mask_mod, 1, 1, query.shape[-2], key.shape[-2])
else:
attn_mask = None
# Broadcast query/key for eager.
b_key = torch.repeat_interleave(key, q_heads // kv_heads, dim=1)
b_value = torch.repeat_interleave(value, q_heads // kv_heads, dim=1)
forward_eager_time = benchmark_torch_function_in_microseconds(
eager_sdpa, query, key, value, score_mod
eager_sdpa, query, b_key, b_value, attn_mask
)
forward_compiled_time = benchmark_torch_function_in_microseconds(
compiled_sdpa, query, key, value, score_mod, block_mask
compiled_sdpa,
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=True,
)
out_eager = eager_sdpa(query, b_key, b_value, attn_mask)
out_compile = compiled_sdpa(
query,
b_key,
b_value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=True,
)
if score_mod is None:
torch.testing.assert_close(out_eager, out_compile, atol=1e-2, rtol=1e-2)
if config.calculate_bwd_time:
out_eager = eager_sdpa(query, key, value, score_mod)
out_eager = eager_sdpa(query, b_key, b_value, attn_mask)
dOut = torch.randn_like(out_eager)
backward_eager_time = benchmark_torch_function_in_microseconds(
out_eager.backward, dOut, retain_graph=True
)
out_compile = compiled_sdpa(query, key, value, score_mod, block_mask)
dOut = torch.randn_like(out_eager)
out_compile = compiled_sdpa(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=True,
)
dOut = torch.randn_like(out_compile)
backward_compile_time = benchmark_torch_function_in_microseconds(
out_compile.backward, dOut, retain_graph=True
)
@ -250,8 +285,6 @@ def calculate_tflops(config: ExperimentConfig, results: ExperimentResults) -> fl
def get_func_name(func):
if func is None:
return "None"
if "gqa" in func.__name__:
return func.__name__
func_str = str(func)
if "<locals>" in func_str:
# For locally defined functions
@ -369,8 +402,9 @@ def generate_score_mods(score_mods: List[str]) -> List[Callable | None]:
return score + 2 * h
function_dict = {
"noop": noop,
"noop": None,
"causal": None,
"offset": None,
"rel": relative_bias,
"head_bias": head_bias,
}
@ -384,45 +418,22 @@ def generate_mask_mods(score_mods: List[str]) -> List[Callable | None]:
def causal(b, h, m, n):
return m >= n
def gen_offset(off):
def offset(b, h, m, n):
return m + off >= n
return offset
mask_mod_dict = {
"noop": None,
"causal": causal,
"offset": gen_offset,
"rel": None,
"head_bias": None,
}
return [mask_mod_dict[name] for name in score_mods]
def get_gqa_score_mod(score_mod, G, q_seq_len):
if score_mod is None:
return None
def score_mod_gqa(score, b, hkv, m, n):
g = m // q_seq_len
new_m = m % q_seq_len
hq = hkv * G + g
return score_mod(score, b, hq, new_m, n)
score_mod_name = get_func_name(score_mod)
set_func_name(score_mod_gqa, score_mod_name + "_gqa")
return score_mod_gqa
def get_gqa_mask_mod(mask_mod, G, q_seq_len):
if mask_mod is None:
return None
def mask_mod_gqa(b, h, m, n):
g = m // q_seq_len
new_m = m % q_seq_len
hq = h * G + g
return mask_mod(b, hq, new_m, n)
mask_mod_name = get_func_name(mask_mod)
set_func_name(mask_mod_gqa, mask_mod_name + "_gqa")
return mask_mod_gqa
def generate_flash_configs(
calculate_bwd: bool,
dtype: torch.dtype,
@ -521,16 +532,14 @@ def generate_experiment_configs(
(q_heads, kv_heads),
(q_seq_len, kv_seq_len),
head_dim,
score_mod,
mask_mod,
(score_mod, mask_mod),
dtype,
) in itertools.product(
kv_cache_size if kv_cache_size else batch_sizes,
num_heads,
q_kv_seq_lens,
head_dims,
score_mods,
mask_mods,
zip(score_mods, mask_mods),
dtypes,
):
if kv_cache_size:
@ -541,11 +550,10 @@ def generate_experiment_configs(
if bsz <= 0:
continue
if q_heads != kv_heads: # GQA work around before it's explicitly supported
assert q_heads % kv_heads == 0
G = q_heads // kv_heads
score_mod = get_gqa_score_mod(score_mod, G, q_seq_len)
mask_mod = get_gqa_mask_mod(mask_mod, G, q_seq_len)
assert q_heads % kv_heads == 0
if mask_mod and get_func_name(mask_mod) == "gen_offset":
mask_mod = mask_mod(kv_seq_len // 2)
all_configs.append(
ExperimentConfig(
@ -577,7 +585,7 @@ def main(args):
args.mods,
args.decoding,
args.kv_cache_size,
args.cal_bandwidth,
args.throughput,
)
):
results.append(
@ -658,7 +666,7 @@ Ignores -b batch size and calculate batch size from kv_cache size instead when s
""",
)
parser.add_argument(
"--cal-bandwidth",
"--throughput",
action="store_true",
help="Calculate kernel memory bandwidth & computational throughput. ",
)

View file

@ -53,13 +53,23 @@ def rmse(ref, res):
return torch.sqrt(torch.mean(torch.square(ref - res)))
def create_attention(score_mod, block_mask):
return functools.partial(flex_attention, score_mod=score_mod, block_mask=block_mask)
def create_attention(score_mod, block_mask, enable_gqa=False):
return functools.partial(
flex_attention,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=enable_gqa,
)
def create_block_mask_test(score_mod, query, key):
block_mask = create_block_mask(
score_mod, 1, 1, query.shape[-2], key.shape[-2], query.device
score_mod,
1,
1,
query.shape[-2],
key.shape[-2],
query.device,
)
return block_mask
@ -276,7 +286,9 @@ class TestFlexAttention(InductorTestCase):
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
block_mask = None
sdpa_partial = create_attention(score_mod, block_mask)
sdpa_partial = create_attention(
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
)
compiled_sdpa = torch.compile(sdpa_partial)
golden_out = sdpa_partial(q_gold, k_gold, v_gold)
ref_out = sdpa_partial(q_ref, k_ref, v_ref)
@ -548,6 +560,23 @@ class TestFlexAttention(InductorTestCase):
D,
)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
@common_utils.parametrize("score_mod", test_score_mods)
def test_GQA(self, dtype: torch.dtype, score_mod: Callable):
self.run_test(
score_mod,
dtype,
B,
H * 4, # Hq = 4*Hkv.
S // 8,
D,
B,
H,
S,
D,
)
test_strides = [
((H * S * D, S * D, D, 1), 997), # offset
((H * D, D, B * H * D, 1), 499), # transposed dimensions
@ -1166,6 +1195,29 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
self.run_test_with_call(attention)
@supported_platform
def test_GQA_causal_mask(self):
def mask_mod(b, h, q, kv):
return q >= kv
block_mask = create_block_mask(mask_mod, 1, 1, S // 8, S // 8)
attention = functools.partial(
flex_attention, block_mask=block_mask, enable_gqa=True
)
self.run_test_with_call(
attention,
torch.float16,
B,
H * 4, # Hq = 4*Hkv.
S // 8,
D,
B,
H,
S // 8,
D,
)
@supported_platform
def test_custom_block_mask_generator(self):
def mask_mod(b, h, q, kv):

View file

@ -38,30 +38,13 @@ index = torch.ops.aten.index
Tensor = torch.Tensor
# score_mod / gqa_mask convert for GQA inputs before GQA is explictly supported
def get_gqa_score_mod(score_mod, G, q_seq_len):
def score_mod_gqa(score, b, hkv, m, n):
g = m // q_seq_len
g = torch.where(g < G, g, 0)
new_m = m % q_seq_len
hq = hkv * G + g
return score_mod(score, b, hq, new_m, n)
return score_mod_gqa
def get_gqa_mask_mod(mask_fn, G, q_seq_len):
def mask_mod_gqa(b, hkv, m, n):
g = m // q_seq_len
new_m = m % q_seq_len
hq = hkv * G + g
return mask_fn(b, hq, new_m, n)
return mask_mod_gqa
def create_attention(score_mod, block_mask):
return functools.partial(flex_attention, score_mod=score_mod, block_mask=block_mask)
def create_attention(score_mod, block_mask, enable_gqa=False):
return functools.partial(
flex_attention,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=enable_gqa,
)
def create_block_mask_test(score_mod, query, key):
@ -200,7 +183,6 @@ test_Hq_Hkv = [
(16, 1),
(8, 2),
(16, 16),
(20, 1),
]
(Hq, Hkv) = (16, 8)
@ -251,20 +233,11 @@ class TestFlexDecoding(InductorTestCase):
msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
self.assertTrue(False, msg)
def _check_out_and_grad(
def _check_out(
self,
golden_out: torch.Tensor,
ref_out: torch.Tensor,
compiled_out: torch.Tensor,
q_gold: torch.Tensor,
q_ref: torch.Tensor,
q: torch.Tensor,
k_gold: torch.Tensor,
k_ref: torch.Tensor,
k: torch.Tensor,
v_gold: torch.Tensor,
v_ref: torch.Tensor,
v: torch.Tensor,
):
dtype = ref_out.dtype
with torch.no_grad():
@ -278,21 +251,6 @@ class TestFlexDecoding(InductorTestCase):
# Checkout output
self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")
# TODO: add backward support
# # Check gradients
# q_fudge_factor = 2.5 * fudge_factor
# self._check_equal(
# q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query"
# )
# k_fudge_factor = 4 * fudge_factor
# self._check_equal(
# k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key"
# )
# v_fudge_factor = 4 * fudge_factor
# self._check_equal(
# v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value"
# )
def run_test(
self,
score_mod: Callable,
@ -307,10 +265,8 @@ class TestFlexDecoding(InductorTestCase):
KV_D: int = D,
):
assert Q_H % KV_H == 0
score_mod = get_gqa_score_mod(score_mod, G=Q_H // KV_H, q_seq_len=Q_S)
q = torch.randn(
(Q_B, KV_H, Q_S * (Q_H // KV_H), Q_D),
(Q_B, Q_H, Q_S, Q_D),
dtype=dtype,
device="cuda",
requires_grad=False,
@ -325,34 +281,18 @@ class TestFlexDecoding(InductorTestCase):
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
block_mask = None
sdpa_partial = create_attention(score_mod, block_mask)
sdpa_partial = create_attention(
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
)
compiled_sdpa = torch.compile(sdpa_partial)
golden_out = sdpa_partial(q_gold, k_gold, v_gold)
ref_out = sdpa_partial(q_ref, k_ref, v_ref)
compiled_out = compiled_sdpa(q, k, v)
# TODO: Add backward support
# backward_grad = torch.randn(
# (Q_B, KV_H, Q_S *(Q_H // KV_H), Q_D), dtype=dtype, device="cuda"
# )
# golden_out.backward(backward_grad.to(torch.float64))
# ref_out.backward(backward_grad)
# compiled_out.backward(backward_grad)
self._check_out_and_grad(
self._check_out(
golden_out,
ref_out,
compiled_out,
q_gold,
q_ref,
q,
k_gold,
k_ref,
k,
v_gold,
v_ref,
v,
)
def run_test_with_call(
@ -391,19 +331,10 @@ class TestFlexDecoding(InductorTestCase):
ref_out = golden_call(q_ref, k_ref, v_ref)
compiled_out = compiled_sdpa(q, k, v)
self._check_out_and_grad(
self._check_out(
golden_out,
ref_out,
compiled_out,
q_gold,
q_ref,
q,
k_gold,
k_ref,
k,
v_gold,
v_ref,
v,
)
@supported_platform
@ -488,7 +419,7 @@ class TestFlexDecoding(InductorTestCase):
k_shape = (B, Hkv, S, D)
v_shape = (B, Hkv, S, D)
q = q1.view(Hq // Hkv, Hkv, B, D).transpose(0, 2)
q = q1.view(1, Hq, B, D).transpose(0, 2)
k_strides, k_offset = k_s(B, Hkv, S, D)
k_max = [x * (y - 1) for x, y in zip(k_strides, k_shape)]
@ -503,7 +434,9 @@ class TestFlexDecoding(InductorTestCase):
v = torch.as_strided(v1, v_shape, v_strides, v_offset)
sdpa_partial = create_attention(
score_mod=_generate_alibi_bias(8), block_mask=None
score_mod=_generate_alibi_bias(8),
block_mask=None,
enable_gqa=(not Hq == Hkv),
)
compiled_sdpa = torch.compile(sdpa_partial)
ref_out = sdpa_partial(q, k, v)

View file

@ -184,6 +184,11 @@ def math_attention(
score_mod: The score_mod function
other_buffers: Other buffers that are passed to the score_mod function
"""
# broadcast query & key along head dim for GQA
G = query.size(1) // key.size(1)
value = torch.repeat_interleave(value, G, dim=1)
key = torch.repeat_interleave(key, G, dim=1)
_, post_mod_scores = _math_attention_inner(
query,
key,
@ -673,6 +678,10 @@ def sdpa_dense_backward(
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
G = query.size(1) // key.size(1)
key = torch.repeat_interleave(key, G, dim=1)
value = torch.repeat_interleave(value, G, dim=1)
scores, post_mod_scores = _math_attention_inner(
query,
key,
@ -731,6 +740,18 @@ def sdpa_dense_backward(
grad_query = grad_scores @ key
grad_key = grad_scores.transpose(-2, -1) @ query
# Reduce DK, DV along broadcasted heads.
grad_key = grad_key.view(
grad_key.size(0), -1, G, grad_key.size(-2), grad_key.size(-1)
)
grad_value = grad_value.view(
grad_value.size(0), -1, G, grad_value.size(-2), grad_value.size(-1)
)
grad_key = torch.sum(grad_key, 2, keepdim=False)
grad_value = torch.sum(grad_value, 2, keepdim=False)
return grad_query.contiguous(), grad_key.contiguous(), grad_value.contiguous()

View file

@ -29,7 +29,7 @@ log = logging.getLogger(__name__)
aten = torch.ops.aten
def flex_attention_grid(batch_size, num_heads, num_queries, d_model, meta):
def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta):
"""How is this kernel parallelized?
We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1)
Each block is responsible for iterating over blocks of keys and values calculating
@ -37,7 +37,7 @@ def flex_attention_grid(batch_size, num_heads, num_queries, d_model, meta):
"""
import triton
return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * num_heads, 1)
return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1)
def create_placeholder(
@ -134,6 +134,7 @@ compute_flex_attention = r"""
# Q: Query, K: Key, V: Value
# M: Number of queries, N: Number of keys/values, D: Model dimension
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
#
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
@ -163,29 +164,31 @@ compute_flex_attention = r"""
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
Z = {{size("Q", 0)}}
H = {{size("Q", 1)}}
HQ = {{size("Q", 1)}}
Q_LEN = {{size("Q", 2)}}
KV_LEN = {{size("K", 2)}}
MATMUL_PRECISION = Q.dtype.element_ty
q_start = tl.program_id(0)
off_z = tl.program_id(1) // H
off_h = tl.program_id(1) % H
off_z = tl.program_id(1) // HQ
off_hq = tl.program_id(1) % HQ
off_hkv = off_hq // GQA_SHARED_HEADS
off_g = off_hq % GQA_SHARED_HEADS
q_offset = off_z * stride_qz + off_h * stride_qh
k_offset = off_z * stride_kz + off_h * stride_kh
v_offset = off_z * stride_vz + off_h * stride_vh
q_offset = off_z * stride_qz + off_hq * stride_qh
k_offset = off_z * stride_kz + off_hkv * stride_kh
v_offset = off_z * stride_vz + off_hkv * stride_vh
Q = Q + q_offset
K = K + k_offset
V = V + v_offset
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
SPARSE_H = {{size("KV_NUM_BLKS", 1)}}
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
sparse_idx_z = off_z % SPARSE_Z
sparse_idx_h = off_h % SPARSE_H
sparse_idx_hq = off_hq % SPARSE_HQ
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
@ -201,7 +204,7 @@ compute_flex_attention = r"""
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
# KV_IDX and KV_NUM_BLKS are always contiguous.
sparse_hz_offset = sparse_idx_z * SPARSE_H + sparse_idx_h
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
sparse_kv_num_blks_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT + q_start // SPARSE_Q_MULTIPLE
sparse_kv_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (q_start // SPARSE_Q_MULTIPLE) * SPARSE_KV_BLOCK_CNT # noqa: B950
@ -244,7 +247,7 @@ compute_flex_attention = r"""
acc, l_i, m_i = forward_inner(
q, K_block_ptr, V_block_ptr,
acc, l_i, m_i,
off_z, off_h, offs_m, offs_n,
off_z, off_hq, offs_m[:, None], offs_n[None, :],
kv_indices, kv_num_blocks,
0, kv_num_blocks * SPARSE_KV_MULTIPLE,
MATMUL_PRECISION,
@ -282,7 +285,7 @@ compute_flex_attention = r"""
acc, l_i, m_i = forward_inner(
q, K_block_ptr, V_block_ptr,
acc, l_i, m_i,
off_z, off_h, offs_m, offs_n,
off_z, off_hq, offs_m[:, None], offs_n[None, :],
kv_indices, kv_num_blocks,
0, kv_num_blocks * SPARSE_KV_MULTIPLE,
MATMUL_PRECISION,
@ -294,16 +297,14 @@ compute_flex_attention = r"""
# Store output and logsumexp
l_i = tl.where(l_i == 0, 1, l_i)
acc = acc / l_i[:, None]
idx_z = tl.program_id(1) // H
idx_h = tl.program_id(1) % H
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
idx_z = tl.program_id(1) // HQ
idx_hq = tl.program_id(1) % HQ
idx_m = offs_m[:, None]
idx_d = tl.arange(0, BLOCK_DMODEL)[None, :]
mask = idx_m < Q_LEN
# TODO generalize and add proper mask support
{{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc", "mask")}}
{{store_output(("idx_z", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
# TODO dont want to write this if we dont require grad
if OUTPUT_LOGSUMEXP:
@ -320,7 +321,8 @@ def forward_inner(
q, K_block_ptr, V_block_ptr,
# accumulated values
acc, l_i, m_i,
# Offsets
# Offsets used as inputs to score_mod & mask_mod
# of size [BLOCK_M, BLOCK_N] or scalar.
off_z, off_h, offs_m, offs_n,
# blocksparse data
kv_indices, kv_num_blocks,
@ -351,8 +353,6 @@ def forward_inner(
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
m = offs_m[:, None]
n = offs_n[None, :]
# TODO: Add load mask in modification when M/N Boundary is not safe
{{ modification(
subgraph_number=0,
@ -360,8 +360,8 @@ def forward_inner(
score="qk",
b="off_z",
h="off_h",
m="m",
n="n",
m="offs_m",
n="offs_n",
out="qk"
) | indent_except_first(2) }}
@ -372,8 +372,8 @@ def forward_inner(
score="qk",
b="off_z",
h="off_h",
m="m",
n="n",
m="offs_m",
n="offs_n",
) | indent_except_first(3) }}
# apply mask for partially unmasked blocks
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
@ -627,7 +627,7 @@ def flex_attention(
query.get_stride(),
)
# see NOTE:[TritonTemplates with multiple outputs]
logsumexp_shape = query.get_size()[:-1] # [B, H, M]
logsumexp_shape = query.get_size()[:-1] # [B, Hq, Mq]
logsumexp = empty_strided(
logsumexp_shape,
None,
@ -636,6 +636,11 @@ def flex_attention(
)
kernel_options = dict(kernel_options)
kernel_options["SM_SCALE"] = scale
# Determine GQA broadcast factor.
gqa_shared_heads = query.get_size()[1] // key.get_size()[1]
kernel_options["GQA_SHARED_HEADS"] = gqa_shared_heads
# Inside of Triton kernel, only apply partial masking if partial blocks are computed.
# full_kv_num_blocks is None if partial blocks are not computed
kernel_options["HAS_FULL_BLOCKS"] = full_kv_num_blocks is not None
@ -736,20 +741,20 @@ def flex_attention(
def flex_attention_backward_grid(
batch_size, num_heads, num_queries, d_model, num_key_value, meta
batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta
):
"""How is this kernel parallelized?
Currently this is only parallelizing over batch * num_heads, but we can, and want to
parallelize over ceil_div(num_key_value, key_value_block_size). To do this will either require
atomic updates to some grad values or to have a two pass kernel design.
Currently this is only parallelizing over batch* kv_heads, but we can, and want to
parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size).
To do this will either require atomic updates to some grad values or to have a two pass kernel design.
"""
import triton
return (
triton.cdiv(num_queries, meta["BLOCK_M2"])
triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads)
+ triton.cdiv(num_key_value, meta["BLOCK_N1"]),
1,
batch_size * num_heads,
batch_size * kv_heads,
)
@ -768,6 +773,7 @@ flex_attention_backward_template = TritonTemplate(
# inductor codegen
# M: Number of queries, N: Number of keys/values, D: Model dimension
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
# (Modifiable) Performance tuning options
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
@ -799,7 +805,8 @@ flex_attention_backward_template = TritonTemplate(
stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}}
Z = {{size("Q", 0)}}
H = {{size("Q", 1)}}
HQ = {{size("Q", 1)}}
HKV = {{size("K", 1)}}
Q_LEN = {{size("Q", 2)}}
KV_LEN = {{size("K", 2)}}
@ -807,39 +814,25 @@ flex_attention_backward_template = TritonTemplate(
pid = tl.program_id(0)
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
off_hz = tl.program_id(2)
off_z = off_hz // H # batch idx
off_h = off_hz % H # head idx
off_z = off_hz // HKV # batch idx
off_hkv = off_hz % HKV # kv head idx
SM_Z = {{size("KV_NUM_BLKS", 0)}}
SM_H = {{size("KV_NUM_BLKS", 1)}}
SM_HQ = {{size("KV_NUM_BLKS", 1)}}
sparse_idx_z = off_z % SM_Z
sparse_idx_h = off_h % SM_H
sparse_hz_offset = sparse_idx_z * SM_H + sparse_idx_h
k_adj = (stride_kh * off_hkv + stride_kz * off_z).to(tl.int64)
v_adj = (stride_vh * off_hkv + stride_vz * off_z).to(tl.int64)
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_z).to(tl.int64)
off_chz = (off_hz * Q_LEN).to(tl.int64)
q_adj = (stride_qh * (off_hz % H) + stride_qz * (off_hz // H)).to(tl.int64)
k_adj = (stride_kh * (off_hz % H) + stride_kz * (off_hz // H)).to(tl.int64)
v_adj = (stride_vh * (off_hz % H) + stride_vz * (off_hz // H)).to(tl.int64)
do_adj = (stride_doh * (off_hz % H) + stride_doz * (off_hz // H)).to(tl.int64)
dq_adj = (stride_dqh * (off_hz % H) + stride_dqz * (off_hz // H)).to(tl.int64)
dv_adj = (stride_dvh * (off_hz % H) + stride_dvz * (off_hz // H)).to(tl.int64)
# offset pointers for batch/head
Q += q_adj
# offset K, V, DV pointers for batch/kv-head
K += k_adj
V += v_adj
DO += do_adj
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
DQ += dq_adj
DV += dv_adj
LSE += off_chz
DELTA += off_chz
RCP_LN2 = 1.44269504
offs_k = tl.arange(0, BLOCK_DMODEL)
@ -849,43 +842,60 @@ flex_attention_backward_template = TritonTemplate(
# THIS BLOCK DOES DQ
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
off_pid_mask = off_pid // SPARSE_Q_MULTIPLE
KV_IDX_N = {{size("KV_IDX", 3)}}
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
start_m2_block = off_pid % NUM_Q_BLOCKS
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
sparse_idx_hq2 = off_hq2 % SM_HQ
sparse_hz_offset = sparse_idx_z * SM_HQ + sparse_idx_hq2
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_z).to(tl.int64)
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_z).to(tl.int64)
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_z).to(tl.int64)
off_chz2 = ((off_z * HQ + off_hq2) * Q_LEN).to(tl.int64)
Q2 = Q + q_adj2
DO2 = DO + do_adj2
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
DQ2 = DQ + dq_adj2
LSE2 = LSE + off_chz2
DELTA2 = DELTA + off_chz2
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
start_m2 = off_pid * BLOCK_M2
start_m2 = start_m2_block * BLOCK_M2
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
# load Q and do: they stay in SRAM throughout the inner loop.
q = tl.load(Q + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd)
do = tl.load(DO + offs_m2[:, None] * stride_dom + offs_k[None, :] * stride_dod)
q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd)
do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_k[None, :] * stride_dod)
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
Di = tl.load(DELTA + offs_m2)
lse = tl.load(LSE + offs_m2)
Di = tl.load(DELTA2 + offs_m2)
lse = tl.load(LSE2 + offs_m2)
lse = lse[:, None]
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# KV_IDX and KV_NUM_BLKS are always contiguous.
kv_indices = KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(KV_IDX + sparse_kv_idx_offset) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
dq, q, K, V, do, Di, lse,
off_z, off_h, offs_m2, offs_n2,
K, V,
dq, q, do, Di, lse,
off_z, off_hq2, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
@ -900,11 +910,11 @@ flex_attention_backward_template = TritonTemplate(
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
dq, q, K, V, do, Di, lse,
off_z, off_h, offs_m2, offs_n2,
K, V,
dq, q, do, Di, lse,
off_z, off_hq2, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
@ -913,7 +923,7 @@ flex_attention_backward_template = TritonTemplate(
)
# Write back dQ.
dq_ptrs = DQ + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
dq *= SM_SCALE
tl.store(dq_ptrs, dq)
else:
@ -923,14 +933,10 @@ flex_attention_backward_template = TritonTemplate(
pid_mask = pid // SPARSE_KV_MULTIPLE
Q_IDX_M = {{size("Q_IDX", 3)}}
stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}}
stride_q_idx_h = {{stride("Q_IDX", 1)}}
stride_q_idx_n = {{stride("Q_IDX", 2)}}
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
sparse_q_idx_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask * stride_q_idx_n # noqa: B950
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
@ -943,46 +949,66 @@ flex_attention_backward_template = TritonTemplate(
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
v = tl.load(V + offs_n1[:, None] * stride_vn + offs_k[None, :] * stride_vd)
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Q_IDX and Q_NUM_BLKS are always contiguous.
q_indices = Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
for off_g in range(0, GQA_SHARED_HEADS):
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_z).to(tl.int64)
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_z).to(tl.int64)
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_z).to(tl.int64)
off_chz1 = ((off_z * HQ + off_hq1) * Q_LEN).to(tl.int64)
dk, dv = bwd_dkdv_inner(
dk, dv, Q, k, v, DO, DELTA, LSE,
off_z, off_h, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS=False
)
Q1 = Q + q_adj1
DO1 = DO + do_adj1
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
LSE1 = LSE + off_chz1
DELTA1 = DELTA + off_chz1
sparse_idx_hq1 = off_hq1 % SM_HQ
sparse_hz_offset = sparse_idx_z * SM_HQ + sparse_idx_hq1
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
q_indices = FULL_Q_IDX + sparse_q_idx_offset
# Q_IDX and Q_NUM_BLKS are always contiguous.
q_indices = Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
dk, dv = bwd_dkdv_inner(
dk, dv, Q, k, v, DO, DELTA, LSE,
off_z, off_h, offs_n1, offs_m1,
Q1, DO1, DELTA1, LSE1,
dk, dv, k, v,
off_z, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS=True
IS_FULL_BLOCKS=False
)
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
q_indices = FULL_Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
dk, dv = bwd_dkdv_inner(
Q1, DO1, DELTA1, LSE1,
dk, dv, k, v,
off_z, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS=True
)
# Write back dV and dK.
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_k[None, :] * stride_dvd
@ -993,12 +1019,13 @@ flex_attention_backward_template = TritonTemplate(
dk *= SM_SCALE
mask = index_n < KV_LEN
{{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
{{store_output(("off_z", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
@triton.jit
def bwd_dq_inner(
dq, q, K, V, do, Di, lse,
off_z, off_h, offs_m2, offs_n2,
K, V, # pointers
dq, q, do, Di, lse,
off_z, off_hq, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
@ -1032,7 +1059,7 @@ def bwd_dq_inner(
output_name="post_mod_scores",
score="qk",
b="off_z",
h="off_h",
h="off_hq",
m="m",
n="n",
out="qk"
@ -1044,7 +1071,7 @@ def bwd_dq_inner(
output_name="mask_mod_output",
score="qk",
b="off_z",
h="off_h",
h="off_hq",
m="m",
n="n",
) | indent_except_first(3) }}
@ -1065,7 +1092,7 @@ def bwd_dq_inner(
output_name = "grad_scores",
score="pre_mod_scores",
b="off_z",
h="off_h",
h="off_hq",
m="m",
n="n",
grad_score_mod="ds"
@ -1093,8 +1120,9 @@ def bwd_dq_inner(
@triton.jit
def bwd_dkdv_inner(
dk, dv, Q, k, v, DO, DELTA, LSE,
off_z, off_h, offs_n1, offs_m1,
Q, DO, DELTA, LSE, # pointers
dk, dv, k, v,
off_z, off_hq, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
@ -1130,7 +1158,7 @@ def bwd_dkdv_inner(
output_name="post_mod_scores",
score="qkT",
b="off_z",
h="off_h",
h="off_hq",
m="m",
n="n",
out="qkT"
@ -1141,7 +1169,7 @@ def bwd_dkdv_inner(
output_name="mask_mod_output",
score="qkT",
b="off_z",
h="off_h",
h="off_hq",
m="m",
n="n",
) | indent_except_first(3) }}
@ -1167,7 +1195,7 @@ def bwd_dkdv_inner(
output_name = "grad_scores",
score="pre_mod_scores",
b="off_z",
h="off_h",
h="off_hq",
m="m",
n="n",
grad_score_mod="dsT"
@ -1305,6 +1333,11 @@ def flex_attention_backward(*args, **kwargs):
kernel_options = dict(kernel_options)
kernel_options["SM_SCALE"] = scale
# Determine GQA factor
gqa_shared_heads = query.get_size()[1] // key.get_size()[1]
kernel_options["GQA_SHARED_HEADS"] = gqa_shared_heads
# Inside of Triton kernel, only apply partial masking if partial blocks are computed.
# full_kv_num_blocks is torch.zeros([1, 1, 1]) if partial blocks are not computed.
kernel_options["HAS_FULL_BLOCKS"] = full_kv_num_blocks is not None
@ -1370,7 +1403,7 @@ def flex_attention_backward(*args, **kwargs):
layout=layout_k, # We use store_output only for grad_key
subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer, mask_graph_buffer],
mutated_inputs=[grad_query, grad_value],
call_sizes=query.get_size() + [key.get_size()[2]],
call_sizes=query.get_size() + key.get_size()[1:3],
num_stages=num_stages,
num_warps=num_warps,
**kernel_options,

View file

@ -9,7 +9,7 @@ from torch._inductor.virtualized import V
from ..ir import FixedLayout, FlexibleLayout
from ..lowering import empty, empty_strided, lowerings
from ..runtime.runtime_utils import next_power_of_2
from ..runtime.runtime_utils import is_power_of_2, next_power_of_2
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
from .flex_attention import compute_forward_block, compute_next_offset_func
@ -18,15 +18,15 @@ aten = torch.ops.aten
prims = torch.ops.prims
def flex_decoding_grid(batch_size, num_heads, n_keys, d_model, meta):
def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta):
"""How is this kernel parallelized?
We create a grid of (batch_size * num_heads, SPLIT_KV, 1)
We create a grid of (batch_size * kv_heads, SPLIT_KV, 1)
Each block is responsible for iterating over blocks of keys and values calculating
the local output for their tile of keys and values over all full length of query.
groups of SPLIT_KV blocks then combine their output to produce the final result.
"""
return (batch_size * num_heads, meta["SPLIT_KV"], 1)
return (batch_size * kv_heads, meta["SPLIT_KV"], 1)
flex_decoding_template = TritonTemplate(
@ -45,6 +45,7 @@ flex_decoding_template = TritonTemplate(
# TILE_KV: length of each local KV split
# BLOCK_M: block size that Q is padded along seqlen dim.
# BLOCK_N: block size of K & V along N dimension.
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
#
# change of base out of the loop
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
@ -61,70 +62,73 @@ flex_decoding_template = TritonTemplate(
#
# Output: ACC output accumulated across local KV split.
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
# Define Q Strides
stride_qz = {{stride("Q", 0)}}
stride_qh = {{stride("Q", 1)}}
stride_qm = {{stride("Q", 2)}}
stride_qk = {{stride("Q", 3)}}
# Define K Strides
stride_kz = {{stride("K", 0)}}
stride_kh = {{stride("K", 1)}}
stride_kn = {{stride("K", 2)}}
stride_kk = {{stride("K", 3)}}
# Define V Strides
stride_vz = {{stride("V", 0)}}
stride_vh = {{stride("V", 1)}}
stride_vk = {{stride("V", 2)}}
stride_vn = {{stride("V", 3)}}
# Define M Strides
stride_mz = {{stride("M", 0)}}
stride_mh = {{stride("M", 1)}}
stride_mt = {{stride("M", 2)}}
stride_mm = {{stride("M", 3)}}
# Define L Strides
stride_lz = {{stride("L", 0)}}
stride_lh = {{stride("L", 1)}}
stride_lt = {{stride("L", 2)}}
stride_lm = {{stride("L", 3)}}
stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}}
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}}
stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}}
Z = {{size("Q", 0)}}
H = {{size("Q", 1)}}
Q_LEN = {{size("Q", 2)}}
HKV = {{size("Q", 1)}}
G: tl.constexpr = GQA_SHARED_HEADS
HQ = HKV * G
Q_LEN = {{size("Q", 3)}}
KV_LEN = {{size("K", 2)}}
# Make sure each split is a multiple of BLOCK_N
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
SPARSE_H = {{size("KV_NUM_BLKS", 1)}}
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
SPARSE_KV_BLOCK_CNT: tl.constexpr = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
MATMUL_PRECISION = Q.dtype.element_ty
off_z = tl.program_id(0) // H
off_h = tl.program_id(0) % H
# Make sure each split is a multiple of BLOCK_N
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
off_z = tl.program_id(0) // HKV
off_hkv = tl.program_id(0) % HKV
off_t = tl.program_id(1)
sparse_idx_z = off_z % SPARSE_Z
sparse_idx_h = off_h % SPARSE_H
q_offset = off_z * stride_qz + off_hkv * stride_qh
k_offset = off_z * stride_kz + off_hkv * stride_kh
v_offset = off_z * stride_vz + off_hkv * stride_vh
# TODO: strided KV_IDX and KV_NUM_BLKS
sparse_hz_offset = sparse_idx_z * SPARSE_H + sparse_idx_h
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
sparse_idx_z = off_z % SPARSE_Z
# TODO: support masks not broadcasted along the head dimension.
tl.device_assert(SPARSE_HQ == 1)
sparse_idx_h = 0
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# initialize offsets
tl.device_assert(BLOCK_M % G == 0)
BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
off_g = tl.arange(0, G) # [G]
offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
offs_hq = offs_g + off_hkv * G
off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
offs_d = tl.arange(0, BLOCK_DMODEL)
# KV_IDX / FULL_KV_IDX and KV_NUM_BLKS / FULL_KV_NUM_BLKS are always contiguous.
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_h
# Calculate KV blocks that belong this CTA.
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
q_offset = off_z * stride_qz + off_h * stride_qh
k_offset = off_z * stride_kz + off_h * stride_kh
v_offset = off_z * stride_vz + off_h * stride_vh
q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(Q_LEN, BLOCK_DMODEL), # (M, d)
@ -134,18 +138,13 @@ flex_decoding_template = TritonTemplate(
order=(1, 0)
)
if SAFE_M_BOUNDARY:
q = tl.load(Q_block_ptr)
q = tl.load(Q + q_offset + q_range)
else:
q = tl.load(Q_block_ptr, boundary_check=(0, ))
mask = off_m[None, :, None] < Q_LEN
q = tl.load(Q + q_offset + q_range, mask)
# initialize offsets
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_DMODEL)
q = tl.reshape(q, [BLOCK_M, BLOCK_DMODEL])
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Apply both score_mod and mask_mod
@ -171,7 +170,7 @@ flex_decoding_template = TritonTemplate(
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(KV_LEN, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
strides=(stride_vn, stride_vk),
offsets=(off_n, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
@ -183,7 +182,7 @@ flex_decoding_template = TritonTemplate(
# accumulatd values
acc, l_i, m_i,
#offsets
off_z, off_h, offs_m, offs_n,
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
#block sparse data
kv_indices, kv_num_blocks,
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
@ -216,7 +215,7 @@ flex_decoding_template = TritonTemplate(
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(KV_LEN, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
strides=(stride_vn, stride_vk),
offsets=(off_n, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
@ -228,7 +227,7 @@ flex_decoding_template = TritonTemplate(
# accumulatd values
acc, l_i, m_i,
#offsets
off_z, off_h, offs_m, offs_n,
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
#block sparse data
kv_indices, kv_num_blocks,
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
@ -237,42 +236,46 @@ flex_decoding_template = TritonTemplate(
IS_FULL_BLOCKS=True,
)
m_offset = off_h * stride_mh + off_z * stride_mz
l_offset = off_h * stride_lh + off_z * stride_lz
m_offset = off_t * stride_mt + off_z * stride_mz
l_offset = off_t * stride_lt + off_z * stride_lz
M_block_ptr = tl.make_block_ptr(
base=M + m_offset,
shape=(SPLIT_KV, Q_LEN), # (T, M)
strides=(stride_mt, stride_mm),
offsets=(off_t, 0),
block_shape=(1, BLOCK_M),
shape=(G, Q_LEN), # (G, M)
strides=(stride_mh, stride_mm),
offsets=(off_hkv*G, 0),
block_shape=(G, BLOCK_M_PER_HQ),
order=(1, 0)
)
L_block_ptr = tl.make_block_ptr(
base=L + l_offset,
shape=(SPLIT_KV, Q_LEN), # (T, M)
strides=(stride_lt, stride_lm),
offsets=(off_t, 0),
block_shape=(1, BLOCK_M),
shape=(G, Q_LEN), # (G, M)
strides=(stride_lh, stride_lm),
offsets=(off_hkv*G, 0),
block_shape=(G, BLOCK_M_PER_HQ),
order=(1, 0)
)
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
if SAFE_M_BOUNDARY:
tl.store(M_block_ptr, m_i[None, :])
tl.store(L_block_ptr, l_i[None, :])
tl.store(M_block_ptr, m_i)
tl.store(L_block_ptr, l_i)
else:
tl.store(M_block_ptr, m_i[None, :], boundary_check=(1,))
tl.store(L_block_ptr, l_i[None, :], boundary_check=(1,))
tl.store(M_block_ptr, m_i, boundary_check=(1,))
tl.store(L_block_ptr, l_i, boundary_check=(1,))
# -- store output
idx_z = off_z
idx_h = off_h
idx_t = off_t
idx_m = offs_m[:, None]
idx_d = offs_d[None, :]
idx_hq = off_hkv*G + off_g[:, None, None]
idx_m = off_m[None, :, None]
idx_d = offs_d[None, None, :]
# TODO generalize and add proper mask support
mask = (idx_m < Q_LEN)
{{store_output(("idx_z", "idx_h", "idx_t", "idx_m", "idx_d"), "acc", "mask")}}
acc = acc.reshape(G, BLOCK_M_PER_HQ, BLOCK_DMODEL)
{{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
"""
+ compute_forward_block
+ compute_next_offset_func,
@ -325,6 +328,15 @@ def create_flex_decoding_kernel(*args, **kwargs):
) = block_mask
kernel_options = dict(kernel_options)
# Calculate GQA head sharing
gqa_shared_heads = query.get_size()[1] // key.get_size()[1]
if not is_power_of_2(gqa_shared_heads):
raise ValueError(
"Number of shared query heads sharing the same KV head must be power of 2. "
)
kernel_options["GQA_SHARED_HEADS"] = gqa_shared_heads
# Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod
kernel_options["HAS_FULL_BLOCKS"] = full_kv_num_blocks is not None
if (
@ -363,10 +375,10 @@ def create_flex_decoding_kernel(*args, **kwargs):
assert kernel_options["SPLIT_KV"] <= MAX_SPLIT_KV
# create config dependent intermediate buffers
buf_ML_shape = query.get_size()[:-2] + [
MAX_SPLIT_KV,
query.get_size()[-2],
] # [B, H, SPLIT_KV, M]
buf_ACC_shape = (
query.get_size()[:1] + [MAX_SPLIT_KV] + query.get_size()[1:]
) # [B, SPLIT_KV, Hq, M, D]
buf_ML_shape = buf_ACC_shape[:-1]
buf_M = empty_strided(
buf_ML_shape,
None,
@ -380,10 +392,6 @@ def create_flex_decoding_kernel(*args, **kwargs):
device=query.get_device(),
)
buf_ACC_shape = (
query.get_size()[:-2] + [MAX_SPLIT_KV] + query.get_size()[-2:]
) # [B, H, SPLIT_KV, M, D]
layout_acc = FixedLayout(
query.get_device(),
torch.float32,
@ -403,14 +411,31 @@ def create_flex_decoding_kernel(*args, **kwargs):
V.graph.sizevars.size_hint(
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
)
* gqa_shared_heads
),
16,
)
)
V.graph.sizevars.guard_leq(m, sympy.Integer(kernel_options["BLOCK_M"]))
# Reshape query for GQA: [B, Hq, Mq, D] -> [B, Hkv, G, Mq, D]
gqa_query_shape = (
query.get_size()[:1]
+ [key.get_size()[1], gqa_shared_heads]
+ query.get_size()[2:]
)
gqa_query_stride = (
query.get_stride()[:1]
+ [query.get_stride()[1] * gqa_shared_heads, query.get_stride()[1]]
+ query.get_stride()[2:]
)
query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride)
V.graph.sizevars.guard_leq(
m * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"])
)
kernel_options["SAFE_M_BOUNDARY"] = (
query.get_size()[-2] % kernel_options["BLOCK_M"]
(m * gqa_shared_heads) % kernel_options["BLOCK_M"]
) == 0
kernel_options["SAFE_N_BOUNDARY"] = True
@ -472,18 +497,18 @@ def create_flex_decoding_kernel(*args, **kwargs):
# Reduction
g_M = lowerings[aten.max](buf_M, dim=-2, keepdim=True)[0]
g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0]
adj_M = lowerings[aten.sub](buf_M, g_M)
alpha = lowerings[aten.exp2](adj_M)
buf_L = lowerings[aten.mul](buf_L, alpha)
g_L = lowerings[aten.sum](buf_L, axis=-2)
g_L = lowerings[aten.sum](buf_L, axis=1)
logsumexp = lowerings[aten.log2](g_L)
logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=-2))
logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1))
alpha_unseq = lowerings[aten.unsqueeze](alpha, 4)
buf_ACC = lowerings[aten.mul](buf_ACC, alpha_unseq)
output = lowerings[aten.sum](buf_ACC, axis=-3)
output = lowerings[aten.sum](buf_ACC, axis=1)
L_unseq = lowerings[aten.unsqueeze](g_L, 3)
output = lowerings[aten.div](output, L_unseq)
output = lowerings[prims.convert_element_type](output, query.get_dtype())

View file

@ -79,9 +79,10 @@ def _vmap_for_bhqkv(
prefix: Tuple[Optional[int], ...],
suffix: Tuple[Optional[int], ...] = (),
out_dims: Union[int, List[Optional[int]]] = 0,
group_dim: bool = False,
):
"""Used to vmap both score_mods and mask_mods over 4-dimensional inputs.
Mapping over the [b, h, q_idx, kv_idx] dimensions.
"""Used to vmap both score_mods and mask_mods over 4-dimensional/5-dimension inputs.
Mapping over the [b, hq, q_idx, kv_idx] or [b, hkv, g, q_idx, kv_idx] dimensions.
Args:
fn (callable): The function to vmap.
@ -98,10 +99,19 @@ def _vmap_for_bhqkv(
callable: The vmapped function.
"""
# We vamp a function 4 times, broadcasting the [b, h, q_idx, kv_idx] dimensions
dimensions: List[Tuple[None | int, None | int, None | int, None | int]] = []
dimensions = [
(None, None, None, 0),
(None, None, 0, None),
(None, 0, None, None),
]
if group_dim:
dimensions += [
(None, 0, None, None),
]
dimensions += [
(0, None, None, None),
]
@ -616,7 +626,7 @@ def create_mask(
Args:
mod_fn (Union[_score_mod_signature, _mask_mod_signature]): Function to modify attention scores.
B (int): Batch size.
H (int): Number of heads.
H (int): Number of query heads.
Q_LEN (int): Sequence length of query.
KV_LEN (int): Sequence length of key/value.
device (str): Device to run the mask creation on.
@ -698,7 +708,7 @@ def create_block_mask(
It should return a boolean tensor indicating which attention connections are allowed (True)
or masked out (False).
B (int): Batch size.
H (int): Number of heads.
H (int): Number of query heads.
Q_LEN (int): Sequence length of query.
KV_LEN (int): Sequence length of key/value.
device (str): Device to run the mask creation on.
@ -794,6 +804,7 @@ def flex_attention(
score_mod: Optional[_score_mod_signature] = None,
block_mask: Optional[BlockMask] = None,
scale: Optional[float] = None,
enable_gqa: bool = False,
kernel_options: Optional[Dict[str, Any]] = None,
) -> Tensor:
r"""This function implements scaled dot product attention with an arbitrary attention score modification function.
@ -818,20 +829,21 @@ def flex_attention(
- ``score``: A scalar tensor representing the attention score,
with the same data type and device as the query, key, and value tensors.
- ``batch``, ``head``, ``q_idx``, ``k_idx``: Scalar tensors indicating
the batch index, head index, query index, and key/value index, respectively.
the batch index, query head index, query index, and key/value index, respectively.
These should have the ``torch.int`` data type and be located on the same device as the score tensor.
Args:
query (Tensor): Query tensor; shape :math:`(B, H, L, E)`.
key (Tensor): Key tensor; shape :math:`(B, H, S, E)`.
value (Tensor): Value tensor; shape :math:`(B, H, S, Ev)`.
query (Tensor): Query tensor; shape :math:`(B, Hq, L, E)`.
key (Tensor): Key tensor; shape :math:`(B, Hkv, S, E)`.
value (Tensor): Value tensor; shape :math:`(B, Hkv, S, Ev)`.
score_mod (Optional[Callable]): Function to modify attention scores. By default no score_mod is applied.
block_mask (Optional[BlockMask]): BlockMask object that controls the blocksparsity pattern of the attention.
scale (Optional[float]): Scaling factor applied prior to softmax. If none, the default value is set to :math:`\frac{1}{\sqrt{E}}`.
enable_gqa (bool): If set to True, enables Grouped Query Attention (GQA) and broadcasts key/value heads to query heads.
kernel_options (Optional[Dict[str, Any]]): Options to pass into the Triton kernels.
Returns:
output (Tensor): Attention output; shape :math:`(B, H, L, Ev)`.
output (Tensor): Attention output; shape :math:`(B, Hq, L, Ev)`.
Shape legend:
- :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
@ -855,6 +867,20 @@ def flex_attention(
raise NotImplementedError("NYI: S must be <128 or a multiple of 128")
if key.size(-2) % 128 != 0:
raise NotImplementedError("NYI: L must be a multiple of 128")
if (not enable_gqa) and query.size(-3) != key.size(-3):
raise ValueError(
f"Expect query and key/value to have the same number of heads "
f"but got Hq={query.size(-3)} and Hkv={key.size(-3)}. "
f"Try setting enable_gqa=True for GQA."
)
if enable_gqa:
Hq = query.size(1)
Hkv = key.size(1)
if Hq % Hkv != 0:
raise ValueError(
f"Expect number of query heads to be a multiple of kv heads for GQA "
f"but got Hq={Hq} and Hkv={Hkv}."
)
if score_mod is None:
score_mod = _identity
@ -871,8 +897,9 @@ def flex_attention(
)
if torch.compiler.is_dynamo_compiling():
# mark head_dim always to be static
# mark head_dim and number of heads to be static
for x in [query, key, value]:
torch._dynamo.mark_static(x, -3)
torch._dynamo.mark_static(x, -1)
out, _ = flex_attention_hop(
query, key, value, score_mod, block_mask.as_tuple(), scale, kernel_options