Add Full block support to flex_decoding (#131404)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131404
Approved by: https://github.com/yanboliang
This commit is contained in:
joydddd 2024-07-31 20:04:45 -07:00 committed by PyTorch MergeBot
parent 043e41f4f4
commit bdd83c4c7f
5 changed files with 261 additions and 176 deletions

View file

@ -37,6 +37,7 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) ->
class ExperimentConfig:
shape: Tuple[int]
score_mod: Callable
mask_mod: Callable
dtype: torch.dtype
calculate_bwd_time: bool
cal_bandwidth: bool
@ -122,7 +123,9 @@ def generate_inputs(
def run_single_experiment(
config: ExperimentConfig, dynamic=False, max_autotune=False, enable_mask=False
config: ExperimentConfig,
dynamic=False,
max_autotune=False,
) -> ExperimentResults:
device = torch.device("cuda")
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
@ -149,13 +152,14 @@ def run_single_experiment(
compiled_sdpa = torch.compile(flex_attention, dynamic=dynamic)
score_mod = config.score_mod
mask_mod = config.mask_mod
if enable_mask:
if mask_mod:
block_mask = create_block_mask(
score_mod, 1, 1, q_seq_len * (q_heads // kv_heads), kv_seq_len, query.device
mask_mod, 1, 1, q_seq_len * (q_heads // kv_heads), kv_seq_len, query.device
)
else:
block_mask = _create_empty_block_mask(query, key, value)
block_mask = _create_empty_block_mask(query, key)
forward_eager_time = benchmark_torch_function_in_microseconds(
eager_sdpa, query, key, value, score_mod
@ -328,7 +332,7 @@ def print_results(results: List[Experiment]):
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
def generate_score_mods(score_mods: List[str]) -> List[Callable]:
def generate_score_mods(score_mods: List[str]) -> List[Callable | None]:
def noop(score, b, h, m, n):
return score
@ -343,14 +347,33 @@ def generate_score_mods(score_mods: List[str]) -> List[Callable]:
function_dict = {
"noop": noop,
"causal": causal_mask,
"causal": None,
"rel": relative_bias,
"head_bias": head_bias,
}
return [function_dict[name] for name in score_mods]
def generate_mask_mods(score_mods: List[str]) -> List[Callable | None]:
def noop(b, h, m, n):
return True
def causal(b, h, m, n):
return m >= n
mask_mod_dict = {
"noop": None,
"causal": causal,
"rel": None,
"head_bias": None,
}
return [mask_mod_dict[name] for name in score_mods]
def get_gqa_score_mod(score_mod, G, q_seq_len):
if score_mod is None:
return None
def score_mod_gqa(score, b, hkv, m, n):
g = m // q_seq_len
new_m = m % q_seq_len
@ -362,6 +385,21 @@ def get_gqa_score_mod(score_mod, G, q_seq_len):
return score_mod_gqa
def get_gqa_mask_mod(mask_mod, G, q_seq_len):
if mask_mod is None:
return None
def mask_mod_gqa(b, h, m, n):
g = m // q_seq_len
new_m = m % q_seq_len
hq = h * G + g
return mask_mod(b, hq, new_m, n)
mask_mod_name = get_func_name(mask_mod)
set_func_name(mask_mod_gqa, mask_mod_name + "_gqa")
return mask_mod_gqa
def generate_experiment_configs(
calculate_bwd: bool,
dtype: torch.dtype,
@ -369,7 +407,7 @@ def generate_experiment_configs(
num_heads: List[Tuple[int, int]],
seq_lens: List[int],
head_dims: List[int],
score_mods: List[str],
score_mods_str: List[str],
decoding: bool,
kv_cache_size: List[int],
cal_bandwidth: bool,
@ -381,7 +419,8 @@ def generate_experiment_configs(
else:
q_kv_seq_lens = [(i, i) for i in seq_lens] # only testing q_len == kv_len
dtypes = [dtype]
score_mods = generate_score_mods(score_mods)
score_mods = generate_score_mods(score_mods_str)
mask_mods = generate_mask_mods(score_mods_str)
all_configs = []
for (
bsz,
@ -389,6 +428,7 @@ def generate_experiment_configs(
(q_seq_len, kv_seq_len),
head_dim,
score_mod,
mask_mod,
dtype,
) in itertools.product(
kv_cache_size if kv_cache_size else batch_sizes,
@ -396,6 +436,7 @@ def generate_experiment_configs(
q_kv_seq_lens,
head_dims,
score_mods,
mask_mods,
dtypes,
):
if kv_cache_size:
@ -410,11 +451,13 @@ def generate_experiment_configs(
assert q_heads % kv_heads == 0
G = q_heads // kv_heads
score_mod = get_gqa_score_mod(score_mod, G, q_seq_len)
mask_mod = get_gqa_mask_mod(mask_mod, G, q_seq_len)
all_configs.append(
ExperimentConfig(
shape=(bsz, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim),
score_mod=score_mod,
mask_mod=mask_mod,
dtype=dtype,
calculate_bwd_time=calculate_bwd,
cal_bandwidth=cal_bandwidth,
@ -450,7 +493,6 @@ def main(args):
config,
dynamic=args.dynamic,
max_autotune=args.max_autotune,
enable_mask=args.mask,
),
)
)
@ -526,9 +568,6 @@ Ignores -b batch size and calculate batch size from kv_cache size instead when s
action="store_true",
help="Calculate kernel memory bandwidth & computational throughput. ",
)
parser.add_argument(
"--mask", action="store_true", help="Enables block sparsity mask. "
)
# Parse arguments
args = parser.parse_args()

View file

@ -91,6 +91,19 @@ def _causal(
return torch.where(token_q >= token_kv, score, float("-inf"))
def _generate_windowed(offset):
def _windowed(score, b, h, q, kv):
return torch.where(q + offset >= kv, score, float("-inf"))
return _windowed
def _get_windowed_sdpa_mask(Mq, Mkv, offset):
return torch.tril(torch.ones(Mkv, Mkv, dtype=torch.bool, device="cuda"))[
offset : offset + Mq
]
def _rel_bias(
score: Tensor,
batch: Tensor,
@ -171,6 +184,7 @@ test_score_mods = [
_rel_bias,
_rel_causal,
_generate_alibi_bias(8),
_generate_windowed(1000),
]
captured_buffers_map = {
@ -825,42 +839,48 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
self.run_test(bias_mod)
@supported_platform
def test_causal_no_mask_vs_sdpa(self):
attention = functools.partial(flex_attention, score_mod=_causal)
def test_windowed_no_mask_vs_sdpa(self):
score_mod = _generate_windowed(1000)
attention = functools.partial(flex_attention, score_mod=score_mod)
sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
sdpa_attention = functools.partial(
torch.nn.functional.scaled_dot_product_attention, is_causal=True
torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
)
self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8)
@supported_platform
def test_causal_full_mask_vs_sdpa(self):
def test_windowed_full_mask_vs_sdpa(self):
def mask_mod(b, h, q, kv):
return q >= kv
return q + 1000 >= kv
score_mod = _generate_windowed(1000)
block_mask = create_block_mask(mask_mod, 1, 1, 8, S)
attention = functools.partial(
flex_attention, block_mask=block_mask, score_mod=_causal
flex_attention, block_mask=block_mask, score_mod=score_mod
)
sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
sdpa_attention = functools.partial(
torch.nn.functional.scaled_dot_product_attention, is_causal=True
torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
)
self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8)
@expectedFailure # TODO: add support for partial mask.
@supported_platform
def test_causal_partial_block_vs_sdpa(self):
def test_windowed_partial_block_vs_sdpa(self):
def mask_mod(b, h, q, kv):
return q >= kv
return q + 1000 >= kv
block_mask = create_block_mask(mask_mod, 1, 1, 8, S)
attention = functools.partial(flex_attention, block_mask=block_mask)
sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
sdpa_attention = functools.partial(
torch.nn.functional.scaled_dot_product_attention, is_causal=True
torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
)
self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8)

View file

@ -113,6 +113,7 @@ def build_subgraph_buffer(
raise ValueError("FlexAttention was passed a subgraph with no output node!")
# Inner Triton functions shared by flex_attention & split-k decoding kernels.
compute_next_offset_func = r"""
@triton.jit
def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK):
@ -126,10 +127,7 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK
return offset
"""
flex_attention_template = TritonTemplate(
name="flex_attention",
grid=flex_attention_grid,
source=r"""
compute_flex_attention = r"""
{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
# Sub notation for this kernel:
#
@ -248,6 +246,7 @@ flex_attention_template = TritonTemplate(
acc, l_i, m_i,
off_z, off_h, offs_m, offs_n,
kv_indices, kv_num_blocks,
0, kv_num_blocks * SPARSE_KV_MULTIPLE,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS=False
@ -285,6 +284,7 @@ flex_attention_template = TritonTemplate(
acc, l_i, m_i,
off_z, off_h, offs_m, offs_n,
kv_indices, kv_num_blocks,
0, kv_num_blocks * SPARSE_KV_MULTIPLE,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS=True
@ -311,8 +311,10 @@ flex_attention_template = TritonTemplate(
l_ptrs = LSE + off_hz * Q_LEN + offs_m
lse = m_i + tl.math.log2(l_i)
tl.store(l_ptrs, lse)
"""
compute_forward_block = r"""
@triton.jit
def forward_inner(
q, K_block_ptr, V_block_ptr,
@ -322,6 +324,8 @@ def forward_inner(
off_z, off_h, offs_m, offs_n,
# blocksparse data
kv_indices, kv_num_blocks,
# start kv and end kv block
block_n_start, block_n_end,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS,
@ -339,19 +343,17 @@ def forward_inner(
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
# loop over k, v and update accumulator
lo = 0
hi = kv_num_blocks * SPARSE_KV_MULTIPLE
for start_n in range(0, hi):
for start_n in range(block_n_start, block_n_end):
# -- load k --
k = tl.load(K_block_ptr)
# -- compute qk ---
qk = tl.dot(q, k)
qk = tl.dot(q, k) # TODO: use cuda matmul when q_len <= 2.
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
m = offs_m[:, None]
n = offs_n[None, :]
# TODO: Add load mask in modification when M/N Boundary is not safe
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
@ -413,8 +415,14 @@ def forward_inner(
offs_n = offs_n + offset
return acc, l_i, m_i
"""
+ compute_next_offset_func,
"""
flex_attention_template = TritonTemplate(
name="flex_attention",
grid=flex_attention_grid,
source=compute_flex_attention + compute_forward_block + compute_next_offset_func,
)
@ -570,17 +578,6 @@ def flex_attention(
subgraph_buffer = build_subgraph_buffer(
placeholder_inps + list(score_mod_other_buffers), subgraph
)
if _use_flex_decoding(query):
return create_flex_decoding_kernel(
subgraph_buffer,
query,
key,
value,
subgraph,
block_mask,
scale,
*score_mod_other_buffers,
)
mask_graph_placeholder_inps = [
create_placeholder(name, dtype, query.get_device())
for name, dtype in [
@ -593,6 +590,18 @@ def flex_attention(
mask_graph_buffer = build_subgraph_buffer(
mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
)
if _use_flex_decoding(query):
return create_flex_decoding_kernel(
query,
key,
value,
block_mask,
scale,
subgraph_buffer,
mask_graph_buffer,
score_mod_other_buffers,
mask_mod_other_buffers,
)
for buf in [
query,
key,

View file

@ -8,9 +8,10 @@ import torch
from torch._inductor.virtualized import V
from ..ir import FixedLayout, FlexibleLayout
from ..lowering import empty_strided, lowerings
from ..lowering import empty, empty_strided, lowerings
from ..runtime.runtime_utils import next_power_of_2
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
from .flex_attention import compute_forward_block, compute_next_offset_func
aten = torch.ops.aten
@ -32,7 +33,7 @@ flex_decoding_template = TritonTemplate(
name="flex_decoding",
grid=flex_decoding_grid,
source=r"""
{{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX")}}
{{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
# Sub notation for this kernel:
# Q: Query, K: Key, V: Value
# reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
@ -114,21 +115,12 @@ flex_decoding_template = TritonTemplate(
sparse_idx_z = off_z % SPARSE_Z
sparse_idx_h = off_h % SPARSE_H
# TODO: strided KV_IDX and KV_NUM_BLKS
sparse_hz_offset = sparse_idx_z * SPARSE_H + sparse_idx_h
kv_indices = KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
sparse_block_num = tl.load(KV_NUM_BLKS + sparse_hz_offset)
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
block_n_last_valid = sparse_block_num * SPARSE_KV_MULTIPLE # last valid block according to sparse mask
block_n_end = block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
# first kv block we're loading
# Calculate KV blocks that belong this CTA.
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
q_offset = off_z * stride_qz + off_h * stride_qh
k_offset = off_z * stride_kz + off_h * stride_kh
@ -141,6 +133,32 @@ flex_decoding_template = TritonTemplate(
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
if SAFE_M_BOUNDARY:
q = tl.load(Q_block_ptr)
else:
q = tl.load(Q_block_ptr, boundary_check=(0, ))
# initialize offsets
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_DMODEL)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Apply both score_mod and mask_mod
# find first kv block we are loading and the number of blocks we are loading
kv_indices = KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_hz_offset)
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
# first kv block we're loading
block_n_last_valid = kv_num_blocks * SPARSE_KV_MULTIPLE # last valid block according to sparse mask
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
@ -158,6 +176,66 @@ flex_decoding_template = TritonTemplate(
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
offs_n = tl.arange(0, BLOCK_N) + off_n
acc, l_i, m_i = forward_inner(
q, K_block_ptr, V_block_ptr,
# accumulatd values
acc, l_i, m_i,
#offsets
off_z, off_h, offs_m, offs_n,
#block sparse data
kv_indices, kv_num_blocks,
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS=False,
)
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We know these blocks are guaranteed to be "full", so we don't need to
# apply mask_mod to them - only score_mod
if HAS_FULL_BLOCKS:
kv_indices = FULL_KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_hz_offset)
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
block_n_last_valid = kv_num_blocks * SPARSE_KV_MULTIPLE # last valid block according to sparse mask
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(BLOCK_DMODEL, KV_LEN), # (d, N)
strides=(stride_kk, stride_kn),
offsets=(0, off_n),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(KV_LEN, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(off_n, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
offs_n = tl.arange(0, BLOCK_N) + off_n
acc, l_i, m_i = forward_inner(
q, K_block_ptr, V_block_ptr,
# accumulatd values
acc, l_i, m_i,
#offsets
off_z, off_h, offs_m, offs_n,
#block sparse data
kv_indices, kv_num_blocks,
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
MATMUL_PRECISION,
{{gen_argdefs()}},
IS_FULL_BLOCKS=True,
)
m_offset = off_h * stride_mh + off_z * stride_mz
l_offset = off_h * stride_lh + off_z * stride_lz
@ -178,92 +256,6 @@ flex_decoding_template = TritonTemplate(
order=(1, 0)
)
# initialize offsets
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N) + off_n
offs_d = tl.arange(0, BLOCK_DMODEL)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
if SAFE_M_BOUNDARY:
q = tl.load(Q_block_ptr)
else:
q = tl.load(Q_block_ptr, boundary_check=(0, ))
RCP_LN2 = 1.44269504
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
# loop over k, v and update accumulator
for start_n in range(block_n_start, block_n_end):
# -- load k, v --
k = tl.load(K_block_ptr).to(MATMUL_PRECISION)
v = tl.load(V_block_ptr)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk)
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
m = offs_m[:, None]
n = offs_n[None, :]
# TODO: Add load mask in modification when M/N Boundary is not safe
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
b="off_z",
h="off_h",
m="m",
n="n",
out="qk"
) | indent_except_first(2) }}
# TODO: In the case that score_mod is linear, this can be LICMed
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# -- compute scaling constant ---
row_max = tl.max(post_mod_scores, 1)
m_i_new = tl.maximum(m_i, row_max)
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(post_mod_scores - m_i_new[:, None])
if not ROWS_GUARANTEED_SAFE:
masked_out_rows = (m_i_new == float("-inf"))
alpha = tl.where(masked_out_rows, 0, alpha)
p = tl.where(masked_out_rows[:, None], 0, p)
# -- scale and update acc --
acc *= alpha[:, None]
acc = tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION), acc=acc)
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
# update pointers
indices_idx = start_n // SPARSE_KV_MULTIPLE
cur_block = tl.load(kv_indices + indices_idx, eviction_policy="evict_last")
next_block = tl.load(kv_indices + indices_idx + 1, eviction_policy="evict_last", mask=indices_idx + 1 < sparse_block_num)
needs_jump = (start_n + 1) % SPARSE_KV_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_KV_BLOCK_SIZE - (SPARSE_KV_MULTIPLE - 1) * BLOCK_N
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK_N
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
offs_n = offs_n + offset
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
if SAFE_M_BOUNDARY:
tl.store(M_block_ptr, m_i[None, :])
@ -281,7 +273,9 @@ flex_decoding_template = TritonTemplate(
# TODO generalize and add proper mask support
mask = (idx_m < Q_LEN)
{{store_output(("idx_z", "idx_h", "idx_t", "idx_m", "idx_d"), "acc", "mask")}}
""",
"""
+ compute_forward_block
+ compute_next_offset_func,
)
@ -305,36 +299,47 @@ def _get_decoding_default_config(key) -> Tuple[int, int, int]:
def create_flex_decoding_kernel(*args, **kwargs):
(
subgraph_buffer,
query,
key,
value,
subgraph,
block_mask,
scale,
*other_buffers,
score_mod_subgraph,
mask_mod_subgraph,
score_mod_other_buffers,
mask_mod_other_buffers,
) = args
(
sparse_kv_num_blocks,
sparse_kv_indices,
kv_num_blocks,
kv_indices,
full_kv_num_blocks, # full_kv_num_blocks,
_, # full_kv_indices,
full_kv_indices, # full_kv_indices,
_, # q_num_blocks
_, # q_indices
_, # full_q_num_blocks,
_, # full_q_indices,
SPARSE_KV_BLOCK_SIZE,
_, # SPARSE_Q_BLOCK_SIZE,
mask_graph,
_,
) = block_mask
if full_kv_num_blocks is not None:
raise NotImplementedError("NYI: Flex decoding only supports full mask. ")
# Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod
has_full_blocks = full_kv_num_blocks is not None
if (
full_kv_num_blocks is None
): # Create a plackeholder full block list in case it is empty
full_kv_num_blocks, full_kv_indices = (
empty(0, device=query.get_device()) for _ in range(2)
)
for buf in [
query,
key,
value,
sparse_kv_num_blocks,
sparse_kv_indices,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
]:
buf.realize()
choices: List[Any] = []
@ -412,12 +417,15 @@ def create_flex_decoding_kernel(*args, **kwargs):
value,
buf_M,
buf_L,
sparse_kv_num_blocks,
sparse_kv_indices,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
],
layout=layout_acc,
subgraphs=[
subgraph_buffer,
score_mod_subgraph,
mask_mod_subgraph,
],
mutated_inputs=[buf_M, buf_L],
num_stages=num_stages,
@ -429,23 +437,32 @@ def create_flex_decoding_kernel(*args, **kwargs):
SM_SCALE=scale,
# Performance tuning
BLOCK_N=BLOCK_N,
# Sparse block size
SPARSE_KV_BLOCK_SIZE=SPARSE_KV_BLOCK_SIZE,
# For now, we always assume the "sound" option
ROWS_GUARANTEED_SAFE=False,
SAFE_M_BOUNDARY=(query.get_size()[-2] % BLOCK_M) == 0,
SAFE_N_BOUNDARY=True,
PRESCALE_QK=False,
SPARSE_KV_BLOCK_SIZE=SPARSE_KV_BLOCK_SIZE,
HAS_FULL_BLOCKS=has_full_blocks,
)
inputs_for_flex_decoding = [
query,
key,
value,
buf_M,
buf_L,
sparse_kv_num_blocks,
sparse_kv_indices,
] + list(other_buffers)
inputs_for_flex_decoding = (
[
query,
key,
value,
buf_M,
buf_L,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
]
+ list(score_mod_other_buffers)
+ list(mask_mod_other_buffers)
)
buf_ACC = autotune_select_algorithm(
"flex_decoding", choices, inputs_for_flex_decoding, layout_acc
)

View file

@ -511,9 +511,6 @@ def _convert_mask_to_block_mask(
assert mask.dtype == torch.bool
mask = _broadcast_to_dim(mask, 4)
B, H, Q, KV = mask.shape
is_decoding = Q < 128
if is_decoding:
Q_BLOCK_SIZE = Q
assert Q % Q_BLOCK_SIZE == 0
assert KV % KV_BLOCK_SIZE == 0
mask = mask.view(
@ -525,7 +522,7 @@ def _convert_mask_to_block_mask(
mask_block_sum = mask.sum(
dim=[-2, -1]
) # [B, H, Q//Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE]
if separate_full_blocks and not is_decoding:
if separate_full_blocks:
full_block_sum = Q_BLOCK_SIZE * KV_BLOCK_SIZE
full_blocks = mask_block_sum == full_block_sum
partial_blocks = (mask_block_sum > 0) & (mask_block_sum < full_block_sum)
@ -732,7 +729,10 @@ def create_block_mask(
mod_type == _ModificationType.MASK_MOD
), f"create-block_mask requires a mask_mod function! Got {mask_mod}"
inner_func = _create_block_mask_inner
Q_LEN = Q_LEN if Q_LEN < 128 else _round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
if Q_LEN < 128:
Q_BLOCK_SIZE = Q_LEN
else:
Q_LEN = _round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
KV_LEN = _round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE)
if _compile:
inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)