mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
final format fix
This commit is contained in:
parent
748a571e68
commit
93da119e3d
1 changed files with 21 additions and 22 deletions
|
|
@ -10,7 +10,7 @@ from contextlib import contextmanager
|
|||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
from typing import Callable, Optional, Union
|
||||
from unittest import expectedFailure, skipUnless
|
||||
from unittest import expectedFailure, skip, skipUnless
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
@ -287,6 +287,20 @@ test_block_size = [
|
|||
(256, 128),
|
||||
]
|
||||
|
||||
test_strides = [
|
||||
((H * S * D, S * D, D, 1), 997), # offset
|
||||
((H * D, D, B * H * D, 1), 499), # transposed dimensions
|
||||
((H * S * D, D, H * D, 1), 0), # heads/sequence transposed
|
||||
(
|
||||
(S * (D + 1), B * S * (D + 1), (D + 1), 1),
|
||||
293,
|
||||
), # additional buffer on one dim
|
||||
(
|
||||
(1, D, (B + 1) * (H + 1) * D, 1),
|
||||
97,
|
||||
), # additional buffer on multiple dim + shared dimension
|
||||
]
|
||||
|
||||
|
||||
def query_key_value_clones(
|
||||
query: torch.Tensor,
|
||||
|
|
@ -312,21 +326,6 @@ def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor):
|
|||
)
|
||||
|
||||
|
||||
test_strides = [
|
||||
((H * S * D, S * D, D, 1), 997), # offset
|
||||
((H * D, D, B * H * D, 1), 499), # transposed dimensions
|
||||
((H * S * D, D, H * D, 1), 0), # heads/sequence transposed
|
||||
(
|
||||
(S * (D + 1), B * S * (D + 1), (D + 1), 1),
|
||||
293,
|
||||
), # additional buffer on one dim
|
||||
(
|
||||
(1, D, (B + 1) * (H + 1) * D, 1),
|
||||
97,
|
||||
), # additional buffer on multiple dim + shared dimension
|
||||
]
|
||||
|
||||
|
||||
class TestFlexAttention(InductorTestCase):
|
||||
def setUp(self):
|
||||
super(self.__class__, self).setUp()
|
||||
|
|
@ -2138,12 +2137,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
self.run_test_with_paged_attention(score_mod, device=device)
|
||||
|
||||
@supported_platform
|
||||
@unittest.skipIf(TEST_ON_CUDA, "TODO: Figure out why this is erroring")
|
||||
@skip("TODO: Figure out why this is erroring")
|
||||
@patch.object(torch._inductor.config, "max_autotune", True)
|
||||
def test_max_autotune_with_captured(self, device):
|
||||
head_scale = torch.randn(H, device=device)
|
||||
batch_scale = torch.randn(B, device=device)
|
||||
tok_scale = torch.randn(S, device=device)
|
||||
def test_max_autotune_with_captured(self):
|
||||
head_scale = torch.randn(H, device="cuda")
|
||||
batch_scale = torch.randn(B, device="cuda")
|
||||
tok_scale = torch.randn(S, device="cuda")
|
||||
|
||||
def bias_mod(score, batch, head, token_q, token_kv):
|
||||
score = score + tok_scale[token_q]
|
||||
|
|
@ -2151,7 +2150,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
score = score + head_scale[head]
|
||||
return score
|
||||
|
||||
self.run_test(bias_mod, device=device)
|
||||
self.run_test(bias_mod)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("score_mod", test_score_mods)
|
||||
|
|
|
|||
Loading…
Reference in a new issue