[Inductor] Added and_masks and or_masks utilities & make fully masked out rows 0 instead of nan (#131552)

Combine #131073 and #131012 and fix doc building failures.

Co-authored-by: chilli <chilli@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131552
Approved by: https://github.com/Chillee
This commit is contained in:
Yanbo Liang 2024-07-25 21:29:43 +00:00 committed by PyTorch MergeBot
parent 89bdd9c18f
commit a34692c0a3
5 changed files with 116 additions and 14 deletions

View file

@ -14,6 +14,9 @@ BlockMask Utilities
.. autofunction:: create_block_mask
.. autofunction:: create_mask
.. autofunction:: and_masks
.. autofunction:: or_masks
.. autofunction:: noop_mask
BlockMask
---------

View file

@ -17,9 +17,12 @@ from torch._inductor.utils import run_and_get_code
from torch.nn.attention.flex_attention import (
_create_empty_block_mask,
_identity,
and_masks,
BlockMask,
create_block_mask,
flex_attention,
noop_mask,
or_masks,
)
from torch.testing import FileCheck
from torch.testing._internal import common_utils
@ -199,6 +202,7 @@ class TestFlexAttention(InductorTestCase):
):
compiled_error = (golden_out - compiled_out).abs().mean()
ref_error = (golden_out - ref_out).abs().mean()
# TODO: Make this check stricter after updating eager SDPA masked_softmax semantics
if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any():
self.assertTrue(False, "Output/Grad with NaN")
if compiled_error > ref_error * fudge_factor:
@ -1010,6 +1014,35 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
self.assertEqual(block_mask_a.kv_indices, block_mask_b.kv_indices)
self.assertEqual(block_mask_a.q_num_blocks, block_mask_b.q_num_blocks)
@supported_platform
def test_mask_mod_combiners(self):
def causal_mask(b, h, q, kv):
return q >= kv
def neg_causal_mask(b, h, q, kv):
return q < kv
def sliding_window(b, h, q, kv):
return (q - kv) <= 512
block_mask = create_block_mask(
and_masks(causal_mask, sliding_window), 1, 1, S, S
)
self.assertExpectedInline(block_mask.kv_num_blocks.sum().item(), """28""")
attention = functools.partial(flex_attention, block_mask=block_mask)
self.run_test_with_call(attention)
block_mask = create_block_mask(
and_masks(causal_mask, neg_causal_mask), 1, 1, S, S
)
self.assertEqual(block_mask.kv_num_blocks.sum(), 0)
block_mask1 = create_block_mask(
or_masks(causal_mask, neg_causal_mask), 1, 1, S, S
)
block_mask2 = create_block_mask(noop_mask, 1, 1, S, S)
self.assertEqual(block_mask1.sparsity(), block_mask2.sparsity())
@supported_platform
def test_epilogue_fused(self):
@torch.compile
@ -1351,6 +1384,35 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
out = func(query, key, value, block_mask=block_mask)
out.sum().backward()
@supported_platform
def test_fully_masked_out_rows(self):
# Ensure fully masked out rows won't cause NaNs.
query = torch.randn(
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
)
key = torch.randn(
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
)
value = torch.randn(
(B, H, S, D), dtype=torch.float32, device="cuda", requires_grad=True
)
do = torch.randn((B, H, S, D), dtype=torch.float32, device="cuda")
M = S // 2
def mask_mod(b, h, q, kv):
return q < M
block_mask = create_block_mask(mask_mod, 1, 1, S, S)
out = torch.compile(flex_attention, dynamic=False)(
query, key, value, block_mask=block_mask
)
# TODO: Switch to self.run_test_with_call after updating eager SDPA masked_softmax semantics
self.assertEqual(out[:, :, M:, :].sum(), 0)
out.backward(do)
self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
@supported_platform
def test_comparison_vs_sdpa(self):
def causal(score, b, h, q_idx, kv_idx):

View file

