mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Refactored flexattention kernel (#130904)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130904 Approved by: https://github.com/drisspg ghstack dependencies: #130871
This commit is contained in:
parent
ac76dd606f
commit
d59803fb67
4 changed files with 139 additions and 104 deletions
|
|
@ -1621,9 +1621,9 @@ class GraphModule(torch.nn.Module):
|
|||
NotImplementedError, "NYI: L must be a multiple of 128"
|
||||
):
|
||||
flex_attention(
|
||||
torch.randn((2, 3, 4)),
|
||||
torch.randn((2, 10, 5)),
|
||||
torch.randn((2, 10, 5)),
|
||||
torch.randn((1, 2, 3, 4)),
|
||||
torch.randn((1, 2, 10, 5)),
|
||||
torch.randn((1, 2, 10, 5)),
|
||||
score_mod=_identity,
|
||||
)
|
||||
|
||||
|
|
@ -1632,9 +1632,9 @@ class GraphModule(torch.nn.Module):
|
|||
):
|
||||
compiled_flex = torch.compile(flex_attention)
|
||||
compiled_flex(
|
||||
torch.randn((2, 3, 4)),
|
||||
torch.randn((2, 10, 5)),
|
||||
torch.randn((2, 10, 5)),
|
||||
torch.randn((1, 2, 3, 4)),
|
||||
torch.randn((1, 2, 10, 5)),
|
||||
torch.randn((1, 2, 10, 5)),
|
||||
score_mod=_identity,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from torch.utils._pytree import tree_map
|
|||
from .. import config
|
||||
from ..ir import (
|
||||
ComputedBuffer,
|
||||
ExternKernel,
|
||||
FixedLayout,
|
||||
FlexibleLayout,
|
||||
InputBuffer,
|
||||
|
|
@ -110,6 +111,19 @@ def build_subgraph_buffer(
|
|||
raise ValueError("FlexAttention was passed a subgraph with no output node!")
|
||||
|
||||
|
||||
compute_next_offset_func = r"""
|
||||
@triton.jit
|
||||
def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK):
|
||||
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
||||
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
||||
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
||||
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
||||
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
||||
|
||||
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
||||
return offset
|
||||
"""
|
||||
|
||||
flex_attention_template = TritonTemplate(
|
||||
name="flex_attention",
|
||||
grid=flex_attention_grid,
|
||||
|
|
@ -176,8 +190,8 @@ flex_attention_template = TritonTemplate(
|
|||
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
||||
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
||||
|
||||
SPARSE_Q_BLOCK_CNT: tl.constexpr = Q_LEN // SPARSE_Q_BLOCK_SIZE
|
||||
SPARSE_KV_BLOCK_CNT: tl.constexpr = KV_LEN // SPARSE_KV_BLOCK_SIZE
|
||||
SPARSE_Q_BLOCK_CNT: tl.constexpr = tl.cdiv(Q_LEN, SPARSE_Q_BLOCK_SIZE)
|
||||
SPARSE_KV_BLOCK_CNT: tl.constexpr = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
|
|
@ -207,7 +221,7 @@ flex_attention_template = TritonTemplate(
|
|||
# both score_mod and mask_mod to it
|
||||
kv_indices = KV_IDX + sparse_kv_idx_offset
|
||||
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
|
|
@ -225,15 +239,14 @@ flex_attention_template = TritonTemplate(
|
|||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
||||
|
||||
acc, l_i, m_i = forward_inner(
|
||||
q, K_block_ptr, V_block_ptr,
|
||||
acc, l_i, m_i,
|
||||
off_z, off_h, offs_m,
|
||||
kv_start,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE,
|
||||
BLOCK_M, BLOCK_N, BLOCK_DMODEL, PRESCALE_QK, SM_SCALE, ROWS_GUARANTEED_SAFE, MATMUL_PRECISION,
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
kv_indices, kv_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}},
|
||||
IS_FULL_BLOCKS=False
|
||||
)
|
||||
|
|
@ -245,7 +258,7 @@ flex_attention_template = TritonTemplate(
|
|||
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
||||
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
||||
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
|
|
@ -263,17 +276,14 @@ flex_attention_template = TritonTemplate(
|
|||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
# initialize offsets
|
||||
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
||||
|
||||
acc, l_i, m_i = forward_inner(
|
||||
q, K_block_ptr, V_block_ptr,
|
||||
acc, l_i, m_i,
|
||||
off_z, off_h, offs_m,
|
||||
kv_start,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE,
|
||||
BLOCK_M, BLOCK_N, BLOCK_DMODEL, PRESCALE_QK, SM_SCALE, ROWS_GUARANTEED_SAFE, MATMUL_PRECISION,
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
kv_indices, kv_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}},
|
||||
IS_FULL_BLOCKS=True
|
||||
)
|
||||
|
|
@ -303,27 +313,31 @@ flex_attention_template = TritonTemplate(
|
|||
@triton.jit
|
||||
def forward_inner(
|
||||
q, K_block_ptr, V_block_ptr,
|
||||
# accumulated values
|
||||
acc, l_i, m_i,
|
||||
off_z, off_h, offs_m,
|
||||
kv_start,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE,
|
||||
BLOCK_M, BLOCK_N, BLOCK_DMODEL, PRESCALE_QK, SM_SCALE, ROWS_GUARANTEED_SAFE, MATMUL_PRECISION,
|
||||
# Offsets
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
# blocksparse data
|
||||
kv_indices, kv_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}},
|
||||
IS_FULL_BLOCKS,
|
||||
):
|
||||
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
||||
{{gen_defines() | indent_except_first(1)}}
|
||||
|
||||
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
||||
|
||||
# initialize offsets
|
||||
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
||||
|
||||
RCP_LN2 = 1.44269504
|
||||
RCP_LN2: tl.constexpr = 1.44269504
|
||||
|
||||
if PRESCALE_QK:
|
||||
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
||||
|
||||
# loop over k, v and update accumulator
|
||||
lo = 0
|
||||
hi = sparse_kv_num_blocks * SPARSE_KV_MULTIPLE
|
||||
hi = kv_num_blocks * SPARSE_KV_MULTIPLE
|
||||
|
||||
for start_n in range(0, hi):
|
||||
# -- load k --
|
||||
|
|
@ -387,14 +401,7 @@ def forward_inner(
|
|||
m_i = m_ij
|
||||
|
||||
# update pointers
|
||||
indices_idx = start_n // SPARSE_KV_MULTIPLE
|
||||
|
||||
cur_block = tl.load(kv_indices + indices_idx, eviction_policy="evict_last")
|
||||
next_block = tl.load(kv_indices + indices_idx + 1, eviction_policy="evict_last", mask=indices_idx + 1 < sparse_kv_num_blocks)
|
||||
needs_jump = (start_n + 1) % SPARSE_KV_MULTIPLE == 0
|
||||
jump_to_block = (next_block - cur_block ) * SPARSE_KV_BLOCK_SIZE - (SPARSE_KV_MULTIPLE - 1) * BLOCK_N
|
||||
|
||||
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK_N
|
||||
offset = get_offset_for_next_block(start_n, kv_indices, kv_num_blocks, SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N)
|
||||
|
||||
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
|
||||
|
|
@ -402,8 +409,8 @@ def forward_inner(
|
|||
offs_n = offs_n + offset
|
||||
|
||||
return acc, l_i, m_i
|
||||
|
||||
""",
|
||||
"""
|
||||
+ compute_next_offset_func,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -737,7 +744,7 @@ flex_attention_backward_template = TritonTemplate(
|
|||
#
|
||||
# Q: Query, K: Key, V: Value
|
||||
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
||||
# DELTA: Precomputed sum(OUT* DO, axis=1)
|
||||
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
||||
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
||||
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
||||
# inductor codegen
|
||||
|
|
@ -781,7 +788,7 @@ flex_attention_backward_template = TritonTemplate(
|
|||
MATMUL_PRECISION = Q.dtype.element_ty
|
||||
|
||||
pid = tl.program_id(0)
|
||||
NUM_KV_BLOCKS = KV_LEN // BLOCK_N1
|
||||
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
||||
|
||||
off_hz = tl.program_id(2)
|
||||
off_z = off_hz // H # batch idx
|
||||
|
|
@ -793,9 +800,6 @@ flex_attention_backward_template = TritonTemplate(
|
|||
sparse_idx_z = off_z % SM_Z
|
||||
sparse_idx_h = off_h % SM_H
|
||||
|
||||
SPARSE_Q_BLOCK_CNT = Q_LEN // SPARSE_Q_BLOCK_SIZE
|
||||
SPARSE_KV_BLOCK_CNT = KV_LEN // SPARSE_KV_BLOCK_SIZE
|
||||
|
||||
sparse_hz_offset = sparse_idx_z * SM_H + sparse_idx_h
|
||||
|
||||
off_chz = (off_hz * Q_LEN).to(tl.int64)
|
||||
|
|
@ -827,9 +831,15 @@ flex_attention_backward_template = TritonTemplate(
|
|||
# THIS BLOCK DOES DQ
|
||||
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
||||
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
||||
off_pid_mask = off_pid // SPARSE_Q_MULTIPLE
|
||||
|
||||
sparse_kv_num_blks_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT + off_pid // SPARSE_Q_MULTIPLE
|
||||
sparse_kv_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (off_pid // SPARSE_Q_MULTIPLE) * SPARSE_KV_BLOCK_CNT # noqa: B950
|
||||
KV_IDX_N = {{size("KV_IDX", 3)}}
|
||||
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
|
||||
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
|
||||
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
|
||||
|
||||
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
||||
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
||||
|
||||
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
|
|
@ -850,15 +860,17 @@ flex_attention_backward_template = TritonTemplate(
|
|||
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
||||
kv_indices = KV_IDX + sparse_kv_idx_offset
|
||||
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
kv_start = tl.load(KV_IDX + sparse_kv_idx_offset) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
|
||||
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
||||
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
||||
dq = bwd_dq_inner(
|
||||
dq, q, K, V, do, Di, lse,
|
||||
off_z, off_h, offs_m2, offs_k,
|
||||
off_z, off_h, offs_m2, offs_n2,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_start, kv_indices, sparse_kv_num_blocks, SPARSE_KV_MULTIPLE, SPARSE_KV_BLOCK_SIZE,
|
||||
BLOCK_M2, BLOCK_N2, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}},
|
||||
IS_FULL_BLOCKS=False
|
||||
)
|
||||
|
|
@ -870,12 +882,14 @@ flex_attention_backward_template = TritonTemplate(
|
|||
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
|
||||
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
||||
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
||||
dq = bwd_dq_inner(
|
||||
dq, q, K, V, do, Di, lse,
|
||||
off_z, off_h, offs_m2, offs_k,
|
||||
off_z, off_h, offs_m2, offs_n2,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_start, kv_indices, sparse_kv_num_blocks, SPARSE_KV_MULTIPLE, SPARSE_KV_BLOCK_SIZE,
|
||||
BLOCK_M2, BLOCK_N2, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}},
|
||||
IS_FULL_BLOCKS=True
|
||||
)
|
||||
|
|
@ -889,8 +903,15 @@ flex_attention_backward_template = TritonTemplate(
|
|||
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
||||
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
||||
|
||||
sparse_q_num_blks_offset = sparse_hz_offset * SPARSE_KV_BLOCK_CNT + pid // SPARSE_KV_MULTIPLE
|
||||
sparse_q_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (pid // SPARSE_KV_MULTIPLE) * SPARSE_Q_BLOCK_CNT # noqa: B950
|
||||
pid_mask = pid // SPARSE_KV_MULTIPLE
|
||||
|
||||
Q_IDX_M = {{size("Q_IDX", 3)}}
|
||||
stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}}
|
||||
stride_q_idx_h = {{stride("Q_IDX", 1)}}
|
||||
stride_q_idx_n = {{stride("Q_IDX", 2)}}
|
||||
|
||||
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
||||
sparse_q_idx_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask * stride_q_idx_n # noqa: B950
|
||||
|
||||
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
|
@ -910,14 +931,15 @@ flex_attention_backward_template = TritonTemplate(
|
|||
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
||||
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
||||
|
||||
start_m1 = q_start
|
||||
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
||||
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
||||
|
||||
dk, dv = bwd_dkdv_inner(
|
||||
dk, dv, Q, k, v, DO, DELTA, LSE,
|
||||
off_z, off_h, offs_n1, offs_k, start_n1, start_m1,
|
||||
off_z, off_h, offs_n1, offs_m1,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_start, q_indices, sparse_q_num_blocks, SPARSE_Q_MULTIPLE, SPARSE_Q_BLOCK_SIZE,
|
||||
BLOCK_M1, BLOCK_N1, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}},
|
||||
IS_FULL_BLOCKS=False
|
||||
)
|
||||
|
|
@ -930,51 +952,56 @@ flex_attention_backward_template = TritonTemplate(
|
|||
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
||||
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
||||
|
||||
start_m1 = q_start
|
||||
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
||||
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
||||
|
||||
dk, dv = bwd_dkdv_inner(
|
||||
dk, dv, Q, k, v, DO, DELTA, LSE,
|
||||
off_z, off_h, offs_n1, offs_k, start_n1, start_m1,
|
||||
off_z, off_h, offs_n1, offs_m1,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_start, q_indices, sparse_q_num_blocks, SPARSE_Q_MULTIPLE, SPARSE_Q_BLOCK_SIZE,
|
||||
BLOCK_M1, BLOCK_N1, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}},
|
||||
IS_FULL_BLOCKS=True
|
||||
)
|
||||
|
||||
# Write back dV and dK.
|
||||
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_k[None, :] * stride_dvd
|
||||
tl.store(dv_ptrs, dv)
|
||||
|
||||
# Write back dK.
|
||||
index_n = offs_n1[:, None]
|
||||
index_k = offs_k[None, :]
|
||||
|
||||
tl.store(dv_ptrs, dv)
|
||||
|
||||
dk *= SM_SCALE
|
||||
mask = index_n <= KV_LEN
|
||||
mask = index_n < KV_LEN
|
||||
{{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
|
||||
|
||||
@triton.jit
|
||||
def bwd_dq_inner(
|
||||
dq, q, K, V, do, Di, lse,
|
||||
off_z, off_h, offs_m2, offs_k,
|
||||
off_z, off_h, offs_m2, offs_n2,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_start, kv_indices, sparse_kv_num_blocks, SPARSE_KV_MULTIPLE, SPARSE_KV_BLOCK_SIZE,
|
||||
BLOCK_M2, BLOCK_N2, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}}, IS_FULL_BLOCKS
|
||||
):
|
||||
start_n2 = kv_start
|
||||
offs_n2 = start_n2 + tl.arange(0, BLOCK_N2)
|
||||
{{gen_defines() | indent_except_first(1) }}
|
||||
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
||||
RCP_LN2: tl.constexpr = 1.44269504
|
||||
Q_LEN = {{size("Q", 2)}}
|
||||
KV_LEN = {{size("K", 2)}}
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
||||
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_k[:, None] * stride_vd
|
||||
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
||||
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
||||
|
||||
curr_n = start_n2
|
||||
hi = sparse_kv_num_blocks * SPARSE_KV_MULTIPLE
|
||||
for start_n in range(0, hi):
|
||||
offs_n2 = curr_n + tl.arange(0, BLOCK_N2)
|
||||
kT = tl.load(kT_ptrs)
|
||||
vT = tl.load(vT_ptrs)
|
||||
qk = tl.dot(q, kT)
|
||||
if not PRESCALE_QK:
|
||||
qk *= SM_SCALE
|
||||
|
|
@ -1011,6 +1038,7 @@ def bwd_dq_inner(
|
|||
post_mod_scores *= RCP_LN2
|
||||
p = tl.math.exp2(post_mod_scores - lse)
|
||||
# Compute dP and dS.
|
||||
vT = tl.load(vT_ptrs)
|
||||
dp = tl.dot(do, vT)
|
||||
ds = p * (dp - Di[:, None])
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
||||
|
|
@ -1035,17 +1063,12 @@ def bwd_dq_inner(
|
|||
dq += tl.dot(ds, tl.trans(kT))
|
||||
|
||||
# Increment pointers.
|
||||
indices_idx = start_n // SPARSE_KV_MULTIPLE
|
||||
cur_block = tl.load(kv_indices + indices_idx)
|
||||
next_block = tl.load(kv_indices + indices_idx + 1, mask=indices_idx + 1 < sparse_kv_num_blocks)
|
||||
needs_jump = (start_n + 1) % SPARSE_KV_MULTIPLE == 0
|
||||
jump_to_block = (next_block - cur_block ) * SPARSE_KV_BLOCK_SIZE - (SPARSE_KV_MULTIPLE - 1) * BLOCK_N2
|
||||
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK_N2
|
||||
offset = get_offset_for_next_block(start_n, kv_indices, sparse_kv_num_blocks, SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2)
|
||||
|
||||
kT_ptrs += offset * stride_kn
|
||||
vT_ptrs += offset * stride_vn
|
||||
|
||||
curr_n += offset
|
||||
offs_n2 += offset
|
||||
|
||||
return dq
|
||||
|
||||
|
|
@ -1053,25 +1076,29 @@ def bwd_dq_inner(
|
|||
@triton.jit
|
||||
def bwd_dkdv_inner(
|
||||
dk, dv, Q, k, v, DO, DELTA, LSE,
|
||||
off_z, off_h, offs_n1, offs_k, start_n1, start_m1,
|
||||
off_z, off_h, offs_n1, offs_m1,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_start, q_indices, sparse_q_num_blocks, SPARSE_Q_MULTIPLE, SPARSE_Q_BLOCK_SIZE,
|
||||
BLOCK_M1, BLOCK_N1, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}}, IS_FULL_BLOCKS
|
||||
):
|
||||
offs_m1 = start_m1 + tl.arange(0, BLOCK_M1)
|
||||
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
||||
{{gen_defines() | indent_except_first(1) }}
|
||||
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
||||
RCP_LN2: tl.constexpr = 1.44269504
|
||||
Q_LEN = {{size("Q", 2)}}
|
||||
KV_LEN = {{size("K", 2)}}
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
||||
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_k[None, :] * stride_dod
|
||||
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
||||
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
||||
|
||||
curr_m = start_m1
|
||||
hi = sparse_q_num_blocks * SPARSE_Q_MULTIPLE
|
||||
for q_start in range(0, hi):
|
||||
qT = tl.load(qT_ptrs)
|
||||
for start_m in range(0, hi):
|
||||
# Load LSE before computing qk to reduce pipeline stall.
|
||||
offs_m1 = curr_m + tl.arange(0, BLOCK_M1)
|
||||
|
||||
qT = tl.load(qT_ptrs)
|
||||
lse = tl.load(LSE + offs_m1)
|
||||
qkT = tl.dot(k, qT)
|
||||
if not PRESCALE_QK:
|
||||
|
|
@ -1113,6 +1140,7 @@ def bwd_dkdv_inner(
|
|||
Di = tl.load(DELTA + offs_m1)
|
||||
# Compute dP and dS.
|
||||
dpT = tl.dot(v, tl.trans(do))
|
||||
# dpT = tl.where(offs_m1[None, :] < Q_LEN, dpT, 0.0)
|
||||
dsT = pT * (dpT - Di[None, :])
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
||||
m = offs_m1[None, :]
|
||||
|
|
@ -1134,21 +1162,16 @@ def bwd_dkdv_inner(
|
|||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT))
|
||||
# Increment pointers.
|
||||
indices_idx = q_start // SPARSE_Q_MULTIPLE
|
||||
cur_block = tl.load(q_indices + indices_idx)
|
||||
next_block = tl.load(q_indices + indices_idx + 1, mask=indices_idx + 1 < sparse_q_num_blocks)
|
||||
needs_jump = (q_start + 1) % SPARSE_Q_MULTIPLE == 0
|
||||
jump_to_block = (next_block - cur_block ) * SPARSE_Q_BLOCK_SIZE - (SPARSE_Q_MULTIPLE - 1) * BLOCK_M1
|
||||
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK_M1
|
||||
offset = get_offset_for_next_block(start_m, q_indices, sparse_q_num_blocks, SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1)
|
||||
|
||||
qT_ptrs += offset * stride_qm
|
||||
do_ptrs += offset * stride_dom
|
||||
|
||||
curr_m += offset
|
||||
offs_m1 += offset
|
||||
|
||||
return dk, dv
|
||||
|
||||
""",
|
||||
"""
|
||||
+ compute_next_offset_func,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1252,6 +1275,7 @@ def flex_attention_backward(*args, **kwargs):
|
|||
# Create delta which will is needed for the bwd's kernel
|
||||
mul_delta = lowerings[aten.mul](out, grad_out)
|
||||
delta = lowerings[aten.sum](mul_delta, axis=-1)
|
||||
delta = ExternKernel.require_contiguous(delta)
|
||||
|
||||
# see NOTE:[TritonTemplates with multiple outputs]
|
||||
grad_query = empty_strided(
|
||||
|
|
|
|||
|
|
@ -256,6 +256,9 @@ class TritonTemplateKernel(TritonKernel):
|
|||
self.render_hooks["<ARGDEFS>"] = hook
|
||||
return "<ARGDEFS>"
|
||||
|
||||
def gen_defines(self):
|
||||
return self.defines
|
||||
|
||||
def def_kernel(self, *argnames):
|
||||
"""
|
||||
Hook called from template code to generate function def and
|
||||
|
|
@ -515,6 +518,7 @@ class TritonTemplateKernel(TritonKernel):
|
|||
self.make_load,
|
||||
self.modification,
|
||||
self.gen_argdefs,
|
||||
self.gen_defines,
|
||||
]
|
||||
}
|
||||
|
||||
|
|
@ -632,7 +636,7 @@ class TritonTemplate(KernelTemplate):
|
|||
assert self.template, "requires jinja2"
|
||||
defines = StringIO()
|
||||
for name, val in kwargs.items():
|
||||
defines.write(f" {name} : tl.constexpr = {val}\n")
|
||||
defines.write(f"{name} : tl.constexpr = {val}\n")
|
||||
defines = defines.getvalue()
|
||||
|
||||
fake_out = ir.Buffer("buf_out", layout)
|
||||
|
|
|
|||
|
|
@ -439,6 +439,10 @@ def _broadcast_to_dim(x, dim):
|
|||
return x
|
||||
|
||||
|
||||
def round_up_to_multiple(x, multiple):
|
||||
return (x + multiple - 1) // multiple * multiple
|
||||
|
||||
|
||||
def _convert_mask_to_block_mask(
|
||||
mask: Tensor,
|
||||
KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
|
|
@ -620,7 +624,8 @@ def create_block_mask(
|
|||
mod_type == _ModificationType.MASK_MOD
|
||||
), "create-block_mask requires a mask_mod function!"
|
||||
inner_func = _create_block_mask_inner
|
||||
# Temporary work around see: _create_block_mask_inner for more details
|
||||
Q_LEN = round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
|
||||
KV_LEN = round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE)
|
||||
if _compile:
|
||||
inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)
|
||||
with TransformGetItemToIndex():
|
||||
|
|
@ -637,8 +642,8 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask:
|
|||
of the query and key tensors.
|
||||
"""
|
||||
device = query.device
|
||||
kv_len: int = key.size()[-2]
|
||||
q_len: int = query.size()[-2]
|
||||
kv_len = round_up_to_multiple(key.size()[-2], 128)
|
||||
q_len = round_up_to_multiple(query.size()[-2], 128)
|
||||
return BlockMask(
|
||||
kv_num_blocks=torch.ones([1, 1, 1], dtype=torch.int32, device=device),
|
||||
kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device),
|
||||
|
|
@ -708,6 +713,8 @@ def flex_attention(
|
|||
"""
|
||||
# Some basic input validation
|
||||
_validate_sdpa_input(query, key, value)
|
||||
if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
|
||||
raise NotImplementedError("NYI: query, key, and value must be 4D tensors")
|
||||
if query.size(-2) >= 32: # use Attention Kernel
|
||||
if query.size(-2) >= 128 and query.size(-2) % 128 != 0:
|
||||
raise NotImplementedError("NYI: S must be <128 or a multiple of 128")
|
||||
|
|
|
|||
Loading…
Reference in a new issue