[FlexAttention] Optimzing learned bias perf to dq calc (#142281)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142281
Approved by: https://github.com/Chillee
This commit is contained in:
drisspg 2024-12-14 16:14:57 -08:00 committed by PyTorch MergeBot
parent e0bdae7884
commit 744a303dee
3 changed files with 36 additions and 18 deletions

View file

@ -3442,7 +3442,7 @@ class GraphModule(torch.nn.Module):
score_mod_0 = self.score_mod_0
mask_fn_0 = self.mask_fn_0
flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None
return (out,)
@ -3483,7 +3483,7 @@ class GraphModule(torch.nn.Module):
fw_graph0 = self.fw_graph0
joint_graph0 = self.joint_graph0
mask_graph0 = self.mask_graph0
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
getitem_4: "f64[2, 2, 128, 4]" = flex_attention_backward[0]
getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[1]
getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None

View file

@ -1787,6 +1787,21 @@ def bwd_dq_block_mn(
if CHECK_BLOCK_BOUNDARY:
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
if WRITE_DQ:
scatter_mask = offs_m2[:, None] < Q_LEN and offs_n2[None, :] < KV_LEN
{{ modification(
subgraph_number=3,
output_name=None,
mask="scatter_mask",
score="pre_mod_scores",
b="off_z",
h="off_hq",
m="m",
n="n",
grad_score_mod="ds"
) | indent_except_first(2) }}
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ds = grad_scores
if not IS_FULL_BLOCKS:
@ -1975,22 +1990,23 @@ def bwd_dkdv_block_mn(
) | indent_except_first(1) }}
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
idx_b = off_z
idx_h = off_hq
idx_m = m
idx_n = n
scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN
{{ modification(
subgraph_number=3,
output_name=None,
mask="scatter_mask",
score="pre_mod_scores",
b="idx_b",
h="idx_h",
m="idx_m",
n="idx_n",
grad_score_mod="dsT"
) | indent_except_first(1) }}
if not WRITE_DQ:
idx_b = off_z
idx_h = off_hq
idx_m = m
idx_n = n
scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN
{{ modification(
subgraph_number=3,
output_name=None,
mask="scatter_mask",
score="pre_mod_scores",
b="idx_b",
h="idx_h",
m="idx_m",
n="idx_n",
grad_score_mod="dsT"
) | indent_except_first(2) }}
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if CHECK_BLOCK_BOUNDARY:

View file

@ -1083,6 +1083,8 @@ def _apply_kernel_options(
kernel_options.setdefault("PRESCALE_QK", False)
kernel_options.setdefault("ROWS_GUARANTEED_SAFE", False)
kernel_options.setdefault("BLOCKS_ARE_CONTIGUOUS", False)
# This forces all biases grad scatters to be done in the DQ iteration loop of the backwards
kernel_options.setdefault("WRITE_DQ", True)
# If forward kernel needs to return logsumexp is decided by this rule internally.
assert "OUTPUT_LOGSUMEXP" not in kernel_options