@ -146,8 +146,6 @@ def _math_attention_inner(
mask_mod_in_dim_buffers = (None,) * len(mask_mod_other_buffers)
mask_mod = _vmap_for_bhqkv(mask_mod, prefix=(), suffix=mask_mod_in_dim_buffers)
# todo: We wouldn't need these overrides in this file if Dynamo always did the
# rewriting.
with TransformGetItemToIndex():
scores = (scores * scale).to(working_precision)
post_mod_scores = torch.where(

View file

@ -292,6 +292,7 @@ flex_attention_template = TritonTemplate(
# Store output and logsumexp
l_i = tl.where(l_i == 0, 1, l_i)
acc = acc / l_i[:, None]
idx_z = tl.program_id(1) // H
idx_h = tl.program_id(1) % H
@ -382,13 +383,14 @@ def forward_inner(
# -- compute scaling constant ---
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
alpha = tl.math.exp2(m_i - m_ij)
p = tl.math.exp2(post_mod_scores - m_ij[:, None])
if not ROWS_GUARANTEED_SAFE:
masked_out_rows = (m_ij == float("-inf"))
alpha = tl.where(masked_out_rows, 0, alpha)
p = tl.where(masked_out_rows[:, None], 0, p)
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
else:
m_ij_masked = m_ij
alpha = tl.math.exp2(m_i - m_ij_masked)
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
# NB: l_i update is pulled up here since it's a bit faster
# NB: For headdim=256, it's faster to move it back down to after m_i =
@ -1144,7 +1146,6 @@ def bwd_dkdv_inner(
Di = tl.load(DELTA + offs_m1)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do))
# dpT = tl.where(offs_m1[None, :] < Q_LEN, dpT, 0.0)
dsT = pT * (dpT - Di[None, :])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
m = offs_m1[None, :]

View file

@ -23,7 +23,15 @@ from torch.fx.experimental.proxy_tensor import (
)
from torch.nn.attention._utils import _validate_sdpa_input
__all__ = ["BlockMask", "flex_attention", "create_block_mask", "create_mask"]
__all__ = [
"BlockMask",
"flex_attention",
"create_block_mask",
"create_mask",
"or_masks",
"and_masks",
"noop_mask",
]
_score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor]
_mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
@ -38,6 +46,7 @@ class _ModificationType(Enum):
SCORE_MOD = 1
MASK_MOD = 2
UNKNOWN = 3
@torch._dynamo.assume_constant_result
@ -59,7 +68,7 @@ def _get_mod_type(fn: Callable) -> _ModificationType:
elif num_positional_args == 4:
return _ModificationType.MASK_MOD
else:
raise AssertionError
return _ModificationType.UNKNOWN
# Need to define it here so that Dynamo doesn't skip it
@ -109,13 +118,14 @@ def _identity(
return score
def _no_mask(
def noop_mask(
batch: Tensor,
head: Tensor,
token_q: Tensor,
token_kv: Tensor,
) -> Tensor:
return token_q.new_ones(size=(), dtype=torch.bool, device=batch.device)
"""Returns a noop mask_mod"""
return batch.new_ones(size=(), dtype=torch.bool, device=batch.device)
_DEFAULT_SPARSE_BLOCK_SIZE = 128
@ -267,7 +277,7 @@ class BlockMask:
BLOCK_SIZE = (BLOCK_SIZE, BLOCK_SIZE)
self.BLOCK_SIZE = BLOCK_SIZE
if mask_mod is None:
mask_mod = _no_mask
mask_mod = noop_mask
self.mask_mod = mask_mod
def as_tuple(self):
@ -480,6 +490,34 @@ def _convert_mask_to_block_mask(
return partial_blocks, None
def or_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature:
"""Returns a mask_mod that's the union of provided mask_mods"""
if not all(callable(arg) for arg in mask_mods):
raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}")
def or_mask(b, h, q_idx, kv_idx):
result = b.new_zeros((), dtype=torch.bool)
for mask in mask_mods:
result = result | mask(b, h, q_idx, kv_idx)
return result
return or_mask
def and_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature:
"""Returns a mask_mod that's the intersection of provided mask_mods"""
if not all(callable(arg) for arg in mask_mods):
raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}")
def and_mask(b, h, q_idx, kv_idx):
result = b.new_ones((), dtype=torch.bool)
for mask in mask_mods:
result = result & mask(b, h, q_idx, kv_idx)
return result
return and_mask
def _convert_block_mask_to_mask(
block_mask,
KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
@ -644,7 +682,7 @@ def create_block_mask(
mod_type = _get_mod_type(mask_mod)
assert (
mod_type == _ModificationType.MASK_MOD
), "create-block_mask requires a mask_mod function!"
), f"create-block_mask requires a mask_mod function! Got {mask_mod}"
inner_func = _create_block_mask_inner
Q_LEN = Q_LEN if Q_LEN < 128 else round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE)
KV_LEN = round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE)