diff --git a/benchmarks/transformer/score_mod.py b/benchmarks/transformer/score_mod.py index d7e28282669..1f040a24d43 100644 --- a/benchmarks/transformer/score_mod.py +++ b/benchmarks/transformer/score_mod.py @@ -37,6 +37,7 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> class ExperimentConfig: shape: Tuple[int] score_mod: Callable + mask_mod: Callable dtype: torch.dtype calculate_bwd_time: bool cal_bandwidth: bool @@ -122,7 +123,9 @@ def generate_inputs( def run_single_experiment( - config: ExperimentConfig, dynamic=False, max_autotune=False, enable_mask=False + config: ExperimentConfig, + dynamic=False, + max_autotune=False, ) -> ExperimentResults: device = torch.device("cuda") batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape @@ -149,13 +152,14 @@ def run_single_experiment( compiled_sdpa = torch.compile(flex_attention, dynamic=dynamic) score_mod = config.score_mod + mask_mod = config.mask_mod - if enable_mask: + if mask_mod: block_mask = create_block_mask( - score_mod, 1, 1, q_seq_len * (q_heads // kv_heads), kv_seq_len, query.device + mask_mod, 1, 1, q_seq_len * (q_heads // kv_heads), kv_seq_len, query.device ) else: - block_mask = _create_empty_block_mask(query, key, value) + block_mask = _create_empty_block_mask(query, key) forward_eager_time = benchmark_torch_function_in_microseconds( eager_sdpa, query, key, value, score_mod @@ -328,7 +332,7 @@ def print_results(results: List[Experiment]): print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f")) -def generate_score_mods(score_mods: List[str]) -> List[Callable]: +def generate_score_mods(score_mods: List[str]) -> List[Callable | None]: def noop(score, b, h, m, n): return score @@ -343,14 +347,33 @@ def generate_score_mods(score_mods: List[str]) -> List[Callable]: function_dict = { "noop": noop, - "causal": causal_mask, + "causal": None, "rel": relative_bias, "head_bias": head_bias, } return [function_dict[name] for name in score_mods] +def generate_mask_mods(score_mods: List[str]) -> List[Callable | None]: + def noop(b, h, m, n): + return True + + def causal(b, h, m, n): + return m >= n + + mask_mod_dict = { + "noop": None, + "causal": causal, + "rel": None, + "head_bias": None, + } + return [mask_mod_dict[name] for name in score_mods] + + def get_gqa_score_mod(score_mod, G, q_seq_len): + if score_mod is None: + return None + def score_mod_gqa(score, b, hkv, m, n): g = m // q_seq_len new_m = m % q_seq_len @@ -362,6 +385,21 @@ def get_gqa_score_mod(score_mod, G, q_seq_len): return score_mod_gqa +def get_gqa_mask_mod(mask_mod, G, q_seq_len): + if mask_mod is None: + return None + + def mask_mod_gqa(b, h, m, n): + g = m // q_seq_len + new_m = m % q_seq_len + hq = h * G + g + return mask_mod(b, hq, new_m, n) + + mask_mod_name = get_func_name(mask_mod) + set_func_name(mask_mod_gqa, mask_mod_name + "_gqa") + return mask_mod_gqa + + def generate_experiment_configs( calculate_bwd: bool, dtype: torch.dtype, @@ -369,7 +407,7 @@ def generate_experiment_configs( num_heads: List[Tuple[int, int]], seq_lens: List[int], head_dims: List[int], - score_mods: List[str], + score_mods_str: List[str], decoding: bool, kv_cache_size: List[int], cal_bandwidth: bool, @@ -381,7 +419,8 @@ def generate_experiment_configs( else: q_kv_seq_lens = [(i, i) for i in seq_lens] # only testing q_len == kv_len dtypes = [dtype] - score_mods = generate_score_mods(score_mods) + score_mods = generate_score_mods(score_mods_str) + mask_mods = generate_mask_mods(score_mods_str) all_configs = [] for ( bsz, @@ -389,6 +428,7 @@ def generate_experiment_configs( (q_seq_len, kv_seq_len), head_dim, score_mod, + mask_mod, dtype, ) in itertools.product( kv_cache_size if kv_cache_size else batch_sizes, @@ -396,6 +436,7 @@ def generate_experiment_configs( q_kv_seq_lens, head_dims, score_mods, + mask_mods, dtypes, ): if kv_cache_size: @@ -410,11 +451,13 @@ def generate_experiment_configs( assert q_heads % kv_heads == 0 G = q_heads // kv_heads score_mod = get_gqa_score_mod(score_mod, G, q_seq_len) + mask_mod = get_gqa_mask_mod(mask_mod, G, q_seq_len) all_configs.append( ExperimentConfig( shape=(bsz, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim), score_mod=score_mod, + mask_mod=mask_mod, dtype=dtype, calculate_bwd_time=calculate_bwd, cal_bandwidth=cal_bandwidth, @@ -450,7 +493,6 @@ def main(args): config, dynamic=args.dynamic, max_autotune=args.max_autotune, - enable_mask=args.mask, ), ) ) @@ -526,9 +568,6 @@ Ignores -b batch size and calculate batch size from kv_cache size instead when s action="store_true", help="Calculate kernel memory bandwidth & computational throughput. ", ) - parser.add_argument( - "--mask", action="store_true", help="Enables block sparsity mask. " - ) # Parse arguments args = parser.parse_args() diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index c8b369a9963..86254fd8e0a 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -91,6 +91,19 @@ def _causal( return torch.where(token_q >= token_kv, score, float("-inf")) +def _generate_windowed(offset): + def _windowed(score, b, h, q, kv): + return torch.where(q + offset >= kv, score, float("-inf")) + + return _windowed + + +def _get_windowed_sdpa_mask(Mq, Mkv, offset): + return torch.tril(torch.ones(Mkv, Mkv, dtype=torch.bool, device="cuda"))[ + offset : offset + Mq + ] + + def _rel_bias( score: Tensor, batch: Tensor, @@ -171,6 +184,7 @@ test_score_mods = [ _rel_bias, _rel_causal, _generate_alibi_bias(8), + _generate_windowed(1000), ] captured_buffers_map = { @@ -825,42 +839,48 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): self.run_test(bias_mod) @supported_platform - def test_causal_no_mask_vs_sdpa(self): - attention = functools.partial(flex_attention, score_mod=_causal) + def test_windowed_no_mask_vs_sdpa(self): + score_mod = _generate_windowed(1000) + attention = functools.partial(flex_attention, score_mod=score_mod) + + sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000) sdpa_attention = functools.partial( - torch.nn.functional.scaled_dot_product_attention, is_causal=True + torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask ) self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8) @supported_platform - def test_causal_full_mask_vs_sdpa(self): + def test_windowed_full_mask_vs_sdpa(self): def mask_mod(b, h, q, kv): - return q >= kv + return q + 1000 >= kv + + score_mod = _generate_windowed(1000) block_mask = create_block_mask(mask_mod, 1, 1, 8, S) attention = functools.partial( - flex_attention, block_mask=block_mask, score_mod=_causal + flex_attention, block_mask=block_mask, score_mod=score_mod ) + sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000) sdpa_attention = functools.partial( - torch.nn.functional.scaled_dot_product_attention, is_causal=True + torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask ) self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8) - @expectedFailure # TODO: add support for partial mask. @supported_platform - def test_causal_partial_block_vs_sdpa(self): + def test_windowed_partial_block_vs_sdpa(self): def mask_mod(b, h, q, kv): - return q >= kv + return q + 1000 >= kv block_mask = create_block_mask(mask_mod, 1, 1, 8, S) attention = functools.partial(flex_attention, block_mask=block_mask) + sdpa_mask = _get_windowed_sdpa_mask(8, S, 1000) sdpa_attention = functools.partial( - torch.nn.functional.scaled_dot_product_attention, is_causal=True + torch.nn.functional.scaled_dot_product_attention, attn_mask=sdpa_mask ) self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 4000a8e1bbb..c38c673c377 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -113,6 +113,7 @@ def build_subgraph_buffer( raise ValueError("FlexAttention was passed a subgraph with no output node!") +# Inner Triton functions shared by flex_attention & split-k decoding kernels. compute_next_offset_func = r""" @triton.jit def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK): @@ -126,10 +127,7 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK return offset """ -flex_attention_template = TritonTemplate( - name="flex_attention", - grid=flex_attention_grid, - source=r""" +compute_flex_attention = r""" {{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} # Sub notation for this kernel: # @@ -248,6 +246,7 @@ flex_attention_template = TritonTemplate( acc, l_i, m_i, off_z, off_h, offs_m, offs_n, kv_indices, kv_num_blocks, + 0, kv_num_blocks * SPARSE_KV_MULTIPLE, MATMUL_PRECISION, {{gen_argdefs()}}, IS_FULL_BLOCKS=False @@ -285,6 +284,7 @@ flex_attention_template = TritonTemplate( acc, l_i, m_i, off_z, off_h, offs_m, offs_n, kv_indices, kv_num_blocks, + 0, kv_num_blocks * SPARSE_KV_MULTIPLE, MATMUL_PRECISION, {{gen_argdefs()}}, IS_FULL_BLOCKS=True @@ -311,8 +311,10 @@ flex_attention_template = TritonTemplate( l_ptrs = LSE + off_hz * Q_LEN + offs_m lse = m_i + tl.math.log2(l_i) tl.store(l_ptrs, lse) + """ +compute_forward_block = r""" @triton.jit def forward_inner( q, K_block_ptr, V_block_ptr, @@ -322,6 +324,8 @@ def forward_inner( off_z, off_h, offs_m, offs_n, # blocksparse data kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, MATMUL_PRECISION, {{gen_argdefs()}}, IS_FULL_BLOCKS, @@ -339,19 +343,17 @@ def forward_inner( q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) # loop over k, v and update accumulator - lo = 0 - hi = kv_num_blocks * SPARSE_KV_MULTIPLE - - for start_n in range(0, hi): + for start_n in range(block_n_start, block_n_end): # -- load k -- k = tl.load(K_block_ptr) # -- compute qk --- - qk = tl.dot(q, k) + qk = tl.dot(q, k) # TODO: use cuda matmul when q_len <= 2. if not PRESCALE_QK: qk *= SM_SCALE # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ m = offs_m[:, None] n = offs_n[None, :] + # TODO: Add load mask in modification when M/N Boundary is not safe {{ modification( subgraph_number=0, output_name="post_mod_scores", @@ -413,8 +415,14 @@ def forward_inner( offs_n = offs_n + offset return acc, l_i, m_i - """ - + compute_next_offset_func, + +""" + + +flex_attention_template = TritonTemplate( + name="flex_attention", + grid=flex_attention_grid, + source=compute_flex_attention + compute_forward_block + compute_next_offset_func, ) @@ -570,17 +578,6 @@ def flex_attention( subgraph_buffer = build_subgraph_buffer( placeholder_inps + list(score_mod_other_buffers), subgraph ) - if _use_flex_decoding(query): - return create_flex_decoding_kernel( - subgraph_buffer, - query, - key, - value, - subgraph, - block_mask, - scale, - *score_mod_other_buffers, - ) mask_graph_placeholder_inps = [ create_placeholder(name, dtype, query.get_device()) for name, dtype in [ @@ -593,6 +590,18 @@ def flex_attention( mask_graph_buffer = build_subgraph_buffer( mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph ) + if _use_flex_decoding(query): + return create_flex_decoding_kernel( + query, + key, + value, + block_mask, + scale, + subgraph_buffer, + mask_graph_buffer, + score_mod_other_buffers, + mask_mod_other_buffers, + ) for buf in [ query, key, diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index 911fbf1c78e..5b564dc7a46 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -8,9 +8,10 @@ import torch from torch._inductor.virtualized import V from ..ir import FixedLayout, FlexibleLayout -from ..lowering import empty_strided, lowerings +from ..lowering import empty, empty_strided, lowerings from ..runtime.runtime_utils import next_power_of_2 from ..select_algorithm import autotune_select_algorithm, TritonTemplate +from .flex_attention import compute_forward_block, compute_next_offset_func aten = torch.ops.aten @@ -32,7 +33,7 @@ flex_decoding_template = TritonTemplate( name="flex_decoding", grid=flex_decoding_grid, source=r""" - {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX")}} + {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} # Sub notation for this kernel: # Q: Query, K: Key, V: Value # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split @@ -114,21 +115,12 @@ flex_decoding_template = TritonTemplate( sparse_idx_z = off_z % SPARSE_Z sparse_idx_h = off_h % SPARSE_H + # TODO: strided KV_IDX and KV_NUM_BLKS sparse_hz_offset = sparse_idx_z * SPARSE_H + sparse_idx_h - kv_indices = KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT - sparse_block_num = tl.load(KV_NUM_BLKS + sparse_hz_offset) - - block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block - block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N - block_n_last_valid = sparse_block_num * SPARSE_KV_MULTIPLE # last valid block according to sparse mask - block_n_end = block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid - - - indices_idx = block_n_start // SPARSE_KV_MULTIPLE - off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE - off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N - # first kv block we're loading + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N q_offset = off_z * stride_qz + off_h * stride_qh k_offset = off_z * stride_kz + off_h * stride_kh @@ -141,6 +133,32 @@ flex_decoding_template = TritonTemplate( block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) + if SAFE_M_BOUNDARY: + q = tl.load(Q_block_ptr) + else: + q = tl.load(Q_block_ptr, boundary_check=(0, )) + + # initialize offsets + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_DMODEL) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Apply both score_mod and mask_mod + + # find first kv block we are loading and the number of blocks we are loading + kv_indices = KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_hz_offset) + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + block_n_last_valid = kv_num_blocks * SPARSE_KV_MULTIPLE # last valid block according to sparse mask K_block_ptr = tl.make_block_ptr( base=K + k_offset, @@ -158,6 +176,66 @@ flex_decoding_template = TritonTemplate( block_shape=(BLOCK_N, BLOCK_DMODEL), order=(1, 0) ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + q, K_block_ptr, V_block_ptr, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, off_h, offs_m, offs_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + {{gen_argdefs()}}, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_hz_offset) + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + block_n_last_valid = kv_num_blocks * SPARSE_KV_MULTIPLE # last valid block according to sparse mask + + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(BLOCK_DMODEL, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(KV_LEN, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(off_n, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0) + ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + q, K_block_ptr, V_block_ptr, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, off_h, offs_m, offs_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + {{gen_argdefs()}}, + IS_FULL_BLOCKS=True, + ) m_offset = off_h * stride_mh + off_z * stride_mz l_offset = off_h * stride_lh + off_z * stride_lz @@ -178,92 +256,6 @@ flex_decoding_template = TritonTemplate( order=(1, 0) ) - - # initialize offsets - offs_m = tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) + off_n - offs_d = tl.arange(0, BLOCK_DMODEL) - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - if SAFE_M_BOUNDARY: - q = tl.load(Q_block_ptr) - else: - q = tl.load(Q_block_ptr, boundary_check=(0, )) - RCP_LN2 = 1.44269504 - - if PRESCALE_QK: - q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) - - - # loop over k, v and update accumulator - for start_n in range(block_n_start, block_n_end): - # -- load k, v -- - k = tl.load(K_block_ptr).to(MATMUL_PRECISION) - v = tl.load(V_block_ptr) - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk) - if not PRESCALE_QK: - qk *= SM_SCALE - - # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ - m = offs_m[:, None] - n = offs_n[None, :] - # TODO: Add load mask in modification when M/N Boundary is not safe - {{ modification( - subgraph_number=0, - output_name="post_mod_scores", - score="qk", - b="off_z", - h="off_h", - m="m", - n="n", - out="qk" - ) | indent_except_first(2) }} - # TODO: In the case that score_mod is linear, this can be LICMed - if not PRESCALE_QK: - post_mod_scores *= RCP_LN2 - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - # -- compute scaling constant --- - row_max = tl.max(post_mod_scores, 1) - m_i_new = tl.maximum(m_i, row_max) - - alpha = tl.math.exp2(m_i - m_i_new) - p = tl.math.exp2(post_mod_scores - m_i_new[:, None]) - if not ROWS_GUARANTEED_SAFE: - masked_out_rows = (m_i_new == float("-inf")) - alpha = tl.where(masked_out_rows, 0, alpha) - p = tl.where(masked_out_rows[:, None], 0, p) - - # -- scale and update acc -- - acc *= alpha[:, None] - acc = tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION), acc=acc) - - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - - - # update pointers - indices_idx = start_n // SPARSE_KV_MULTIPLE - - cur_block = tl.load(kv_indices + indices_idx, eviction_policy="evict_last") - next_block = tl.load(kv_indices + indices_idx + 1, eviction_policy="evict_last", mask=indices_idx + 1 < sparse_block_num) - needs_jump = (start_n + 1) % SPARSE_KV_MULTIPLE == 0 - jump_to_block = (next_block - cur_block ) * SPARSE_KV_BLOCK_SIZE - (SPARSE_KV_MULTIPLE - 1) * BLOCK_N - - offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK_N - - - K_block_ptr = tl.advance(K_block_ptr, (0, offset)) - V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) - offs_n = offs_n + offset - # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) if SAFE_M_BOUNDARY: tl.store(M_block_ptr, m_i[None, :]) @@ -281,7 +273,9 @@ flex_decoding_template = TritonTemplate( # TODO generalize and add proper mask support mask = (idx_m < Q_LEN) {{store_output(("idx_z", "idx_h", "idx_t", "idx_m", "idx_d"), "acc", "mask")}} - """, + """ + + compute_forward_block + + compute_next_offset_func, ) @@ -305,36 +299,47 @@ def _get_decoding_default_config(key) -> Tuple[int, int, int]: def create_flex_decoding_kernel(*args, **kwargs): ( - subgraph_buffer, query, key, value, - subgraph, block_mask, scale, - *other_buffers, + score_mod_subgraph, + mask_mod_subgraph, + score_mod_other_buffers, + mask_mod_other_buffers, ) = args ( - sparse_kv_num_blocks, - sparse_kv_indices, + kv_num_blocks, + kv_indices, full_kv_num_blocks, # full_kv_num_blocks, - _, # full_kv_indices, + full_kv_indices, # full_kv_indices, _, # q_num_blocks _, # q_indices _, # full_q_num_blocks, _, # full_q_indices, SPARSE_KV_BLOCK_SIZE, _, # SPARSE_Q_BLOCK_SIZE, - mask_graph, + _, ) = block_mask - if full_kv_num_blocks is not None: - raise NotImplementedError("NYI: Flex decoding only supports full mask. ") + + # Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod + has_full_blocks = full_kv_num_blocks is not None + if ( + full_kv_num_blocks is None + ): # Create a plackeholder full block list in case it is empty + full_kv_num_blocks, full_kv_indices = ( + empty(0, device=query.get_device()) for _ in range(2) + ) + for buf in [ query, key, value, - sparse_kv_num_blocks, - sparse_kv_indices, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, ]: buf.realize() choices: List[Any] = [] @@ -412,12 +417,15 @@ def create_flex_decoding_kernel(*args, **kwargs): value, buf_M, buf_L, - sparse_kv_num_blocks, - sparse_kv_indices, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, ], layout=layout_acc, subgraphs=[ - subgraph_buffer, + score_mod_subgraph, + mask_mod_subgraph, ], mutated_inputs=[buf_M, buf_L], num_stages=num_stages, @@ -429,23 +437,32 @@ def create_flex_decoding_kernel(*args, **kwargs): SM_SCALE=scale, # Performance tuning BLOCK_N=BLOCK_N, + # Sparse block size + SPARSE_KV_BLOCK_SIZE=SPARSE_KV_BLOCK_SIZE, # For now, we always assume the "sound" option ROWS_GUARANTEED_SAFE=False, SAFE_M_BOUNDARY=(query.get_size()[-2] % BLOCK_M) == 0, SAFE_N_BOUNDARY=True, PRESCALE_QK=False, - SPARSE_KV_BLOCK_SIZE=SPARSE_KV_BLOCK_SIZE, + HAS_FULL_BLOCKS=has_full_blocks, ) - inputs_for_flex_decoding = [ - query, - key, - value, - buf_M, - buf_L, - sparse_kv_num_blocks, - sparse_kv_indices, - ] + list(other_buffers) + inputs_for_flex_decoding = ( + [ + query, + key, + value, + buf_M, + buf_L, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + ) + buf_ACC = autotune_select_algorithm( "flex_decoding", choices, inputs_for_flex_decoding, layout_acc ) diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 6cc3fc7a362..a9e97499dcc 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -511,9 +511,6 @@ def _convert_mask_to_block_mask( assert mask.dtype == torch.bool mask = _broadcast_to_dim(mask, 4) B, H, Q, KV = mask.shape - is_decoding = Q < 128 - if is_decoding: - Q_BLOCK_SIZE = Q assert Q % Q_BLOCK_SIZE == 0 assert KV % KV_BLOCK_SIZE == 0 mask = mask.view( @@ -525,7 +522,7 @@ def _convert_mask_to_block_mask( mask_block_sum = mask.sum( dim=[-2, -1] ) # [B, H, Q//Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE] - if separate_full_blocks and not is_decoding: + if separate_full_blocks: full_block_sum = Q_BLOCK_SIZE * KV_BLOCK_SIZE full_blocks = mask_block_sum == full_block_sum partial_blocks = (mask_block_sum > 0) & (mask_block_sum < full_block_sum) @@ -732,7 +729,10 @@ def create_block_mask( mod_type == _ModificationType.MASK_MOD ), f"create-block_mask requires a mask_mod function! Got {mask_mod}" inner_func = _create_block_mask_inner - Q_LEN = Q_LEN if Q_LEN < 128 else _round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE) + if Q_LEN < 128: + Q_BLOCK_SIZE = Q_LEN + else: + Q_LEN = _round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE) KV_LEN = _round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE) if _compile: inner_func = torch.compile(inner_func, fullgraph=True, dynamic=False)