final format fix

This commit is contained in:
jianan-gu 2025-02-07 22:11:39 -05:00
parent 748a571e68
commit 93da119e3d

View file

@ -10,7 +10,7 @@ from contextlib import contextmanager
from dataclasses import dataclass
from itertools import product
from typing import Callable, Optional, Union
from unittest import expectedFailure, skipUnless
from unittest import expectedFailure, skip, skipUnless
from unittest.mock import patch
import torch
@ -287,6 +287,20 @@ test_block_size = [
(256, 128),
]
test_strides = [
((H * S * D, S * D, D, 1), 997), # offset
((H * D, D, B * H * D, 1), 499), # transposed dimensions
((H * S * D, D, H * D, 1), 0), # heads/sequence transposed
(
(S * (D + 1), B * S * (D + 1), (D + 1), 1),
293,
), # additional buffer on one dim
(
(1, D, (B + 1) * (H + 1) * D, 1),
97,
), # additional buffer on multiple dim + shared dimension
]
def query_key_value_clones(
query: torch.Tensor,
@ -312,21 +326,6 @@ def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor):
)
test_strides = [
((H * S * D, S * D, D, 1), 997), # offset
((H * D, D, B * H * D, 1), 499), # transposed dimensions
((H * S * D, D, H * D, 1), 0), # heads/sequence transposed
(
(S * (D + 1), B * S * (D + 1), (D + 1), 1),
293,
), # additional buffer on one dim
(
(1, D, (B + 1) * (H + 1) * D, 1),
97,
), # additional buffer on multiple dim + shared dimension
]
class TestFlexAttention(InductorTestCase):
def setUp(self):
super(self.__class__, self).setUp()
@ -2138,12 +2137,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
self.run_test_with_paged_attention(score_mod, device=device)
@supported_platform
@unittest.skipIf(TEST_ON_CUDA, "TODO: Figure out why this is erroring")
@skip("TODO: Figure out why this is erroring")
@patch.object(torch._inductor.config, "max_autotune", True)
def test_max_autotune_with_captured(self, device):
head_scale = torch.randn(H, device=device)
batch_scale = torch.randn(B, device=device)
tok_scale = torch.randn(S, device=device)
def test_max_autotune_with_captured(self):
head_scale = torch.randn(H, device="cuda")
batch_scale = torch.randn(B, device="cuda")
tok_scale = torch.randn(S, device="cuda")
def bias_mod(score, batch, head, token_q, token_kv):
score = score + tok_scale[token_q]
@ -2151,7 +2150,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
score = score + head_scale[head]
return score
self.run_test(bias_mod, device=device)
self.run_test(bias_mod)
@supported_platform
@common_utils.parametrize("score_mod", test_score_mods)