From 93da119e3d2bc7622aa10504bb2f51146294aa9d Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Fri, 7 Feb 2025 22:11:39 -0500 Subject: [PATCH] final format fix --- test/inductor/test_flex_attention.py | 43 ++++++++++++++-------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 4839bd19feb..ebf5a3ca5dd 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -10,7 +10,7 @@ from contextlib import contextmanager from dataclasses import dataclass from itertools import product from typing import Callable, Optional, Union -from unittest import expectedFailure, skipUnless +from unittest import expectedFailure, skip, skipUnless from unittest.mock import patch import torch @@ -287,6 +287,20 @@ test_block_size = [ (256, 128), ] +test_strides = [ + ((H * S * D, S * D, D, 1), 997), # offset + ((H * D, D, B * H * D, 1), 499), # transposed dimensions + ((H * S * D, D, H * D, 1), 0), # heads/sequence transposed + ( + (S * (D + 1), B * S * (D + 1), (D + 1), 1), + 293, + ), # additional buffer on one dim + ( + (1, D, (B + 1) * (H + 1) * D, 1), + 97, + ), # additional buffer on multiple dim + shared dimension +] + def query_key_value_clones( query: torch.Tensor, @@ -312,21 +326,6 @@ def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor): ) -test_strides = [ - ((H * S * D, S * D, D, 1), 997), # offset - ((H * D, D, B * H * D, 1), 499), # transposed dimensions - ((H * S * D, D, H * D, 1), 0), # heads/sequence transposed - ( - (S * (D + 1), B * S * (D + 1), (D + 1), 1), - 293, - ), # additional buffer on one dim - ( - (1, D, (B + 1) * (H + 1) * D, 1), - 97, - ), # additional buffer on multiple dim + shared dimension -] - - class TestFlexAttention(InductorTestCase): def setUp(self): super(self.__class__, self).setUp() @@ -2138,12 +2137,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): self.run_test_with_paged_attention(score_mod, device=device) @supported_platform - @unittest.skipIf(TEST_ON_CUDA, "TODO: Figure out why this is erroring") + @skip("TODO: Figure out why this is erroring") @patch.object(torch._inductor.config, "max_autotune", True) - def test_max_autotune_with_captured(self, device): - head_scale = torch.randn(H, device=device) - batch_scale = torch.randn(B, device=device) - tok_scale = torch.randn(S, device=device) + def test_max_autotune_with_captured(self): + head_scale = torch.randn(H, device="cuda") + batch_scale = torch.randn(B, device="cuda") + tok_scale = torch.randn(S, device="cuda") def bias_mod(score, batch, head, token_q, token_kv): score = score + tok_scale[token_q] @@ -2151,7 +2150,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): score = score + head_scale[head] return score - self.run_test(bias_mod, device=device) + self.run_test(bias_mod) @supported_platform @common_utils.parametrize("score_mod", test_score_mods)