mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
[CUDA] update test_flash_attn_cuda.py for Windows (#21006)
Currently test_flash_attn_cuda.py can only run in Linux. It is because it uses triton for rotary reference implementation, and triton python package is not available in Windows. This changes the script to allow the test run in Windows, so that we can test memory efficient attention in Windows. Due to limitation, rotary is excluded in testing on Windows.
This commit is contained in:
parent
f35dd1407f
commit
7c3a25225f
1 changed files with 58 additions and 33 deletions
|
|
@ -21,7 +21,6 @@ from bert_padding import pad_input, unpad_input
|
|||
from einops import rearrange, repeat
|
||||
from onnx import TensorProto, helper
|
||||
from parameterized import parameterized
|
||||
from rotary_flash import apply_rotary_emb
|
||||
|
||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||
|
||||
|
|
@ -48,14 +47,16 @@ class Config:
|
|||
kv_num_heads = 0
|
||||
head_size = 0
|
||||
|
||||
def __init__(self, b, s, s2, sp, n, n2, h):
|
||||
self.batch_size = b
|
||||
self.sequence_length = s
|
||||
self.kv_sequence_length = s2
|
||||
self.past_sequence_length = sp
|
||||
self.num_heads = n
|
||||
self.kv_num_heads = n2
|
||||
self.head_size = h
|
||||
def __init__(
|
||||
self, batch_size, sequence_length, kv_sequence_length, past_sequence_length, num_heads, kv_num_heads, head_size
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.sequence_length = sequence_length
|
||||
self.kv_sequence_length = kv_sequence_length
|
||||
self.past_sequence_length = past_sequence_length
|
||||
self.num_heads = num_heads
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.head_size = head_size
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
|
|
@ -74,14 +75,23 @@ class PromptConfig:
|
|||
kv_num_heads = 0
|
||||
head_size = 0
|
||||
|
||||
def __init__(self, b, sq, skv, sb, n, n2, h):
|
||||
self.batch_size = b
|
||||
self.q_sequence_length = sq
|
||||
self.kv_sequence_length = skv
|
||||
self.buffer_sequence_length = sb
|
||||
self.num_heads = n
|
||||
self.kv_num_heads = n2
|
||||
self.head_size = h
|
||||
def __init__(
|
||||
self,
|
||||
batch_size,
|
||||
q_sequence_length,
|
||||
kv_sequence_length,
|
||||
buffer_sequence_length,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
head_size,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.q_sequence_length = q_sequence_length
|
||||
self.kv_sequence_length = kv_sequence_length
|
||||
self.buffer_sequence_length = buffer_sequence_length
|
||||
self.num_heads = num_heads
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.head_size = head_size
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
|
|
@ -752,6 +762,12 @@ def mha_func(q, k, v, config):
|
|||
return output
|
||||
|
||||
|
||||
def rotary_options_for_current_os():
|
||||
# Reference implementation of rotary uses triton, which is not available in Windows.
|
||||
# So we only test rotary in Linux right now.
|
||||
return [(False, False)] if platform.system() != "Linux" else [(True, False), (True, True), (False, False)]
|
||||
|
||||
|
||||
def gqa_prompt_func(
|
||||
q,
|
||||
k,
|
||||
|
|
@ -1176,6 +1192,13 @@ def parity_check_mha(
|
|||
return all_close
|
||||
|
||||
|
||||
def rotary_embedding(*args, **kwargs):
|
||||
# Use local import since triton is not available in Windows.
|
||||
from rotary_flash import apply_rotary_emb
|
||||
|
||||
return apply_rotary_emb(*args, **kwargs)
|
||||
|
||||
|
||||
def parity_check_gqa_prompt(
|
||||
config,
|
||||
causal=True,
|
||||
|
|
@ -1265,11 +1288,12 @@ def parity_check_gqa_prompt(
|
|||
angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi
|
||||
cos = torch.cos(angle).to(dtype=torch.float16)
|
||||
sin = torch.sin(angle).to(dtype=torch.float16)
|
||||
|
||||
if causal or local:
|
||||
q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
|
||||
q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
|
||||
else:
|
||||
q_ro = rearrange(
|
||||
apply_rotary_emb(
|
||||
rotary_embedding(
|
||||
rearrange(q, "b s h d -> b 1 (s h) d"),
|
||||
cos,
|
||||
sin,
|
||||
|
|
@ -1280,7 +1304,7 @@ def parity_check_gqa_prompt(
|
|||
s=config.q_sequence_length,
|
||||
)
|
||||
# q_ro = q
|
||||
k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
|
||||
k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
|
||||
else:
|
||||
cos, sin = None, None
|
||||
q_ro, k_ro = q, new_k
|
||||
|
|
@ -1469,11 +1493,12 @@ def parity_check_gqa_prompt_no_buff(
|
|||
angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi
|
||||
cos = torch.cos(angle).to(dtype=torch.float16)
|
||||
sin = torch.sin(angle).to(dtype=torch.float16)
|
||||
|
||||
if causal or local:
|
||||
q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
|
||||
q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
|
||||
else:
|
||||
q_ro = rearrange(
|
||||
apply_rotary_emb(
|
||||
rotary_embedding(
|
||||
rearrange(q, "b s h d -> b 1 (s h) d"),
|
||||
cos,
|
||||
sin,
|
||||
|
|
@ -1484,7 +1509,7 @@ def parity_check_gqa_prompt_no_buff(
|
|||
s=config.q_sequence_length,
|
||||
)
|
||||
# q_ro = q
|
||||
k_ro = apply_rotary_emb(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
|
||||
k_ro = rotary_embedding(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
|
||||
else:
|
||||
cos, sin = None, None
|
||||
q_ro, k_ro = q, k_cache_ref
|
||||
|
|
@ -1669,10 +1694,10 @@ def parity_check_gqa_past(
|
|||
cos = torch.cos(angle).to(dtype=torch.float16)
|
||||
sin = torch.sin(angle).to(dtype=torch.float16)
|
||||
if causal or local:
|
||||
q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
|
||||
q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
|
||||
else:
|
||||
q_ro = rearrange(
|
||||
apply_rotary_emb(
|
||||
rotary_embedding(
|
||||
rearrange(q, "b s h d -> b 1 (s h) d"),
|
||||
cos,
|
||||
sin,
|
||||
|
|
@ -1683,7 +1708,7 @@ def parity_check_gqa_past(
|
|||
s=config.sequence_length,
|
||||
)
|
||||
# q_ro = q
|
||||
k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
|
||||
k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
|
||||
else:
|
||||
cos, sin = None, None
|
||||
q_ro, k_ro = q, new_k
|
||||
|
|
@ -1878,10 +1903,10 @@ def parity_check_gqa_past_no_buff(
|
|||
cos = torch.cos(angle).to(dtype=torch.float16)
|
||||
sin = torch.sin(angle).to(dtype=torch.float16)
|
||||
if causal or local:
|
||||
q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
|
||||
q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
|
||||
else:
|
||||
q_ro = rearrange(
|
||||
apply_rotary_emb(
|
||||
rotary_embedding(
|
||||
rearrange(q, "b s h d -> b 1 (s h) d"),
|
||||
cos,
|
||||
sin,
|
||||
|
|
@ -1892,7 +1917,7 @@ def parity_check_gqa_past_no_buff(
|
|||
s=config.sequence_length,
|
||||
)
|
||||
# q_ro = q
|
||||
k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
|
||||
k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
|
||||
else:
|
||||
cos, sin = None, None
|
||||
q_ro, k_ro = q, new_k
|
||||
|
|
@ -2081,7 +2106,7 @@ def gqa_no_past_memory_efficient_test_cases():
|
|||
for sq, skv in seqs:
|
||||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
|
||||
for rotary, rotary_interleaved in rotary_options_for_current_os():
|
||||
for packed in [False, True]:
|
||||
config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
|
||||
yield (
|
||||
|
|
@ -2121,7 +2146,7 @@ def gqa_no_past_flash_attention_test_cases():
|
|||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
for local in [False, True]:
|
||||
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
|
||||
for rotary, rotary_interleaved in rotary_options_for_current_os():
|
||||
for packed in [False, True]:
|
||||
config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
|
||||
yield (
|
||||
|
|
@ -2161,7 +2186,7 @@ def gqa_past_memory_efficient_test_cases():
|
|||
for s, s2 in seqs:
|
||||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
|
||||
for rotary, rotary_interleaved in rotary_options_for_current_os():
|
||||
for packed in [False, True]:
|
||||
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
|
||||
config = Config(b, s, s2, sp, n, n2, h)
|
||||
|
|
@ -2202,7 +2227,7 @@ def gqa_past_flash_attention_test_cases():
|
|||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
for local in [False, True]:
|
||||
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
|
||||
for rotary, rotary_interleaved in rotary_options_for_current_os():
|
||||
for packed in [False, True]:
|
||||
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
|
||||
config = Config(b, s, s2, sp, n, n2, h)
|
||||
|
|
|
|||
Loading…
Reference in a new issue