mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add Full block support to flex_decoding (#131404)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131404 Approved by: https://github.com/yanboliang
This commit is contained in:
parent
043e41f4f4
commit
bdd83c4c7f
5 changed files with 261 additions and 176 deletions
|
|
@ -37,6 +37,7 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) ->
|
|||
class ExperimentConfig:
|
||||
shape: Tuple[int]
|
||||
score_mod: Callable
|
||||
mask_mod: Callable
|
||||
dtype: torch.dtype
|
||||
calculate_bwd_time: bool
|
||||
cal_bandwidth: bool
|
||||
|
|
@ -122,7 +123,9 @@ def generate_inputs(
|
|||
|
||||
|
||||
def run_single_experiment(
|
||||
config: ExperimentConfig, dynamic=False, max_autotune=False, enable_mask=False
|
||||
config: ExperimentConfig,
|
||||
dynamic=False,
|
||||
max_autotune=False,
|
||||
) -> ExperimentResults:
|
||||
device = torch.device("cuda")
|
||||
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
|
||||
|
|
@ -149,13 +152,14 @@ def run_single_experiment(
|
|||
compiled_sdpa = torch.compile(flex_attention, dynamic=dynamic)
|
||||
|
||||
score_mod = config.score_mod
|
||||
mask_mod = config.mask_mod
|
||||
|
||||
if enable_mask:
|
||||
if mask_mod:
|
||||
block_mask = create_block_mask(
|
||||
score_mod, 1, 1, q_seq_len * (q_heads // kv_heads), kv_seq_len, query.device
|
||||
mask_mod, 1, 1, q_seq_len * (q_heads // kv_heads), kv_seq_len, query.device
|
||||
)
|
||||
else:
|
||||
block_mask = _create_empty_block_mask(query, key, value)
|
||||
block_mask = _create_empty_block_mask(query, key)
|
||||
|
||||
forward_eager_time = benchmark_torch_function_in_microseconds(
|
||||
eager_sdpa, query, key, value, score_mod
|
||||
|
|
@ -328,7 +332,7 @@ def print_results(results: List[Experiment]):
|
|||
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
|
||||
def generate_score_mods(score_mods: List[str]) -> List[Callable]:
|
||||
def generate_score_mods(score_mods: List[str]) -> List[Callable | None]:
|
||||
def noop(score, b, h, m, n):
|
||||
return score
|
||||
|
||||
|
|
@ -343,14 +347,33 @@ def generate_score_mods(score_mods: List[str]) -> List[Callable]:
|
|||
|
||||
function_dict = {
|
||||
"noop": noop,
|
||||
"causal": causal_mask,
|
||||
"causal": None,
|
||||
"rel": relative_bias,
|
||||
"head_bias": head_bias,
|
||||
}
|
||||
return [function_dict[name] for name in score_mods]
|
||||
|
||||
|
||||
def generate_mask_mods(score_mods: List[str]) -> List[Callable | None]:
|
||||
def noop(b, h, m, n):
|
||||
return True
|
||||
|
||||
def causal(b, h, m, n):
|
||||
return m >= n
|
||||
|
||||
mask_mod_dict = {
|
||||
"noop": None,
|
||||
"causal": causal,
|
||||
"rel": None,
|
||||
"head_bias": None,
|
||||
}
|
||||
return [mask_mod_dict[name] for name in score_mods]
|
||||
|
||||
|
||||
def get_gqa_score_mod(score_mod, G, q_seq_len):
|
||||
if score_mod is None:
|
||||
return None
|
||||
|
||||
def score_mod_gqa(score, b, hkv, m, n):
|
||||
g = m // q_seq_len
|
||||
new_m = m % q_seq_len
|
||||
|
|
@ -362,6 +385,21 @@ def get_gqa_score_mod(score_mod, G, q_seq_len):
|
|||
return score_mod_gqa
|
||||
|
||||
|
||||
def get_gqa_mask_mod(mask_mod, G, q_seq_len):
|
||||
if mask_mod is None:
|
||||
return None
|
||||
|
||||
def mask_mod_gqa(b, h, m, n):
|
||||
g = m // q_seq_len
|
||||
new_m = m % q_seq_len
|
||||
hq = h * G + g
|
||||
return mask_mod(b, hq, new_m, n)
|
||||
|
||||
mask_mod_name = get_func_name(mask_mod)
|
||||
set_func_name(mask_mod_gqa, mask_mod_name + "_gqa")
|
||||
return mask_mod_gqa
|
||||
|
||||
|
||||
def generate_experiment_configs(
|
||||
calculate_bwd: bool,
|
||||
dtype: torch.dtype,
|
||||
|
|
@ -369,7 +407,7 @@ def generate_experiment_configs(
|
|||
num_heads: List[Tuple[int, int]],
|
||||
seq_lens: List[int],
|
||||
head_dims: List[int],
|
||||
score_mods: List[str],
|
||||
score_mods_str: List[str],
|
||||
decoding: bool,
|
||||
kv_cache_size: List[int],
|
||||
cal_bandwidth: bool,
|
||||
|
|
@ -381,7 +419,8 @@ def generate_experiment_configs(
|
|||
else:
|
||||
q_kv_seq_lens = [(i, i) for i in seq_lens] # only testing q_len == kv_len
|
||||
dtypes = [dtype]
|
||||
score_mods = generate_score_mods(score_mods)
|
||||
score_mods = generate_score_mods(score_mods_str)
|
||||
mask_mods = generate_mask_mods(score_mods_str)
|
||||
all_configs = []
|
||||
for (
|
||||
bsz,
|
||||
|
|
@ -389,6 +428,7 @@ def generate_experiment_configs(
|
|||
(q_seq_len, kv_seq_len),
|
||||
head_dim,
|
||||
score_mod,
|
||||
mask_mod,
|
||||
dtype,
|
||||
) in itertools.product(
|
||||
kv_cache_size if kv_cache_size else batch_sizes,
|
||||
|
|
@ -396,6 +436,7 @@ def generate_experiment_configs(
|
|||
q_kv_seq_lens,
|
||||
head_dims,
|
||||
score_mods,
|
||||
mask_mods,
|
||||
dtypes,
|
||||
):
|
||||
if kv_cache_size:
|
||||
|
|
@ -410,11 +451,13 @@ def generate_experiment_configs(
|
|||
assert q_heads % kv_heads == 0
|
||||
G = q_heads // kv_heads
|
||||
score_mod = get_gqa_score_mod(score_mod, G, q_seq_len)
|
||||
mask_mod = get_gqa_mask_mod(mask_mod, G, q_seq_len)
|
||||
|
||||
all_configs.append(
|
||||
ExperimentConfig(
|
||||
shape=(bsz, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim),
|
||||
score_mod=score_mod,
|
||||
mask_mod=mask_mod,
|
||||
dtype=dtype,
|
||||
calculate_bwd_time=calculate_bwd,
|
||||
cal_bandwidth=cal_bandwidth,
|
||||
|
|
@ -450,7 +493,6 @@ def main(args):
|
|||
config,
|
||||
dynamic=args.dynamic,
|
||||
max_autotune=args.max_autotune,
|
||||
enable_mask=args.mask,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -526,9 +568,6 @@ Ignores -b batch size and calculate batch size from kv_cache size instead when s
|
|||
action="store_true",
|
||||
help="Calculate kernel memory bandwidth & computational throughput. ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask", action="store_true", help="Enables block sparsity mask. "
|
||||
)
|
||||
|
||||
# Parse arguments
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -91,6 +91,19 @@ def _causal(
|
|||
return torch.where(token_q >= token_kv, score, float("-inf"))
|
||||
|
||||
|
||||
def _generate_windowed(offset):
|
||||
def _windowed(score, b, h, q, kv):
|
||||
return torch.where(q + offset >= kv, score, float("-inf"))
|
||||
|
||||
return _windowed
|
||||
|
||||
|
||||
def _get_windowed_sdpa_mask(Mq, Mkv, offset):
|
||||
return torch.tril(torch.ones(Mkv, Mkv, dtype=torch.bool, device="cuda"))[
|
||||
offset : offset + Mq
|
||||
]
|
||||
|
||||
|
||||
def _rel_bias(
|
||||
score: Tensor,
|
||||
batch: Tensor,
|
||||
|
|
@ -171,6 +184,7 @@ test_score_mods = [
|
|||
_rel_bias,
|
||||
_rel_causal,
|
||||
_generate_alibi_bias(8),
|
||||
_generate_windowed(1000),
|
||||
]
|
||||
|
||||
captured_buffers_map = {
|
||||
|
|
@ -825,42 +839,48 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
self.run_test(bias_mod)
|
||||
|
||||
@supported_platform
|
||||
def test_causal_no_mask_vs_sdpa(self):
|
||||
attention = functools.partial(flex_attention, score_mod=_causal)
|
||||
def test_windowed_no_mask_vs_sdpa(self):
|
||||
score_mod = _generate_windowed(1000)
|
||||
attention = functools.partial(flex_attention, score_mod=score_mod)
|
||||
|
||||
sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
|
||||
|
||||
sdpa_attention = functools.partial(
|
||||
torch.nn.functional.scaled_dot_product_attention, is_causal=True
|
||||
torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
|
||||
)
|
||||
|
||||
self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8)
|
||||
|
||||
@supported_platform
|
||||
def test_causal_full_mask_vs_sdpa(self):
|
||||
def test_windowed_full_mask_vs_sdpa(self):
|
||||
def mask_mod(b, h, q, kv):
|
||||
return q >= kv
|
||||
return q + 1000 >= kv
|
||||
|
||||
score_mod = _generate_windowed(1000)
|
||||
|
||||
block_mask = create_block_mask(mask_mod, 1, 1, 8, S)
|
||||
attention = functools.partial(
|
||||
flex_attention, block_mask=block_mask, score_mod=_causal
|
||||
flex_attention, block_mask=block_mask, score_mod=score_mod
|
||||
)
|
||||
|
||||
sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
|
||||
sdpa_attention = functools.partial(
|
||||
torch.nn.functional.scaled_dot_product_attention, is_causal=True
|
||||
torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
|
||||
)
|
||||
|
||||
self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8)
|
||||
|
||||
@expectedFailure # TODO: add support for partial mask.
|
||||
@supported_platform
|
||||
def test_causal_partial_block_vs_sdpa(self):
|
||||
def test_windowed_partial_block_vs_sdpa(self):
|
||||
def mask_mod(b, h, q, kv):
|
||||
return q >= kv
|
||||
return q + 1000 >= kv
|
||||
|
||||
block_mask = create_block_mask(mask_mod, 1, 1, 8, S)
|
||||
attention = functools.partial(flex_attention, block_mask=block_mask)
|
||||
|
||||
sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000)
|
||||
sdpa_attention = functools.partial(
|
||||
torch.nn.functional.scaled_dot_product_attention, is_causal=True
|
||||
torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask
|
||||
)
|
||||
|
||||
self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8)
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ def build_subgraph_buffer(
|
|||
raise ValueError("FlexAttention was passed a subgraph with no output node!")
|
||||
|
||||
|
||||
# Inner Triton functions shared by flex_attention & split-k decoding kernels.
|
||||
compute_next_offset_func = r"""
|
||||
@triton.jit
|
||||
def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK):
|
||||
|
|
@ -126,10 +127,7 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK
|
|||
return offset
|
||||
"""
|
||||
|
||||
flex_attention_template = TritonTemplate(
|
||||
name="flex_attention",
|
||||
grid=flex_attention_grid,
|
||||
source=r"""
|
||||
compute_flex_attention = r"""
|
||||
{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
|
||||
# Sub notation for this kernel:
|
||||
#
|
||||
|
|
@ -248,6 +246,7 @@ flex_attention_template = TritonTemplate(
|
|||
acc, l_i, m_i,
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
kv_indices, kv_num_blocks,
|
||||
0, kv_num_blocks * SPARSE_KV_MULTIPLE,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}},
|
||||
IS_FULL_BLOCKS=False
|
||||
|
|
@ -285,6 +284,7 @@ flex_attention_template = TritonTemplate(
|
|||
acc, l_i, m_i,
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
kv_indices, kv_num_blocks,
|
||||
0, kv_num_blocks * SPARSE_KV_MULTIPLE,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}},
|
||||
IS_FULL_BLOCKS=True
|
||||
|
|
@ -311,8 +311,10 @@ flex_attention_template = TritonTemplate(
|
|||
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
||||
lse = m_i + tl.math.log2(l_i)
|
||||
tl.store(l_ptrs, lse)
|
||||
"""
|
||||
|
||||
|
||||
compute_forward_block = r"""
|
||||
@triton.jit
|
||||
def forward_inner(
|
||||
q, K_block_ptr, V_block_ptr,
|
||||
|
|
@ -322,6 +324,8 @@ def forward_inner(
|
|||
off_z, off_h, offs_m, offs_n,
|
||||
# blocksparse data
|
||||
kv_indices, kv_num_blocks,
|
||||
# start kv and end kv block
|
||||
block_n_start, block_n_end,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}},
|
||||
IS_FULL_BLOCKS,
|
||||
|
|
@ -339,19 +343,17 @@ def forward_inner(
|
|||
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
||||
|
||||
# loop over k, v and update accumulator
|
||||
lo = 0
|
||||
hi = kv_num_blocks * SPARSE_KV_MULTIPLE
|
||||
|
||||
for start_n in range(0, hi):
|
||||
for start_n in range(block_n_start, block_n_end):
|
||||
# -- load k --
|
||||
k = tl.load(K_block_ptr)
|
||||
# -- compute qk ---
|
||||
qk = tl.dot(q, k)
|
||||
qk = tl.dot(q, k) # TODO: use cuda matmul when q_len <= 2.
|
||||
if not PRESCALE_QK:
|
||||
qk *= SM_SCALE
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
||||
m = offs_m[:, None]
|
||||
n = offs_n[None, :]
|
||||
# TODO: Add load mask in modification when M/N Boundary is not safe
|
||||
{{ modification(
|
||||
subgraph_number=0,
|
||||
output_name="post_mod_scores",
|
||||
|
|
@ -413,8 +415,14 @@ def forward_inner(
|
|||
offs_n = offs_n + offset
|
||||
|
||||
return acc, l_i, m_i
|
||||
"""
|
||||
+ compute_next_offset_func,
|
||||
|
||||
"""
|
||||
|
||||
|
||||
flex_attention_template = TritonTemplate(
|
||||
name="flex_attention",
|
||||
grid=flex_attention_grid,
|
||||
source=compute_flex_attention + compute_forward_block + compute_next_offset_func,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -570,17 +578,6 @@ def flex_attention(
|
|||
subgraph_buffer = build_subgraph_buffer(
|
||||
placeholder_inps + list(score_mod_other_buffers), subgraph
|
||||
)
|
||||
if _use_flex_decoding(query):
|
||||
return create_flex_decoding_kernel(
|
||||
subgraph_buffer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
subgraph,
|
||||
block_mask,
|
||||
scale,
|
||||
*score_mod_other_buffers,
|
||||
)
|
||||
mask_graph_placeholder_inps = [
|
||||
create_placeholder(name, dtype, query.get_device())
|
||||
for name, dtype in [
|
||||
|
|
@ -593,6 +590,18 @@ def flex_attention(
|
|||
mask_graph_buffer = build_subgraph_buffer(
|
||||
mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
|
||||
)
|
||||
if _use_flex_decoding(query):
|
||||
return create_flex_decoding_kernel(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
block_mask,
|
||||
scale,
|
||||
subgraph_buffer,
|
||||
mask_graph_buffer,
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
)
|
||||
for buf in [
|
||||
query,
|
||||
key,
|
||||
|
|
|
|||
|
|
@ -8,9 +8,10 @@ import torch
|
|||
from torch._inductor.virtualized import V
|
||||
|
||||
from ..ir import FixedLayout, FlexibleLayout
|
||||
from ..lowering import empty_strided, lowerings
|
||||
from ..lowering import empty, empty_strided, lowerings
|
||||
from ..runtime.runtime_utils import next_power_of_2
|
||||
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
||||
from .flex_attention import compute_forward_block, compute_next_offset_func
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
|
@ -32,7 +33,7 @@ flex_decoding_template = TritonTemplate(
|
|||
name="flex_decoding",
|
||||
grid=flex_decoding_grid,
|
||||
source=r"""
|
||||
{{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX")}}
|
||||
{{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
|
||||
# Sub notation for this kernel:
|
||||
# Q: Query, K: Key, V: Value
|
||||
# reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
|
||||
|
|
@ -114,21 +115,12 @@ flex_decoding_template = TritonTemplate(
|
|||
sparse_idx_z = off_z % SPARSE_Z
|
||||
sparse_idx_h = off_h % SPARSE_H
|
||||
|
||||
# TODO: strided KV_IDX and KV_NUM_BLKS
|
||||
sparse_hz_offset = sparse_idx_z * SPARSE_H + sparse_idx_h
|
||||
kv_indices = KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
|
||||
sparse_block_num = tl.load(KV_NUM_BLKS + sparse_hz_offset)
|
||||
|
||||
|
||||
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
|
||||
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
|
||||
block_n_last_valid = sparse_block_num * SPARSE_KV_MULTIPLE # last valid block according to sparse mask
|
||||
block_n_end = block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid
|
||||
|
||||
|
||||
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
|
||||
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
||||
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
||||
# first kv block we're loading
|
||||
# Calculate KV blocks that belong this CTA.
|
||||
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
|
||||
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
|
||||
|
||||
q_offset = off_z * stride_qz + off_h * stride_qh
|
||||
k_offset = off_z * stride_kz + off_h * stride_kh
|
||||
|
|
@ -141,6 +133,32 @@ flex_decoding_template = TritonTemplate(
|
|||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
if SAFE_M_BOUNDARY:
|
||||
q = tl.load(Q_block_ptr)
|
||||
else:
|
||||
q = tl.load(Q_block_ptr, boundary_check=(0, ))
|
||||
|
||||
# initialize offsets
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# Apply both score_mod and mask_mod
|
||||
|
||||
# find first kv block we are loading and the number of blocks we are loading
|
||||
kv_indices = KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
|
||||
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_hz_offset)
|
||||
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
|
||||
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
||||
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
||||
# first kv block we're loading
|
||||
|
||||
block_n_last_valid = kv_num_blocks * SPARSE_KV_MULTIPLE # last valid block according to sparse mask
|
||||
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + k_offset,
|
||||
|
|
@ -158,6 +176,66 @@ flex_decoding_template = TritonTemplate(
|
|||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
offs_n = tl.arange(0, BLOCK_N) + off_n
|
||||
|
||||
acc, l_i, m_i = forward_inner(
|
||||
q, K_block_ptr, V_block_ptr,
|
||||
# accumulatd values
|
||||
acc, l_i, m_i,
|
||||
#offsets
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
#block sparse data
|
||||
kv_indices, kv_num_blocks,
|
||||
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}},
|
||||
IS_FULL_BLOCKS=False,
|
||||
)
|
||||
|
||||
|
||||
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# We know these blocks are guaranteed to be "full", so we don't need to
|
||||
# apply mask_mod to them - only score_mod
|
||||
if HAS_FULL_BLOCKS:
|
||||
kv_indices = FULL_KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
|
||||
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_hz_offset)
|
||||
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
|
||||
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
||||
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
||||
|
||||
block_n_last_valid = kv_num_blocks * SPARSE_KV_MULTIPLE # last valid block according to sparse mask
|
||||
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + k_offset,
|
||||
shape=(BLOCK_DMODEL, KV_LEN), # (d, N)
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, off_n),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(KV_LEN, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(off_n, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
offs_n = tl.arange(0, BLOCK_N) + off_n
|
||||
|
||||
acc, l_i, m_i = forward_inner(
|
||||
q, K_block_ptr, V_block_ptr,
|
||||
# accumulatd values
|
||||
acc, l_i, m_i,
|
||||
#offsets
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
#block sparse data
|
||||
kv_indices, kv_num_blocks,
|
||||
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
||||
MATMUL_PRECISION,
|
||||
{{gen_argdefs()}},
|
||||
IS_FULL_BLOCKS=True,
|
||||
)
|
||||
|
||||
m_offset = off_h * stride_mh + off_z * stride_mz
|
||||
l_offset = off_h * stride_lh + off_z * stride_lz
|
||||
|
|
@ -178,92 +256,6 @@ flex_decoding_template = TritonTemplate(
|
|||
order=(1, 0)
|
||||
)
|
||||
|
||||
|
||||
# initialize offsets
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N) + off_n
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
if SAFE_M_BOUNDARY:
|
||||
q = tl.load(Q_block_ptr)
|
||||
else:
|
||||
q = tl.load(Q_block_ptr, boundary_check=(0, ))
|
||||
RCP_LN2 = 1.44269504
|
||||
|
||||
if PRESCALE_QK:
|
||||
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
||||
|
||||
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(block_n_start, block_n_end):
|
||||
# -- load k, v --
|
||||
k = tl.load(K_block_ptr).to(MATMUL_PRECISION)
|
||||
v = tl.load(V_block_ptr)
|
||||
# -- compute qk ---
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk)
|
||||
if not PRESCALE_QK:
|
||||
qk *= SM_SCALE
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
||||
m = offs_m[:, None]
|
||||
n = offs_n[None, :]
|
||||
# TODO: Add load mask in modification when M/N Boundary is not safe
|
||||
{{ modification(
|
||||
subgraph_number=0,
|
||||
output_name="post_mod_scores",
|
||||
score="qk",
|
||||
b="off_z",
|
||||
h="off_h",
|
||||
m="m",
|
||||
n="n",
|
||||
out="qk"
|
||||
) | indent_except_first(2) }}
|
||||
# TODO: In the case that score_mod is linear, this can be LICMed
|
||||
if not PRESCALE_QK:
|
||||
post_mod_scores *= RCP_LN2
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
# -- compute scaling constant ---
|
||||
row_max = tl.max(post_mod_scores, 1)
|
||||
m_i_new = tl.maximum(m_i, row_max)
|
||||
|
||||
alpha = tl.math.exp2(m_i - m_i_new)
|
||||
p = tl.math.exp2(post_mod_scores - m_i_new[:, None])
|
||||
if not ROWS_GUARANTEED_SAFE:
|
||||
masked_out_rows = (m_i_new == float("-inf"))
|
||||
alpha = tl.where(masked_out_rows, 0, alpha)
|
||||
p = tl.where(masked_out_rows[:, None], 0, p)
|
||||
|
||||
# -- scale and update acc --
|
||||
acc *= alpha[:, None]
|
||||
acc = tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION), acc=acc)
|
||||
|
||||
# -- update m_i and l_i --
|
||||
l_i = l_i * alpha + tl.sum(p, 1)
|
||||
m_i = m_i_new
|
||||
|
||||
|
||||
# update pointers
|
||||
indices_idx = start_n // SPARSE_KV_MULTIPLE
|
||||
|
||||
cur_block = tl.load(kv_indices + indices_idx, eviction_policy="evict_last")
|
||||
next_block = tl.load(kv_indices + indices_idx + 1, eviction_policy="evict_last", mask=indices_idx + 1 < sparse_block_num)
|
||||
needs_jump = (start_n + 1) % SPARSE_KV_MULTIPLE == 0
|
||||
jump_to_block = (next_block - cur_block ) * SPARSE_KV_BLOCK_SIZE - (SPARSE_KV_MULTIPLE - 1) * BLOCK_N
|
||||
|
||||
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK_N
|
||||
|
||||
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
|
||||
offs_n = offs_n + offset
|
||||
|
||||
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
|
||||
if SAFE_M_BOUNDARY:
|
||||
tl.store(M_block_ptr, m_i[None, :])
|
||||
|
|
@ -281,7 +273,9 @@ flex_decoding_template = TritonTemplate(
|
|||
# TODO generalize and add proper mask support
|
||||
mask = (idx_m < Q_LEN)
|
||||
{{store_output(("idx_z", "idx_h", "idx_t", "idx_m", "idx_d"), "acc", "mask")}}
|
||||
""",
|
||||
"""
|
||||
+ compute_forward_block
|
||||
+ compute_next_offset_func,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -305,36 +299,47 @@ def _get_decoding_default_config(key) -> Tuple[int, int, int]:
|
|||
|
||||
def create_flex_decoding_kernel(*args, **kwargs):
|
||||
(
|
||||
subgraph_buffer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
subgraph,
|
||||
block_mask,
|
||||
scale,
|
||||
*other_buffers,
|
||||
score_mod_subgraph,
|
||||
mask_mod_subgraph,
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
) = args
|
||||
(
|
||||
sparse_kv_num_blocks,
|
||||
sparse_kv_indices,
|
||||
kv_num_blocks,
|
||||
kv_indices,
|
||||
full_kv_num_blocks, # full_kv_num_blocks,
|
||||
_, # full_kv_indices,
|
||||
full_kv_indices, # full_kv_indices,
|
||||
_, # q_num_blocks
|
||||
_, # q_indices
|
||||
_, # full_q_num_blocks,
|
||||
_, # full_q_indices,
|
||||
SPARSE_KV_BLOCK_SIZE,
|
||||
_, # SPARSE_Q_BLOCK_SIZE,
|
||||
mask_graph,
|
||||
_,
|
||||
) = block_mask
|
||||
if full_kv_num_blocks is not None:
|
||||
raise NotImplementedError("NYI: Flex decoding only supports full mask. ")
|
||||
|
||||
# Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod
|
||||
has_full_blocks = full_kv_num_blocks is not None
|
||||
if (
|
||||
full_kv_num_blocks is None
|
||||
): # Create a plackeholder full block list in case it is empty
|
||||
full_kv_num_blocks, full_kv_indices = (
|
||||
empty(0, device=query.get_device()) for _ in range(2)
|
||||
)
|
||||
|
||||
for buf in [
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
sparse_kv_num_blocks,
|
||||
sparse_kv_indices,
|
||||
kv_num_blocks,
|
||||
kv_indices,
|
||||
full_kv_num_blocks,
|
||||
full_kv_indices,
|
||||
]:
|
||||
buf.realize()
|
||||
choices: List[Any] = []
|
||||
|
|
@ -412,12 +417,15 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
value,
|
||||
buf_M,
|
||||
buf_L,
|
||||
sparse_kv_num_blocks,
|
||||
sparse_kv_indices,
|
||||
kv_num_blocks,
|
||||
kv_indices,
|
||||
full_kv_num_blocks,
|
||||
full_kv_indices,
|
||||
],
|
||||
layout=layout_acc,
|
||||
subgraphs=[
|
||||
subgraph_buffer,
|
||||
score_mod_subgraph,
|
||||
mask_mod_subgraph,
|
||||
],
|
||||
mutated_inputs=[buf_M, buf_L],
|
||||
num_stages=num_stages,
|
||||
|
|
@ -429,23 +437,32 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
SM_SCALE=scale,
|
||||
# Performance tuning
|
||||
BLOCK_N=BLOCK_N,
|
||||
# Sparse block size
|
||||
SPARSE_KV_BLOCK_SIZE=SPARSE_KV_BLOCK_SIZE,
|
||||
# For now, we always assume the "sound" option
|
||||
ROWS_GUARANTEED_SAFE=False,
|
||||
SAFE_M_BOUNDARY=(query.get_size()[-2] % BLOCK_M) == 0,
|
||||
SAFE_N_BOUNDARY=True,
|
||||
PRESCALE_QK=False,
|
||||
SPARSE_KV_BLOCK_SIZE=SPARSE_KV_BLOCK_SIZE,
|
||||
HAS_FULL_BLOCKS=has_full_blocks,
|
||||
)
|
||||
|
||||
inputs_for_flex_decoding = [
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
buf_M,
|
||||
buf_L,
|
||||
sparse_kv_num_blocks,
|
||||
sparse_kv_indices,
|
||||
] + list(other_buffers)
|
||||
inputs_for_flex_decoding = (
|
||||
[
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
buf_M,
|
||||
buf_L,
|
||||
kv_num_blocks,
|
||||
kv_indices,
|
||||
full_kv_num_blocks,
|
||||
full_kv_indices,
|
||||
]
|
||||
+ list(score_mod_other_buffers)
|
||||
+ list(mask_mod_other_buffers)
|
||||
)
|
||||
|
||||
buf_ACC = autotune_select_algorithm(
|
||||
"flex_decoding", choices, inputs_for_flex_decoding, layout_acc
|
||||
)
|
||||
|
|
|
|||
|
|
@ -511,9 +511,6 @@ def _convert_mask_to_block_mask(
|
|||
assert mask.dtype == torch.bool
|
||||
mask = _broadcast_to_dim(mask, 4)
|
||||
B, H, Q, KV = mask.shape
|
||||
is_decoding = Q < 128
|
||||
if is_decoding:
|
||||
Q_BLOCK_SIZE = Q
|
||||
assert Q % Q_BLOCK_SIZE == 0
|
||||
assert KV % KV_BLOCK_SIZE == 0
|
||||
mask = mask.view(
|
||||
|
|
@ -525,7 +522,7 @@ def _convert_mask_to_block_mask(
|
|||
mask_block_sum = mask.sum(
|
||||
dim=[-2, -1]
|
||||
) # [B, H, Q//Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE]
|
||||
if separate_full_blocks and not is_decoding:
|
||||
if separate_full_blocks:
|
||||
full_block_sum = Q_BLOCK_SIZE * KV_BLOCK_SIZE
|
||||
full_blocks = mask_block_sum == full_block_sum
|
||||
partial_blocks = (mask_block_sum > 0) & (mask_block_sum < full_block_sum)
|
||||
|
|
@ -732,7 +729,10 @@ def create_block_mask(
|
|||
mod_type == _ModificationType.MASK_MOD
|
||||
), f"create-block_mask requires a mask_mod function! Got {mask_mod}"
|
||||
inner_func = _create_block_mask_inner
|
||||
Q_LEN = Q_LEN if Q_LEN < 128 else _round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
|
||||
if Q_LEN < 128:
|
||||
Q_BLOCK_SIZE = Q_LEN
|
||||
else:
|
||||
Q_LEN = _round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
|
||||
KV_LEN = _round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE)
|
||||
if _compile:
|
||||
inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)
|
||||
|
|
|
|||
Loading…
Reference in a new issue