mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
final minor refinement
This commit is contained in:
parent
324b268fec
commit
748a571e68
1 changed files with 12 additions and 16 deletions
|
|
@ -1374,8 +1374,7 @@ class TestFlexAttention(InductorTestCase):
|
|||
)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
def test_doc_mask_sparse(self, device, dtype):
|
||||
def test_doc_mask_sparse(self, device):
|
||||
document_id = torch.zeros(S, dtype=torch.int, device=device)
|
||||
for i in range(0, S, 256):
|
||||
document_id[i : i + 256] = i // 256
|
||||
|
|
@ -1385,44 +1384,41 @@ class TestFlexAttention(InductorTestCase):
|
|||
document_mask = document_id[q_idx] == document_id[kv_idx]
|
||||
return torch.where(causal_mask & document_mask, score, -float("inf"))
|
||||
|
||||
self.run_test(document_masking_causal, dtype=dtype, device=device)
|
||||
self.run_test(document_masking_causal, torch.float16, device=device)
|
||||
self.run_test_with_paged_attention(
|
||||
document_masking_causal, dtype=dtype, device=device
|
||||
document_masking_causal, torch.float16, device=device
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
def test_index_multiple(self, device, dtype):
|
||||
def test_index_multiple(self, device):
|
||||
bias = torch.randn(B, S, device=device)
|
||||
|
||||
def index_multiple(score, b, h, q_idx, kv_idx):
|
||||
return score + bias[b][q_idx]
|
||||
|
||||
self.run_test(index_multiple, dtype=dtype, device=device)
|
||||
self.run_test_with_paged_attention(index_multiple, dtype=dtype, device=device)
|
||||
self.run_test(index_multiple, torch.float16, device=device)
|
||||
self.run_test_with_paged_attention(index_multiple, torch.float16, device=device)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
def test_index_weird1(self, device, dtype):
|
||||
def test_index_weird1(self, device):
|
||||
bias = torch.randn(4, B, H, S, device=device)
|
||||
|
||||
def index_weird1(score, b, h, q_idx, kv_idx):
|
||||
return score + bias[0][b, h][q_idx]
|
||||
|
||||
self.run_test(index_weird1, dtype=dtype, device=device)
|
||||
self.run_test_with_paged_attention(index_weird1, dtype=dtype, device=device)
|
||||
self.run_test(index_weird1, torch.float16, device=device)
|
||||
self.run_test_with_paged_attention(index_weird1, torch.float16, device=device)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
def test_index_weird2(self, device, dtype):
|
||||
def test_index_weird2(self, device):
|
||||
bias = torch.randn(B, H, 4, S, device=device)
|
||||
which_bias = torch.tensor(0, device=device)
|
||||
|
||||
def index_weird2(score, b, h, q_idx, kv_idx):
|
||||
return score + bias[b][h][which_bias, q_idx]
|
||||
|
||||
self.run_test(index_weird2, dtype=dtype, device=device)
|
||||
self.run_test_with_paged_attention(index_weird2, dtype=dtype, device=device)
|
||||
self.run_test(index_weird2, torch.float16, device=device)
|
||||
self.run_test_with_paged_attention(index_weird2, torch.float16, device=device)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes)
|
||||
|
|
|
|||
Loading…
Reference in a new issue