Refactored flexattention kernel (#130904)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130904
Approved by: https://github.com/drisspg
ghstack dependencies: #130871
This commit is contained in:
chilli 2024-07-18 10:01:42 -07:00 committed by PyTorch MergeBot
parent ac76dd606f
commit d59803fb67
4 changed files with 139 additions and 104 deletions

View file

@ -1621,9 +1621,9 @@ class GraphModule(torch.nn.Module):
NotImplementedError, "NYI: L must be a multiple of 128"
):
flex_attention(
torch.randn((2, 3, 4)),
torch.randn((2, 10, 5)),
torch.randn((2, 10, 5)),
torch.randn((1, 2, 3, 4)),
torch.randn((1, 2, 10, 5)),
torch.randn((1, 2, 10, 5)),
score_mod=_identity,
)
@ -1632,9 +1632,9 @@ class GraphModule(torch.nn.Module):
):
compiled_flex = torch.compile(flex_attention)
compiled_flex(
torch.randn((2, 3, 4)),
torch.randn((2, 10, 5)),
torch.randn((2, 10, 5)),
torch.randn((1, 2, 3, 4)),
torch.randn((1, 2, 10, 5)),
torch.randn((1, 2, 10, 5)),
score_mod=_identity,
)

View file

@ -12,6 +12,7 @@ from torch.utils._pytree import tree_map
from .. import config
from ..ir import (
ComputedBuffer,
ExternKernel,
FixedLayout,
FlexibleLayout,
InputBuffer,
@ -110,6 +111,19 @@ def build_subgraph_buffer(
raise ValueError("FlexAttention was passed a subgraph with no output node!")
compute_next_offset_func = r"""
@triton.jit
def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK):
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
return offset
"""
flex_attention_template = TritonTemplate(
name="flex_attention",
grid=flex_attention_grid,
@ -176,8 +190,8 @@ flex_attention_template = TritonTemplate(
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
SPARSE_Q_BLOCK_CNT: tl.constexpr = Q_LEN // SPARSE_Q_BLOCK_SIZE
SPARSE_KV_BLOCK_CNT: tl.constexpr = KV_LEN // SPARSE_KV_BLOCK_SIZE
SPARSE_Q_BLOCK_CNT: tl.constexpr = tl.cdiv(Q_LEN, SPARSE_Q_BLOCK_SIZE)
SPARSE_KV_BLOCK_CNT: tl.constexpr = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
@ -207,7 +221,7 @@ flex_attention_template = TritonTemplate(
# both score_mod and mask_mod to it
kv_indices = KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
K_block_ptr = tl.make_block_ptr(
base=K,
@ -225,15 +239,14 @@ flex_attention_template = TritonTemplate(
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
offs_n = kv_start + tl.arange(0, BLOCK_N)
acc, l_i, m_i = forward_inner(
q, K_block_ptr, V_block_ptr,
acc, l_i, m_i,
off_z, off_h, offs_m,
kv_start,
kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE,
BLOCK_M, BLOCK_N, BLOCK_DMODEL, PRESCALE_QK, SM_SCALE, ROWS_GUARANTEED_SAFE, MATMUL_PRECISION,
off_z, off_h, offs_m, offs_n,
kv_indices, kv_num_blocks,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS=False
)
@ -245,7 +258,7 @@ flex_attention_template = TritonTemplate(
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
K_block_ptr = tl.make_block_ptr(
base=K,
@ -263,17 +276,14 @@ flex_attention_template = TritonTemplate(
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
# initialize offsets
offs_n = kv_start + tl.arange(0, BLOCK_N)
acc, l_i, m_i = forward_inner(
q, K_block_ptr, V_block_ptr,
acc, l_i, m_i,
off_z, off_h, offs_m,
kv_start,
kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE,
BLOCK_M, BLOCK_N, BLOCK_DMODEL, PRESCALE_QK, SM_SCALE, ROWS_GUARANTEED_SAFE, MATMUL_PRECISION,
off_z, off_h, offs_m, offs_n,
kv_indices, kv_num_blocks,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS=True
)
@ -303,27 +313,31 @@ flex_attention_template = TritonTemplate(
@triton.jit
def forward_inner(
q, K_block_ptr, V_block_ptr,
# accumulated values
acc, l_i, m_i,
off_z, off_h, offs_m,
kv_start,
kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE,
BLOCK_M, BLOCK_N, BLOCK_DMODEL, PRESCALE_QK, SM_SCALE, ROWS_GUARANTEED_SAFE, MATMUL_PRECISION,
# Offsets
off_z, off_h, offs_m, offs_n,
# blocksparse data
kv_indices, kv_num_blocks,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS,
):
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
{{gen_defines() | indent_except_first(1)}}
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
# initialize offsets
offs_n = kv_start + tl.arange(0, BLOCK_N)
RCP_LN2 = 1.44269504
RCP_LN2: tl.constexpr = 1.44269504
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
# loop over k, v and update accumulator
lo = 0
hi = sparse_kv_num_blocks * SPARSE_KV_MULTIPLE
hi = kv_num_blocks * SPARSE_KV_MULTIPLE
for start_n in range(0, hi):
# -- load k --
@ -387,14 +401,7 @@ def forward_inner(
m_i = m_ij
# update pointers
indices_idx = start_n // SPARSE_KV_MULTIPLE
cur_block = tl.load(kv_indices + indices_idx, eviction_policy="evict_last")
next_block = tl.load(kv_indices + indices_idx + 1, eviction_policy="evict_last", mask=indices_idx + 1 < sparse_kv_num_blocks)
needs_jump = (start_n + 1) % SPARSE_KV_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_KV_BLOCK_SIZE - (SPARSE_KV_MULTIPLE - 1) * BLOCK_N
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK_N
offset = get_offset_for_next_block(start_n, kv_indices, kv_num_blocks, SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N)
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
@ -402,8 +409,8 @@ def forward_inner(
offs_n = offs_n + offset
return acc, l_i, m_i
""",
"""
+ compute_next_offset_func,
)
@ -737,7 +744,7 @@ flex_attention_backward_template = TritonTemplate(
#
# Q: Query, K: Key, V: Value
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
# DELTA: Precomputed sum(OUT* DO, axis=1)
# DELTA: Precomputed sum(OUT*DO, axis=-1)
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
# inductor codegen
@ -781,7 +788,7 @@ flex_attention_backward_template = TritonTemplate(
MATMUL_PRECISION = Q.dtype.element_ty
pid = tl.program_id(0)
NUM_KV_BLOCKS = KV_LEN // BLOCK_N1
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
off_hz = tl.program_id(2)
off_z = off_hz // H # batch idx
@ -793,9 +800,6 @@ flex_attention_backward_template = TritonTemplate(
sparse_idx_z = off_z % SM_Z
sparse_idx_h = off_h % SM_H
SPARSE_Q_BLOCK_CNT = Q_LEN // SPARSE_Q_BLOCK_SIZE
SPARSE_KV_BLOCK_CNT = KV_LEN // SPARSE_KV_BLOCK_SIZE
sparse_hz_offset = sparse_idx_z * SM_H + sparse_idx_h
off_chz = (off_hz * Q_LEN).to(tl.int64)
@ -827,9 +831,15 @@ flex_attention_backward_template = TritonTemplate(
# THIS BLOCK DOES DQ
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
off_pid_mask = off_pid // SPARSE_Q_MULTIPLE
sparse_kv_num_blks_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT + off_pid // SPARSE_Q_MULTIPLE
sparse_kv_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (off_pid // SPARSE_Q_MULTIPLE) * SPARSE_KV_BLOCK_CNT # noqa: B950
KV_IDX_N = {{size("KV_IDX", 3)}}
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
@ -850,15 +860,17 @@ flex_attention_backward_template = TritonTemplate(
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# KV_IDX and KV_NUM_BLKS are always contiguous.
kv_indices = KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
kv_start = tl.load(KV_IDX + sparse_kv_idx_offset) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
dq, q, K, V, do, Di, lse,
off_z, off_h, offs_m2, offs_k,
off_z, off_h, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_start, kv_indices, sparse_kv_num_blocks, SPARSE_KV_MULTIPLE, SPARSE_KV_BLOCK_SIZE,
BLOCK_M2, BLOCK_N2, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS=False
)
@ -870,12 +882,14 @@ flex_attention_backward_template = TritonTemplate(
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
dq, q, K, V, do, Di, lse,
off_z, off_h, offs_m2, offs_k,
off_z, off_h, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_start, kv_indices, sparse_kv_num_blocks, SPARSE_KV_MULTIPLE, SPARSE_KV_BLOCK_SIZE,
BLOCK_M2, BLOCK_N2, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS=True
)
@ -889,8 +903,15 @@ flex_attention_backward_template = TritonTemplate(
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
sparse_q_num_blks_offset = sparse_hz_offset * SPARSE_KV_BLOCK_CNT + pid // SPARSE_KV_MULTIPLE
sparse_q_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (pid // SPARSE_KV_MULTIPLE) * SPARSE_Q_BLOCK_CNT # noqa: B950
pid_mask = pid // SPARSE_KV_MULTIPLE
Q_IDX_M = {{size("Q_IDX", 3)}}
stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}}
stride_q_idx_h = {{stride("Q_IDX", 1)}}
stride_q_idx_n = {{stride("Q_IDX", 2)}}
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
sparse_q_idx_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask * stride_q_idx_n # noqa: B950
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
@ -910,14 +931,15 @@ flex_attention_backward_template = TritonTemplate(
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
start_m1 = q_start
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
dk, dv = bwd_dkdv_inner(
dk, dv, Q, k, v, DO, DELTA, LSE,
off_z, off_h, offs_n1, offs_k, start_n1, start_m1,
off_z, off_h, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_start, q_indices, sparse_q_num_blocks, SPARSE_Q_MULTIPLE, SPARSE_Q_BLOCK_SIZE,
BLOCK_M1, BLOCK_N1, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS=False
)
@ -930,51 +952,56 @@ flex_attention_backward_template = TritonTemplate(
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
start_m1 = q_start
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
dk, dv = bwd_dkdv_inner(
dk, dv, Q, k, v, DO, DELTA, LSE,
off_z, off_h, offs_n1, offs_k, start_n1, start_m1,
off_z, off_h, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_start, q_indices, sparse_q_num_blocks, SPARSE_Q_MULTIPLE, SPARSE_Q_BLOCK_SIZE,
BLOCK_M1, BLOCK_N1, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS=True
)
# Write back dV and dK.
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_k[None, :] * stride_dvd
tl.store(dv_ptrs, dv)
# Write back dK.
index_n = offs_n1[:, None]
index_k = offs_k[None, :]
tl.store(dv_ptrs, dv)
dk *= SM_SCALE
mask = index_n <= KV_LEN
mask = index_n < KV_LEN
{{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
@triton.jit
def bwd_dq_inner(
dq, q, K, V, do, Di, lse,
off_z, off_h, offs_m2, offs_k,
off_z, off_h, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_start, kv_indices, sparse_kv_num_blocks, SPARSE_KV_MULTIPLE, SPARSE_KV_BLOCK_SIZE,
BLOCK_M2, BLOCK_N2, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
{{gen_argdefs()}}, IS_FULL_BLOCKS
):
start_n2 = kv_start
offs_n2 = start_n2 + tl.arange(0, BLOCK_N2)
{{gen_defines() | indent_except_first(1) }}
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = {{size("Q", 2)}}
KV_LEN = {{size("K", 2)}}
offs_k = tl.arange(0, BLOCK_DMODEL)
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_k[:, None] * stride_vd
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
curr_n = start_n2
hi = sparse_kv_num_blocks * SPARSE_KV_MULTIPLE
for start_n in range(0, hi):
offs_n2 = curr_n + tl.arange(0, BLOCK_N2)
kT = tl.load(kT_ptrs)
vT = tl.load(vT_ptrs)
qk = tl.dot(q, kT)
if not PRESCALE_QK:
qk *= SM_SCALE
@ -1011,6 +1038,7 @@ def bwd_dq_inner(
post_mod_scores *= RCP_LN2
p = tl.math.exp2(post_mod_scores - lse)
# Compute dP and dS.
vT = tl.load(vT_ptrs)
dp = tl.dot(do, vT)
ds = p * (dp - Di[:, None])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
@ -1035,17 +1063,12 @@ def bwd_dq_inner(
dq += tl.dot(ds, tl.trans(kT))
# Increment pointers.
indices_idx = start_n // SPARSE_KV_MULTIPLE
cur_block = tl.load(kv_indices + indices_idx)
next_block = tl.load(kv_indices + indices_idx + 1, mask=indices_idx + 1 < sparse_kv_num_blocks)
needs_jump = (start_n + 1) % SPARSE_KV_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_KV_BLOCK_SIZE - (SPARSE_KV_MULTIPLE - 1) * BLOCK_N2
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK_N2
offset = get_offset_for_next_block(start_n, kv_indices, sparse_kv_num_blocks, SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2)
kT_ptrs += offset * stride_kn
vT_ptrs += offset * stride_vn
curr_n += offset
offs_n2 += offset
return dq
@ -1053,25 +1076,29 @@ def bwd_dq_inner(
@triton.jit
def bwd_dkdv_inner(
dk, dv, Q, k, v, DO, DELTA, LSE,
off_z, off_h, offs_n1, offs_k, start_n1, start_m1,
off_z, off_h, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_start, q_indices, sparse_q_num_blocks, SPARSE_Q_MULTIPLE, SPARSE_Q_BLOCK_SIZE,
BLOCK_M1, BLOCK_N1, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
{{gen_argdefs()}}, IS_FULL_BLOCKS
):
offs_m1 = start_m1 + tl.arange(0, BLOCK_M1)
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
{{gen_defines() | indent_except_first(1) }}
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = {{size("Q", 2)}}
KV_LEN = {{size("K", 2)}}
offs_k = tl.arange(0, BLOCK_DMODEL)
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_k[None, :] * stride_dod
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
curr_m = start_m1
hi = sparse_q_num_blocks * SPARSE_Q_MULTIPLE
for q_start in range(0, hi):
qT = tl.load(qT_ptrs)
for start_m in range(0, hi):
# Load LSE before computing qk to reduce pipeline stall.
offs_m1 = curr_m + tl.arange(0, BLOCK_M1)
qT = tl.load(qT_ptrs)
lse = tl.load(LSE + offs_m1)
qkT = tl.dot(k, qT)
if not PRESCALE_QK:
@ -1113,6 +1140,7 @@ def bwd_dkdv_inner(
Di = tl.load(DELTA + offs_m1)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do))
# dpT = tl.where(offs_m1[None, :] < Q_LEN, dpT, 0.0)
dsT = pT * (dpT - Di[None, :])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
m = offs_m1[None, :]
@ -1134,21 +1162,16 @@ def bwd_dkdv_inner(
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT))
# Increment pointers.
indices_idx = q_start // SPARSE_Q_MULTIPLE
cur_block = tl.load(q_indices + indices_idx)
next_block = tl.load(q_indices + indices_idx + 1, mask=indices_idx + 1 < sparse_q_num_blocks)
needs_jump = (q_start + 1) % SPARSE_Q_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_Q_BLOCK_SIZE - (SPARSE_Q_MULTIPLE - 1) * BLOCK_M1
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK_M1
offset = get_offset_for_next_block(start_m, q_indices, sparse_q_num_blocks, SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1)
qT_ptrs += offset * stride_qm
do_ptrs += offset * stride_dom
curr_m += offset
offs_m1 += offset
return dk, dv
""",
"""
+ compute_next_offset_func,
)
@ -1252,6 +1275,7 @@ def flex_attention_backward(*args, **kwargs):
# Create delta which will is needed for the bwd's kernel
mul_delta = lowerings[aten.mul](out, grad_out)
delta = lowerings[aten.sum](mul_delta, axis=-1)
delta = ExternKernel.require_contiguous(delta)
# see NOTE:[TritonTemplates with multiple outputs]
grad_query = empty_strided(

View file

@ -256,6 +256,9 @@ class TritonTemplateKernel(TritonKernel):
self.render_hooks["<ARGDEFS>"] = hook
return "<ARGDEFS>"
def gen_defines(self):
return self.defines
def def_kernel(self, *argnames):
"""
Hook called from template code to generate function def and
@ -515,6 +518,7 @@ class TritonTemplateKernel(TritonKernel):
self.make_load,
self.modification,
self.gen_argdefs,
self.gen_defines,
]
}
@ -632,7 +636,7 @@ class TritonTemplate(KernelTemplate):
assert self.template, "requires jinja2"
defines = StringIO()
for name, val in kwargs.items():
defines.write(f" {name} : tl.constexpr = {val}\n")
defines.write(f"{name} : tl.constexpr = {val}\n")
defines = defines.getvalue()
fake_out = ir.Buffer("buf_out", layout)

View file

@ -439,6 +439,10 @@ def _broadcast_to_dim(x, dim):
return x
def round_up_to_multiple(x, multiple):
return (x + multiple - 1) // multiple * multiple
def _convert_mask_to_block_mask(
mask: Tensor,
KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
@ -620,7 +624,8 @@ def create_block_mask(
mod_type == _ModificationType.MASK_MOD
), "create-block_mask requires a mask_mod function!"
inner_func = _create_block_mask_inner
# Temporary work around see: _create_block_mask_inner for more details
Q_LEN = round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
KV_LEN = round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE)
if _compile:
inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)
with TransformGetItemToIndex():
@ -637,8 +642,8 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask:
of the query and key tensors.
"""
device = query.device
kv_len: int = key.size()[-2]
q_len: int = query.size()[-2]
kv_len = round_up_to_multiple(key.size()[-2], 128)
q_len = round_up_to_multiple(query.size()[-2], 128)
return BlockMask(
kv_num_blocks=torch.ones([1, 1, 1], dtype=torch.int32, device=device),
kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device),
@ -708,6 +713,8 @@ def flex_attention(
"""
# Some basic input validation
_validate_sdpa_input(query, key, value)
if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
raise NotImplementedError("NYI: query, key, and value must be 4D tensors")
if query.size(-2) >= 32: # use Attention Kernel
if query.size(-2) >= 128 and query.size(-2) % 128 != 0:
raise NotImplementedError("NYI: S must be <128 or a multiple of 128")