mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Inductor] Added and_masks and or_masks utilities & make fully masked out rows 0 instead of nan (#131552)
Combine #131073 and #131012 and fix doc building failures. Co-authored-by: chilli <chilli@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/131552 Approved by: https://github.com/Chillee
This commit is contained in:
parent
89bdd9c18f
commit
a34692c0a3
5 changed files with 116 additions and 14 deletions
|
|
@ -14,6 +14,9 @@ BlockMask Utilities
|
|||
|
||||
.. autofunction:: create_block_mask
|
||||
.. autofunction:: create_mask
|
||||
.. autofunction:: and_masks
|
||||
.. autofunction:: or_masks
|
||||
.. autofunction:: noop_mask
|
||||
|
||||
BlockMask
|
||||
---------
|
||||
|
|
|
|||
|
|
@ -17,9 +17,12 @@ from torch._inductor.utils import run_and_get_code
|
|||
from torch.nn.attention.flex_attention import (
|
||||
_create_empty_block_mask,
|
||||
_identity,
|
||||
and_masks,
|
||||
BlockMask,
|
||||
create_block_mask,
|
||||
flex_attention,
|
||||
noop_mask,
|
||||
or_masks,
|
||||
)
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal import common_utils
|
||||
|
|
@ -199,6 +202,7 @@ class TestFlexAttention(InductorTestCase):
|
|||
):
|
||||
compiled_error = (golden_out - compiled_out).abs().mean()
|
||||
ref_error = (golden_out - ref_out).abs().mean()
|
||||
# TODO: Make this check stricter after updating eager SDPA masked_softmax semantics
|
||||
if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any():
|
||||
self.assertTrue(False, "Output/Grad with NaN")
|
||||
if compiled_error > ref_error * fudge_factor:
|
||||
|
|
@ -1010,6 +1014,35 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
self.assertEqual(block_mask_a.kv_indices, block_mask_b.kv_indices)
|
||||
self.assertEqual(block_mask_a.q_num_blocks, block_mask_b.q_num_blocks)
|
||||
|
||||
@supported_platform
|
||||
def test_mask_mod_combiners(self):
|
||||
def causal_mask(b, h, q, kv):
|
||||
return q >= kv
|
||||
|
||||
def neg_causal_mask(b, h, q, kv):
|
||||
return q < kv
|
||||
|
||||
def sliding_window(b, h, q, kv):
|
||||
return (q - kv) <= 512
|
||||
|
||||
block_mask = create_block_mask(
|
||||
and_masks(causal_mask, sliding_window), 1, 1, S, S
|
||||
)
|
||||
self.assertExpectedInline(block_mask.kv_num_blocks.sum().item(), """28""")
|
||||
attention = functools.partial(flex_attention, block_mask=block_mask)
|
||||
self.run_test_with_call(attention)
|
||||
|
||||
block_mask = create_block_mask(
|
||||
and_masks(causal_mask, neg_causal_mask), 1, 1, S, S
|
||||
)
|
||||
self.assertEqual(block_mask.kv_num_blocks.sum(), 0)
|
||||
|
||||
block_mask1 = create_block_mask(
|
||||
or_masks(causal_mask, neg_causal_mask), 1, 1, S, S
|
||||
)
|
||||
block_mask2 = create_block_mask(noop_mask, 1, 1, S, S)
|
||||
self.assertEqual(block_mask1.sparsity(), block_mask2.sparsity())
|
||||
|
||||
@supported_platform
|
||||
def test_epilogue_fused(self):
|
||||
@torch.compile
|
||||
|
|
@ -1351,6 +1384,35 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
out = func(query, key, value, block_mask=block_mask)
|
||||
out.sum().backward()
|
||||
|
||||
@supported_platform
|
||||
def test_fully_masked_out_rows(self):
|
||||
# Ensure fully masked out rows won't cause NaNs.
|
||||
query = torch.randn(
|
||||
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
||||
)
|
||||
key = torch.randn(
|
||||
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
||||
)
|
||||
value = torch.randn(
|
||||
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
|
||||
)
|
||||
do = torch.randn((B, H, S, D), dtype=torch.float32, device="cuda")
|
||||
|
||||
M = S // 2
|
||||
|
||||
def mask_mod(b, h, q, kv):
|
||||
return q < M
|
||||
|
||||
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
|
||||
out = torch.compile(flex_attention, dynamic=False)(
|
||||
query, key, value, block_mask=block_mask
|
||||
)
|
||||
# TODO: Switch to self.run_test_with_call after updating eager SDPA masked_softmax semantics
|
||||
self.assertEqual(out[:, :, M:, :].sum(), 0)
|
||||
|
||||
out.backward(do)
|
||||
self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
|
||||
|
||||
@supported_platform
|
||||
def test_comparison_vs_sdpa(self):
|
||||
def causal(score, b, h, q_idx, kv_idx):
|
||||
|
|
|
|||
|
|
@ -146,8 +146,6 @@ def _math_attention_inner(
|
|||
mask_mod_in_dim_buffers = (None,) * len(mask_mod_other_buffers)
|
||||
mask_mod = _vmap_for_bhqkv(mask_mod, prefix=(), suffix=mask_mod_in_dim_buffers)
|
||||
|
||||
# todo: We wouldn't need these overrides in this file if Dynamo always did the
|
||||
# rewriting.
|
||||
with TransformGetItemToIndex():
|
||||
scores = (scores * scale).to(working_precision)
|
||||
post_mod_scores = torch.where(
|
||||
|
|
|
|||
|
|
@ -292,6 +292,7 @@ flex_attention_template = TritonTemplate(
|
|||
|
||||
|
||||
# Store output and logsumexp
|
||||
l_i = tl.where(l_i == 0, 1, l_i)
|
||||
acc = acc / l_i[:, None]
|
||||
idx_z = tl.program_id(1) // H
|
||||
idx_h = tl.program_id(1) % H
|
||||
|
|
@ -382,13 +383,14 @@ def forward_inner(
|
|||
|
||||
# -- compute scaling constant ---
|
||||
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
||||
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
p = tl.math.exp2(post_mod_scores - m_ij[:, None])
|
||||
if not ROWS_GUARANTEED_SAFE:
|
||||
masked_out_rows = (m_ij == float("-inf"))
|
||||
alpha = tl.where(masked_out_rows, 0, alpha)
|
||||
p = tl.where(masked_out_rows[:, None], 0, p)
|
||||
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
||||
else:
|
||||
m_ij_masked = m_ij
|
||||
|
||||
alpha = tl.math.exp2(m_i - m_ij_masked)
|
||||
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
||||
|
||||
# NB: l_i update is pulled up here since it's a bit faster
|
||||
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
||||
|
|
@ -1144,7 +1146,6 @@ def bwd_dkdv_inner(
|
|||
Di = tl.load(DELTA + offs_m1)
|
||||
# Compute dP and dS.
|
||||
dpT = tl.dot(v, tl.trans(do))
|
||||
# dpT = tl.where(offs_m1[None, :] < Q_LEN, dpT, 0.0)
|
||||
dsT = pT * (dpT - Di[None, :])
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
||||
m = offs_m1[None, :]
|
||||
|
|
|
|||
|
|
@ -23,7 +23,15 @@ from torch.fx.experimental.proxy_tensor import (
|
|||
)
|
||||
from torch.nn.attention._utils import _validate_sdpa_input
|
||||
|
||||
__all__ = ["BlockMask", "flex_attention", "create_block_mask", "create_mask"]
|
||||
__all__ = [
|
||||
"BlockMask",
|
||||
"flex_attention",
|
||||
"create_block_mask",
|
||||
"create_mask",
|
||||
"or_masks",
|
||||
"and_masks",
|
||||
"noop_mask",
|
||||
]
|
||||
|
||||
_score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor]
|
||||
_mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
|
||||
|
|
@ -38,6 +46,7 @@ class _ModificationType(Enum):
|
|||
|
||||
SCORE_MOD = 1
|
||||
MASK_MOD = 2
|
||||
UNKNOWN = 3
|
||||
|
||||
|
||||
@torch._dynamo.assume_constant_result
|
||||
|
|
@ -59,7 +68,7 @@ def _get_mod_type(fn: Callable) -> _ModificationType:
|
|||
elif num_positional_args == 4:
|
||||
return _ModificationType.MASK_MOD
|
||||
else:
|
||||
raise AssertionError
|
||||
return _ModificationType.UNKNOWN
|
||||
|
||||
|
||||
# Need to define it here so that Dynamo doesn't skip it
|
||||
|
|
@ -109,13 +118,14 @@ def _identity(
|
|||
return score
|
||||
|
||||
|
||||
def _no_mask(
|
||||
def noop_mask(
|
||||
batch: Tensor,
|
||||
head: Tensor,
|
||||
token_q: Tensor,
|
||||
token_kv: Tensor,
|
||||
) -> Tensor:
|
||||
return token_q.new_ones(size=(), dtype=torch.bool, device=batch.device)
|
||||
"""Returns a noop mask_mod"""
|
||||
return batch.new_ones(size=(), dtype=torch.bool, device=batch.device)
|
||||
|
||||
|
||||
_DEFAULT_SPARSE_BLOCK_SIZE = 128
|
||||
|
|
@ -267,7 +277,7 @@ class BlockMask:
|
|||
BLOCK_SIZE = (BLOCK_SIZE, BLOCK_SIZE)
|
||||
self.BLOCK_SIZE = BLOCK_SIZE
|
||||
if mask_mod is None:
|
||||
mask_mod = _no_mask
|
||||
mask_mod = noop_mask
|
||||
self.mask_mod = mask_mod
|
||||
|
||||
def as_tuple(self):
|
||||
|
|
@ -480,6 +490,34 @@ def _convert_mask_to_block_mask(
|
|||
return partial_blocks, None
|
||||
|
||||
|
||||
def or_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature:
|
||||
"""Returns a mask_mod that's the union of provided mask_mods"""
|
||||
if not all(callable(arg) for arg in mask_mods):
|
||||
raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}")
|
||||
|
||||
def or_mask(b, h, q_idx, kv_idx):
|
||||
result = b.new_zeros((), dtype=torch.bool)
|
||||
for mask in mask_mods:
|
||||
result = result | mask(b, h, q_idx, kv_idx)
|
||||
return result
|
||||
|
||||
return or_mask
|
||||
|
||||
|
||||
def and_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature:
|
||||
"""Returns a mask_mod that's the intersection of provided mask_mods"""
|
||||
if not all(callable(arg) for arg in mask_mods):
|
||||
raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}")
|
||||
|
||||
def and_mask(b, h, q_idx, kv_idx):
|
||||
result = b.new_ones((), dtype=torch.bool)
|
||||
for mask in mask_mods:
|
||||
result = result & mask(b, h, q_idx, kv_idx)
|
||||
return result
|
||||
|
||||
return and_mask
|
||||
|
||||
|
||||
def _convert_block_mask_to_mask(
|
||||
block_mask,
|
||||
KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
|
|
@ -644,7 +682,7 @@ def create_block_mask(
|
|||
mod_type = _get_mod_type(mask_mod)
|
||||
assert (
|
||||
mod_type == _ModificationType.MASK_MOD
|
||||
), "create-block_mask requires a mask_mod function!"
|
||||
), f"create-block_mask requires a mask_mod function! Got {mask_mod}"
|
||||
inner_func = _create_block_mask_inner
|
||||
Q_LEN = Q_LEN if Q_LEN < 128 else round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
|
||||
KV_LEN = round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE)
|
||||
|
|
|
|||
Loading…
Reference in a new issue