mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix batch-specific attention mod for NJT + Flex (#143866)
Fixes #143788 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143866 Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
This commit is contained in:
parent
1e65dec2b9
commit
228b228449
2 changed files with 37 additions and 3 deletions
|
|
@ -7240,14 +7240,46 @@ torch.cuda.synchronize()
|
|||
|
||||
flex_attention(query, key, value, score_mod=my_score_mod)
|
||||
|
||||
# Test with batch-specific score_mod
|
||||
batch_size = query.size(0)
|
||||
batch_table = torch.randn(batch_size, device=device, dtype=dtype)
|
||||
# Keep score the same for batch index == 0
|
||||
batch_table[0].zero_()
|
||||
|
||||
def batch_specific_score_mod(score, b, h, q_idx, kv_idx):
|
||||
return score + batch_table[b]
|
||||
|
||||
def identity_score_mod(score, b, h, q_idx, kv_idx):
|
||||
return score
|
||||
|
||||
output = flex_attention(query, key, value, score_mod=batch_specific_score_mod)
|
||||
output_identity = flex_attention(
|
||||
query, key, value, score_mod=identity_score_mod
|
||||
)
|
||||
|
||||
# Guard against a bug where the batch index passed to score_mod is always b == 0.
|
||||
# Output would be equivalent to applying an identity score_mod.
|
||||
# See https://github.com/pytorch/pytorch/issues/143788
|
||||
self.assertFalse(torch.allclose(output._values, output_identity._values))
|
||||
|
||||
# Test with mask_mod
|
||||
mask_mod_table = score_mod_table > 0.0
|
||||
|
||||
def my_mask_mod(b, h, q_idx, kv_idx):
|
||||
return mask_mod_table[q_idx]
|
||||
|
||||
def my_mask_mod2(b, h, q_idx, kv_idx):
|
||||
return mask_mod_table[q_idx] & (b == 0)
|
||||
|
||||
block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query, _compile=True)
|
||||
flex_attention(query, key, value, block_mask=block_mask)
|
||||
output = flex_attention(query, key, value, block_mask=block_mask)
|
||||
|
||||
block_mask2 = create_nested_block_mask(my_mask_mod2, 1, 1, query, _compile=True)
|
||||
output2 = flex_attention(query, key, value, block_mask=block_mask2)
|
||||
|
||||
# Guard against a bug where the batch index passed to mask_mod is always b == 0.
|
||||
# See https://github.com/pytorch/pytorch/issues/143788
|
||||
self.assertFalse(torch.allclose(output._values, output2._values))
|
||||
|
||||
@dtypes(torch.float32)
|
||||
def test_apply_(self, device, dtype):
|
||||
|
|
|
|||
|
|
@ -969,12 +969,13 @@ def _nested_mod_func_adapter(
|
|||
if is_score_mod:
|
||||
|
||||
def nt_score_mod(score, b, h, q_idx, kv_idx):
|
||||
b_nested = q_seq_idx[q_idx]
|
||||
q_nested = q_idx - q_offsets[q_seq_idx[q_idx]]
|
||||
kv_nested = kv_idx - kv_offsets[kv_seq_idx[kv_idx]]
|
||||
is_same_sequence = q_seq_idx[q_idx] == kv_seq_idx[kv_idx]
|
||||
return torch.where(
|
||||
is_same_sequence,
|
||||
orig_mod_func(score, b, h, q_nested, kv_nested), # type: ignore[call-arg]
|
||||
orig_mod_func(score, b_nested, h, q_nested, kv_nested), # type: ignore[call-arg]
|
||||
# don't allow inter-sequence attention
|
||||
float("-inf"),
|
||||
)
|
||||
|
|
@ -983,11 +984,12 @@ def _nested_mod_func_adapter(
|
|||
else:
|
||||
|
||||
def nt_mask_mod(b, h, q_idx, kv_idx):
|
||||
b_nested = q_seq_idx[q_idx]
|
||||
q_nested = q_idx - q_offsets[q_seq_idx[q_idx]]
|
||||
kv_nested = kv_idx - kv_offsets[kv_seq_idx[kv_idx]]
|
||||
# don't allow inter-sequence attention
|
||||
is_same_sequence = q_seq_idx[q_idx] == kv_seq_idx[kv_idx]
|
||||
return orig_mod_func(b, h, q_nested, kv_nested) & is_same_sequence # type: ignore[call-arg]
|
||||
return orig_mod_func(b_nested, h, q_nested, kv_nested) & is_same_sequence # type: ignore[call-arg]
|
||||
|
||||
return nt_mask_mod
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue