Fix batch-specific attention mod for NJT + Flex (#143866)

Fixes #143788
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143866
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
This commit is contained in:
Joel Schlosser 2024-12-26 12:31:58 -05:00 committed by PyTorch MergeBot
parent 1e65dec2b9
commit 228b228449
2 changed files with 37 additions and 3 deletions

View file

@ -7240,14 +7240,46 @@ torch.cuda.synchronize()
flex_attention(query, key, value, score_mod=my_score_mod)
# Test with batch-specific score_mod
batch_size = query.size(0)
batch_table = torch.randn(batch_size, device=device, dtype=dtype)
# Keep score the same for batch index == 0
batch_table[0].zero_()
def batch_specific_score_mod(score, b, h, q_idx, kv_idx):
return score + batch_table[b]
def identity_score_mod(score, b, h, q_idx, kv_idx):
return score
output = flex_attention(query, key, value, score_mod=batch_specific_score_mod)
output_identity = flex_attention(
query, key, value, score_mod=identity_score_mod
)
# Guard against a bug where the batch index passed to score_mod is always b == 0.
# Output would be equivalent to applying an identity score_mod.
# See https://github.com/pytorch/pytorch/issues/143788
self.assertFalse(torch.allclose(output._values, output_identity._values))
# Test with mask_mod
mask_mod_table = score_mod_table > 0.0
def my_mask_mod(b, h, q_idx, kv_idx):
return mask_mod_table[q_idx]
def my_mask_mod2(b, h, q_idx, kv_idx):
return mask_mod_table[q_idx] & (b == 0)
block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query, _compile=True)
flex_attention(query, key, value, block_mask=block_mask)
output = flex_attention(query, key, value, block_mask=block_mask)
block_mask2 = create_nested_block_mask(my_mask_mod2, 1, 1, query, _compile=True)
output2 = flex_attention(query, key, value, block_mask=block_mask2)
# Guard against a bug where the batch index passed to mask_mod is always b == 0.
# See https://github.com/pytorch/pytorch/issues/143788
self.assertFalse(torch.allclose(output._values, output2._values))
@dtypes(torch.float32)
def test_apply_(self, device, dtype):

View file

@ -969,12 +969,13 @@ def _nested_mod_func_adapter(
if is_score_mod:
def nt_score_mod(score, b, h, q_idx, kv_idx):
b_nested = q_seq_idx[q_idx]
q_nested = q_idx - q_offsets[q_seq_idx[q_idx]]
kv_nested = kv_idx - kv_offsets[kv_seq_idx[kv_idx]]
is_same_sequence = q_seq_idx[q_idx] == kv_seq_idx[kv_idx]
return torch.where(
is_same_sequence,
orig_mod_func(score, b, h, q_nested, kv_nested), # type: ignore[call-arg]
orig_mod_func(score, b_nested, h, q_nested, kv_nested), # type: ignore[call-arg]
# don't allow inter-sequence attention
float("-inf"),
)
@ -983,11 +984,12 @@ def _nested_mod_func_adapter(
else:
def nt_mask_mod(b, h, q_idx, kv_idx):
b_nested = q_seq_idx[q_idx]
q_nested = q_idx - q_offsets[q_seq_idx[q_idx]]
kv_nested = kv_idx - kv_offsets[kv_seq_idx[kv_idx]]
# don't allow inter-sequence attention
is_same_sequence = q_seq_idx[q_idx] == kv_seq_idx[kv_idx]
return orig_mod_func(b, h, q_nested, kv_nested) & is_same_sequence # type: ignore[call-arg]
return orig_mod_func(b_nested, h, q_nested, kv_nested) & is_same_sequence # type: ignore[call-arg]
return nt_mask_mod