pytorch/torch/_inductor/kernel/flex_decoding.py
joydddd 4110cb6ba7 Add explicit GQA support. (#131559)
### tl;dr
This PR adds GQA support to higher order op `flex_attention`.

## Details
When `enable_gqa` is set to True, HOP `flex_attention(score_mod, query, key, value, block_mask, enable_gqa)` runs Group Query Attention(GQA), where the number of query heads (Hq) is a multiple of number of key/value heads (Hkv). Each group of query heads (`Hq//Hkv` heads) attends to a shared kv head.
Otherwise, `flex_attention` assumes Multi Head Attention (MHA) where the number of query heads is equal the number of kv heads.

The `score_mod` and `mask_mod` API are adapted accordingly to take `q_head` as head index.
```
def score_mod(score: torch.Tensor, batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor

def mask_mod(batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```

## Example
```python
import torch
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask

torch.manual_seed(0)

def query_key_value_clones(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    dtype: torch.dtype = None,
):
    """Clones the query, key, and value tensors and moves them to the specified dtype."""
    if dtype is None:
        dtype = query.dtype
    query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
    key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
    value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
    return query_ref, key_ref, value_ref

# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim).
# query and key/value can have different num_heads and seq_len
# Here 8 query heads share one KV head.
query = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
key = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
value = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)

query1, key1, value1 = query_key_value_clones(query, key, value)

# Lets create a score_modification. We take alibi_bias as an example.
# score_mod takes batch index, query head index, query index, and key/value index.
def _generate_alibi_bias(num_kv_heads: int, num_q_heads: int):
    def _alibi_bias(
        score: torch.Tensor,
        b: torch.Tensor,
        hq: torch.Tensor,
        token_q: torch.Tensor,
        token_kv: torch.Tensor,
    ) -> torch.Tensor:
        # Let's calculate kv head from query head index
        group = num_q_heads // num_kv_heads
        hkv = hq // group

        scale = torch.exp2(-((hkv + 1) * 8.0 / num_kv_heads))
        return score + (token_kv - token_q) * scale

    return _alibi_bias

# Let's apply a casual mask on top of it
def causal_mask(b, h, q, kv):
    return q >= kv

# Generate a block mask for our new mask_mod function.
# The mask is broadcasted long head & batch dimensions.
block_mask = create_block_mask(causal_mask, B=1, H=1, Q_LEN=2048, KV_LEN=2048)

# Lets call flex_attention with our new score modification and block mask under eager mode.
output = flex_attention(query, key, value, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)

# Now lets compile flex_attention and run the flex_attention kernel.
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention(query1, key1, value1, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)

torch.testing.assert_close(output, out_compiled, atol=5e-2, rtol=2e-2)

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131559
Approved by: https://github.com/drisspg
2024-08-09 21:25:35 +00:00

519 lines
18 KiB
Python

# mypy: allow-untyped-defs
""" Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
from typing import Any, List, Tuple
import sympy
import torch
from torch._inductor.virtualized import V
from ..ir import FixedLayout, FlexibleLayout
from ..lowering import empty, empty_strided, lowerings
from ..runtime.runtime_utils import is_power_of_2, next_power_of_2
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
from .flex_attention import compute_forward_block, compute_next_offset_func
aten = torch.ops.aten
prims = torch.ops.prims
def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta):
"""How is this kernel parallelized?
We create a grid of (batch_size * kv_heads, SPLIT_KV, 1)
Each block is responsible for iterating over blocks of keys and values calculating
the local output for their tile of keys and values over all full length of query.
groups of SPLIT_KV blocks then combine their output to produce the final result.
"""
return (batch_size * kv_heads, meta["SPLIT_KV"], 1)
flex_decoding_template = TritonTemplate(
name="flex_decoding",
grid=flex_decoding_grid,
source=r"""
{{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
# Sub notation for this kernel:
# Q: Query, K: Key, V: Value
# reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
# M: Number of queries, N: Number of keys/values, D(BLOCK_DMODEL): Model dimension
# BLOCK_M, BLOCK_DMODEL: M, and D dimemsion are always assigned to the same block
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
# (Modifiable) Config options:
# SPLIT_KV: number of blocks K & V are split into
# TILE_KV: length of each local KV split
# BLOCK_M: block size that Q is padded along seqlen dim.
# BLOCK_N: block size of K & V along N dimension.
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
#
# change of base out of the loop
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
# is not masked out? If so, we can skip an extra safety check
# SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
# SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
#
# SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
#
#
# Output: ACC output accumulated across local KV split.
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
# Define Q Strides
stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}}
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}}
stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}}
Z = {{size("Q", 0)}}
HKV = {{size("Q", 1)}}
G: tl.constexpr = GQA_SHARED_HEADS
HQ = HKV * G
Q_LEN = {{size("Q", 3)}}
KV_LEN = {{size("K", 2)}}
MATMUL_PRECISION = Q.dtype.element_ty
# Make sure each split is a multiple of BLOCK_N
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
off_z = tl.program_id(0) // HKV
off_hkv = tl.program_id(0) % HKV
off_t = tl.program_id(1)
q_offset = off_z * stride_qz + off_hkv * stride_qh
k_offset = off_z * stride_kz + off_hkv * stride_kh
v_offset = off_z * stride_vz + off_hkv * stride_vh
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
sparse_idx_z = off_z % SPARSE_Z
# TODO: support masks not broadcasted along the head dimension.
tl.device_assert(SPARSE_HQ == 1)
sparse_idx_h = 0
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# initialize offsets
tl.device_assert(BLOCK_M % G == 0)
BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
off_g = tl.arange(0, G) # [G]
offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
offs_hq = offs_g + off_hkv * G
off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
offs_d = tl.arange(0, BLOCK_DMODEL)
# KV_IDX / FULL_KV_IDX and KV_NUM_BLKS / FULL_KV_NUM_BLKS are always contiguous.
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_h
# Calculate KV blocks that belong this CTA.
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(Q_LEN, BLOCK_DMODEL), # (M, d)
strides=(stride_qm, stride_qk),
offsets=(0, 0), # No offset: one CTA per query
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
if SAFE_M_BOUNDARY:
q = tl.load(Q + q_offset + q_range)
else:
mask = off_m[None, :, None] < Q_LEN
q = tl.load(Q + q_offset + q_range, mask)
q = tl.reshape(q, [BLOCK_M, BLOCK_DMODEL])
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Apply both score_mod and mask_mod
# find first kv block we are loading and the number of blocks we are loading
kv_indices = KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_hz_offset)
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
# first kv block we're loading
block_n_last_valid = kv_num_blocks * SPARSE_KV_MULTIPLE # last valid block according to sparse mask
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(BLOCK_DMODEL, KV_LEN), # (d, N)
strides=(stride_kk, stride_kn),
offsets=(0, off_n),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(KV_LEN, BLOCK_DMODEL),
strides=(stride_vn, stride_vk),
offsets=(off_n, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
offs_n = tl.arange(0, BLOCK_N) + off_n
acc, l_i, m_i = forward_inner(
q, K_block_ptr, V_block_ptr,
# accumulatd values
acc, l_i, m_i,
#offsets
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
#block sparse data
kv_indices, kv_num_blocks,
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS=False,
)
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We know these blocks are guaranteed to be "full", so we don't need to
# apply mask_mod to them - only score_mod
if HAS_FULL_BLOCKS:
kv_indices = FULL_KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_hz_offset)
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
block_n_last_valid = kv_num_blocks * SPARSE_KV_MULTIPLE # last valid block according to sparse mask
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(BLOCK_DMODEL, KV_LEN), # (d, N)
strides=(stride_kk, stride_kn),
offsets=(0, off_n),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(KV_LEN, BLOCK_DMODEL),
strides=(stride_vn, stride_vk),
offsets=(off_n, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
offs_n = tl.arange(0, BLOCK_N) + off_n
acc, l_i, m_i = forward_inner(
q, K_block_ptr, V_block_ptr,
# accumulatd values
acc, l_i, m_i,
#offsets
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
#block sparse data
kv_indices, kv_num_blocks,
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS=True,
)
m_offset = off_t * stride_mt + off_z * stride_mz
l_offset = off_t * stride_lt + off_z * stride_lz
M_block_ptr = tl.make_block_ptr(
base=M + m_offset,
shape=(G, Q_LEN), # (G, M)
strides=(stride_mh, stride_mm),
offsets=(off_hkv*G, 0),
block_shape=(G, BLOCK_M_PER_HQ),
order=(1, 0)
)
L_block_ptr = tl.make_block_ptr(
base=L + l_offset,
shape=(G, Q_LEN), # (G, M)
strides=(stride_lh, stride_lm),
offsets=(off_hkv*G, 0),
block_shape=(G, BLOCK_M_PER_HQ),
order=(1, 0)
)
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
if SAFE_M_BOUNDARY:
tl.store(M_block_ptr, m_i)
tl.store(L_block_ptr, l_i)
else:
tl.store(M_block_ptr, m_i, boundary_check=(1,))
tl.store(L_block_ptr, l_i, boundary_check=(1,))
# -- store output
idx_z = off_z
idx_t = off_t
idx_hq = off_hkv*G + off_g[:, None, None]
idx_m = off_m[None, :, None]
idx_d = offs_d[None, None, :]
# TODO generalize and add proper mask support
mask = (idx_m < Q_LEN)
acc = acc.reshape(G, BLOCK_M_PER_HQ, BLOCK_DMODEL)
{{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
"""
+ compute_forward_block
+ compute_next_offset_func,
)
MAX_SPLIT_KV = 64
def get_split_k(B: int, H: int, Mk: int, SM: int = 128) -> int:
"""Heuristic for the number of splits from xformer"""
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
split_k = SM // bh # Each SM should at least get one block.
split_k = max(split_k, 1)
return split_k
def _get_decoding_default_config(key) -> Tuple[int, int, int]:
default_config = (64, 2, 3)
return default_config
def create_flex_decoding_kernel(*args, **kwargs):
(
query,
key,
value,
block_mask,
scale,
kernel_options,
score_mod_subgraph,
mask_mod_subgraph,
score_mod_other_buffers,
mask_mod_other_buffers,
) = args
(
kv_num_blocks,
kv_indices,
full_kv_num_blocks, # full_kv_num_blocks,
full_kv_indices, # full_kv_indices,
_, # q_num_blocks
_, # q_indices
_, # full_q_num_blocks,
_, # full_q_indices,
SPARSE_KV_BLOCK_SIZE,
_, # SPARSE_Q_BLOCK_SIZE,
_,
) = block_mask
kernel_options = dict(kernel_options)
# Calculate GQA head sharing
gqa_shared_heads = query.get_size()[1] // key.get_size()[1]
if not is_power_of_2(gqa_shared_heads):
raise ValueError(
"Number of shared query heads sharing the same KV head must be power of 2. "
)
kernel_options["GQA_SHARED_HEADS"] = gqa_shared_heads
# Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod
kernel_options["HAS_FULL_BLOCKS"] = full_kv_num_blocks is not None
if (
full_kv_num_blocks is None
): # Create a plackeholder full block list in case it is empty
full_kv_num_blocks, full_kv_indices = (
empty(0, device=query.get_device()) for _ in range(2)
)
for buf in [
query,
key,
value,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
]:
buf.realize()
choices: List[Any] = []
configs: List[Tuple[int, int, int]] = []
configs.append(_get_decoding_default_config(key))
# Note: max_autotune is not supported yet. Causes error in lowering the dynamic shape in reduction ops.
# if config.max_autotune:
# configs += [
# (64, 2, 2),
# (32, 2, 3),
# ]
# TODO: fix autotuning.
kernel_options["SM_SCALE"] = scale
kernel_options["SPLIT_KV"] = get_split_k(
key.get_size()[0], key.get_size()[1], key.get_size()[2]
)
MAX_SPLIT_KV = kernel_options["SPLIT_KV"]
assert kernel_options["SPLIT_KV"] <= MAX_SPLIT_KV
# create config dependent intermediate buffers
buf_ACC_shape = (
query.get_size()[:1] + [MAX_SPLIT_KV] + query.get_size()[1:]
) # [B, SPLIT_KV, Hq, M, D]
buf_ML_shape = buf_ACC_shape[:-1]
buf_M = empty_strided(
buf_ML_shape,
None,
dtype=torch.float32, # The rowmax is always stored in fp32 regardless of the input dtype
device=query.get_device(),
)
buf_L = empty_strided(
buf_ML_shape,
None,
dtype=torch.float32, # The intermediate sumexp is always stored in fp32 regardless of the input dtype
device=query.get_device(),
)
layout_acc = FixedLayout(
query.get_device(),
torch.float32,
buf_ACC_shape,
FlexibleLayout.contiguous_strides(buf_ACC_shape),
)
kernel_options["BLOCK_DMODEL"] = query.get_size()[-1]
m = query.get_size()[-2]
kernel_options["BLOCK_M"] = (
# m
# if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0))
# else # Always use a BLOCK_M > 16 before Triton fix https://github.com/triton-lang/triton/pull/4061 is in pin
max(
next_power_of_2(
V.graph.sizevars.size_hint(
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
)
* gqa_shared_heads
),
16,
)
)
# Reshape query for GQA: [B, Hq, Mq, D] -> [B, Hkv, G, Mq, D]
gqa_query_shape = (
query.get_size()[:1]
+ [key.get_size()[1], gqa_shared_heads]
+ query.get_size()[2:]
)
gqa_query_stride = (
query.get_stride()[:1]
+ [query.get_stride()[1] * gqa_shared_heads, query.get_stride()[1]]
+ query.get_stride()[2:]
)
query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride)
V.graph.sizevars.guard_leq(
m * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"])
)
kernel_options["SAFE_M_BOUNDARY"] = (
(m * gqa_shared_heads) % kernel_options["BLOCK_M"]
) == 0
kernel_options["SAFE_N_BOUNDARY"] = True
# Note, we don't need to pass in the captured buffers explicitly
# because they're implicitly added by the score_mod function
# We do need to explicitly pass it in for autotuning though.
for BLOCK_N, num_warps, num_stages in configs:
if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0:
continue
# Performance tuning
kernel_options["BLOCK_N"] = BLOCK_N
kernel_options["SPARSE_KV_BLOCK_SIZE"] = SPARSE_KV_BLOCK_SIZE
flex_decoding_template.maybe_append_choice(
choices=choices,
input_nodes=[
query,
key,
value,
buf_M,
buf_L,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
],
layout=layout_acc,
subgraphs=[
score_mod_subgraph,
mask_mod_subgraph,
],
mutated_inputs=[buf_M, buf_L],
num_stages=num_stages,
num_warps=num_warps,
call_sizes=query.get_size(),
**kernel_options,
)
inputs_for_flex_decoding = (
[
query,
key,
value,
buf_M,
buf_L,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
]
+ list(score_mod_other_buffers)
+ list(mask_mod_other_buffers)
)
buf_ACC = autotune_select_algorithm(
"flex_decoding", choices, inputs_for_flex_decoding, layout_acc
)
# Reduction
g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0]
adj_M = lowerings[aten.sub](buf_M, g_M)
alpha = lowerings[aten.exp2](adj_M)
buf_L = lowerings[aten.mul](buf_L, alpha)
g_L = lowerings[aten.sum](buf_L, axis=1)
logsumexp = lowerings[aten.log2](g_L)
logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1))
alpha_unseq = lowerings[aten.unsqueeze](alpha, 4)
buf_ACC = lowerings[aten.mul](buf_ACC, alpha_unseq)
output = lowerings[aten.sum](buf_ACC, axis=1)
L_unseq = lowerings[aten.unsqueeze](g_L, 3)
output = lowerings[aten.div](output, L_unseq)
output = lowerings[prims.convert_element_type](output, query.get_dtype())
return (
output,
logsumexp,
)