final minor refinement

This commit is contained in:
jianan-gu 2025-02-07 22:00:07 -05:00
parent 324b268fec
commit 748a571e68

View file

@ -1374,8 +1374,7 @@ class TestFlexAttention(InductorTestCase):
)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_doc_mask_sparse(self, device, dtype):
def test_doc_mask_sparse(self, device):
document_id = torch.zeros(S, dtype=torch.int, device=device)
for i in range(0, S, 256):
document_id[i : i + 256] = i // 256
@ -1385,44 +1384,41 @@ class TestFlexAttention(InductorTestCase):
document_mask = document_id[q_idx] == document_id[kv_idx]
return torch.where(causal_mask & document_mask, score, -float("inf"))
self.run_test(document_masking_causal, dtype=dtype, device=device)
self.run_test(document_masking_causal, torch.float16, device=device)
self.run_test_with_paged_attention(
document_masking_causal, dtype=dtype, device=device
document_masking_causal, torch.float16, device=device
)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_index_multiple(self, device, dtype):
def test_index_multiple(self, device):
bias = torch.randn(B, S, device=device)
def index_multiple(score, b, h, q_idx, kv_idx):
return score + bias[b][q_idx]
self.run_test(index_multiple, dtype=dtype, device=device)
self.run_test_with_paged_attention(index_multiple, dtype=dtype, device=device)
self.run_test(index_multiple, torch.float16, device=device)
self.run_test_with_paged_attention(index_multiple, torch.float16, device=device)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_index_weird1(self, device, dtype):
def test_index_weird1(self, device):
bias = torch.randn(4, B, H, S, device=device)
def index_weird1(score, b, h, q_idx, kv_idx):
return score + bias[0][b, h][q_idx]
self.run_test(index_weird1, dtype=dtype, device=device)
self.run_test_with_paged_attention(index_weird1, dtype=dtype, device=device)
self.run_test(index_weird1, torch.float16, device=device)
self.run_test_with_paged_attention(index_weird1, torch.float16, device=device)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_index_weird2(self, device, dtype):
def test_index_weird2(self, device):
bias = torch.randn(B, H, 4, S, device=device)
which_bias = torch.tensor(0, device=device)
def index_weird2(score, b, h, q_idx, kv_idx):
return score + bias[b][h][which_bias, q_idx]
self.run_test(index_weird2, dtype=dtype, device=device)
self.run_test_with_paged_attention(index_weird2, dtype=dtype, device=device)
self.run_test(index_weird2, torch.float16, device=device)
self.run_test_with_paged_attention(index_weird2, torch.float16, device=device)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)