mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[FlexAttention] Optimzing learned bias perf to dq calc (#142281)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142281 Approved by: https://github.com/Chillee
This commit is contained in:
parent
e0bdae7884
commit
744a303dee
3 changed files with 36 additions and 18 deletions
|
|
@ -3442,7 +3442,7 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
score_mod_0 = self.score_mod_0
|
||||
mask_fn_0 = self.mask_fn_0
|
||||
flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
|
||||
flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
|
||||
out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None
|
||||
return (out,)
|
||||
|
||||
|
|
@ -3483,7 +3483,7 @@ class GraphModule(torch.nn.Module):
|
|||
fw_graph0 = self.fw_graph0
|
||||
joint_graph0 = self.joint_graph0
|
||||
mask_graph0 = self.mask_graph0
|
||||
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
|
||||
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
|
||||
getitem_4: "f64[2, 2, 128, 4]" = flex_attention_backward[0]
|
||||
getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[1]
|
||||
getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None
|
||||
|
|
|
|||
|
|
@ -1787,6 +1787,21 @@ def bwd_dq_block_mn(
|
|||
if CHECK_BLOCK_BOUNDARY:
|
||||
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
||||
if WRITE_DQ:
|
||||
scatter_mask = offs_m2[:, None] < Q_LEN and offs_n2[None, :] < KV_LEN
|
||||
{{ modification(
|
||||
subgraph_number=3,
|
||||
output_name=None,
|
||||
mask="scatter_mask",
|
||||
score="pre_mod_scores",
|
||||
b="off_z",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
grad_score_mod="ds"
|
||||
) | indent_except_first(2) }}
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
ds = grad_scores
|
||||
|
||||
if not IS_FULL_BLOCKS:
|
||||
|
|
@ -1975,22 +1990,23 @@ def bwd_dkdv_block_mn(
|
|||
) | indent_except_first(1) }}
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
||||
idx_b = off_z
|
||||
idx_h = off_hq
|
||||
idx_m = m
|
||||
idx_n = n
|
||||
scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN
|
||||
{{ modification(
|
||||
subgraph_number=3,
|
||||
output_name=None,
|
||||
mask="scatter_mask",
|
||||
score="pre_mod_scores",
|
||||
b="idx_b",
|
||||
h="idx_h",
|
||||
m="idx_m",
|
||||
n="idx_n",
|
||||
grad_score_mod="dsT"
|
||||
) | indent_except_first(1) }}
|
||||
if not WRITE_DQ:
|
||||
idx_b = off_z
|
||||
idx_h = off_hq
|
||||
idx_m = m
|
||||
idx_n = n
|
||||
scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN
|
||||
{{ modification(
|
||||
subgraph_number=3,
|
||||
output_name=None,
|
||||
mask="scatter_mask",
|
||||
score="pre_mod_scores",
|
||||
b="idx_b",
|
||||
h="idx_h",
|
||||
m="idx_m",
|
||||
n="idx_n",
|
||||
grad_score_mod="dsT"
|
||||
) | indent_except_first(2) }}
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
|
|
|
|||
|
|
@ -1083,6 +1083,8 @@ def _apply_kernel_options(
|
|||
kernel_options.setdefault("PRESCALE_QK", False)
|
||||
kernel_options.setdefault("ROWS_GUARANTEED_SAFE", False)
|
||||
kernel_options.setdefault("BLOCKS_ARE_CONTIGUOUS", False)
|
||||
# This forces all biases grad scatters to be done in the DQ iteration loop of the backwards
|
||||
kernel_options.setdefault("WRITE_DQ", True)
|
||||
|
||||
# If forward kernel needs to return logsumexp is decided by this rule internally.
|
||||
assert "OUTPUT_LOGSUMEXP" not in kernel_options
|
||||
|
|
|
|||
Loading…
Reference in a new issue