From 748a571e689a3437a8abcbfb7e3f95bcd3df573a Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Fri, 7 Feb 2025 22:00:07 -0500 Subject: [PATCH] final minor refinement --- test/inductor/test_flex_attention.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index a4bcb3388fa..4839bd19feb 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1374,8 +1374,7 @@ class TestFlexAttention(InductorTestCase): ) @supported_platform - @common_utils.parametrize("dtype", test_dtypes_fast) - def test_doc_mask_sparse(self, device, dtype): + def test_doc_mask_sparse(self, device): document_id = torch.zeros(S, dtype=torch.int, device=device) for i in range(0, S, 256): document_id[i : i + 256] = i // 256 @@ -1385,44 +1384,41 @@ class TestFlexAttention(InductorTestCase): document_mask = document_id[q_idx] == document_id[kv_idx] return torch.where(causal_mask & document_mask, score, -float("inf")) - self.run_test(document_masking_causal, dtype=dtype, device=device) + self.run_test(document_masking_causal, torch.float16, device=device) self.run_test_with_paged_attention( - document_masking_causal, dtype=dtype, device=device + document_masking_causal, torch.float16, device=device ) @supported_platform - @common_utils.parametrize("dtype", test_dtypes_fast) - def test_index_multiple(self, device, dtype): + def test_index_multiple(self, device): bias = torch.randn(B, S, device=device) def index_multiple(score, b, h, q_idx, kv_idx): return score + bias[b][q_idx] - self.run_test(index_multiple, dtype=dtype, device=device) - self.run_test_with_paged_attention(index_multiple, dtype=dtype, device=device) + self.run_test(index_multiple, torch.float16, device=device) + self.run_test_with_paged_attention(index_multiple, torch.float16, device=device) @supported_platform - @common_utils.parametrize("dtype", test_dtypes_fast) - def test_index_weird1(self, device, dtype): + def test_index_weird1(self, device): bias = torch.randn(4, B, H, S, device=device) def index_weird1(score, b, h, q_idx, kv_idx): return score + bias[0][b, h][q_idx] - self.run_test(index_weird1, dtype=dtype, device=device) - self.run_test_with_paged_attention(index_weird1, dtype=dtype, device=device) + self.run_test(index_weird1, torch.float16, device=device) + self.run_test_with_paged_attention(index_weird1, torch.float16, device=device) @supported_platform - @common_utils.parametrize("dtype", test_dtypes_fast) - def test_index_weird2(self, device, dtype): + def test_index_weird2(self, device): bias = torch.randn(B, H, 4, S, device=device) which_bias = torch.tensor(0, device=device) def index_weird2(score, b, h, q_idx, kv_idx): return score + bias[b][h][which_bias, q_idx] - self.run_test(index_weird2, dtype=dtype, device=device) - self.run_test_with_paged_attention(index_weird2, dtype=dtype, device=device) + self.run_test(index_weird2, torch.float16, device=device) + self.run_test_with_paged_attention(index_weird2, torch.float16, device=device) @supported_platform @common_utils.parametrize("dtype", test_dtypes)