[CUDA] update test_flash_attn_cuda.py for Windows (#21006)

Currently test_flash_attn_cuda.py can only run in Linux. It is because
it uses triton for rotary reference implementation, and triton python
package is not available in Windows.

This changes the script to allow the test run in Windows, so that we can
test memory efficient attention in Windows.

Due to limitation, rotary is excluded in testing on Windows.
This commit is contained in:
Tianlei Wu 2024-06-13 12:50:02 -07:00 committed by GitHub
parent f35dd1407f
commit 7c3a25225f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -21,7 +21,6 @@ from bert_padding import pad_input, unpad_input
from einops import rearrange, repeat
from onnx import TensorProto, helper
from parameterized import parameterized
from rotary_flash import apply_rotary_emb
from onnxruntime import InferenceSession, OrtValue, SessionOptions
@ -48,14 +47,16 @@ class Config:
kv_num_heads = 0
head_size = 0
def __init__(self, b, s, s2, sp, n, n2, h):
self.batch_size = b
self.sequence_length = s
self.kv_sequence_length = s2
self.past_sequence_length = sp
self.num_heads = n
self.kv_num_heads = n2
self.head_size = h
def __init__(
self, batch_size, sequence_length, kv_sequence_length, past_sequence_length, num_heads, kv_num_heads, head_size
):
self.batch_size = batch_size
self.sequence_length = sequence_length
self.kv_sequence_length = kv_sequence_length
self.past_sequence_length = past_sequence_length
self.num_heads = num_heads
self.kv_num_heads = kv_num_heads
self.head_size = head_size
def __repr__(self):
return (
@ -74,14 +75,23 @@ class PromptConfig:
kv_num_heads = 0
head_size = 0
def __init__(self, b, sq, skv, sb, n, n2, h):
self.batch_size = b
self.q_sequence_length = sq
self.kv_sequence_length = skv
self.buffer_sequence_length = sb
self.num_heads = n
self.kv_num_heads = n2
self.head_size = h
def __init__(
self,
batch_size,
q_sequence_length,
kv_sequence_length,
buffer_sequence_length,
num_heads,
kv_num_heads,
head_size,
):
self.batch_size = batch_size
self.q_sequence_length = q_sequence_length
self.kv_sequence_length = kv_sequence_length
self.buffer_sequence_length = buffer_sequence_length
self.num_heads = num_heads
self.kv_num_heads = kv_num_heads
self.head_size = head_size
def __repr__(self):
return (
@ -752,6 +762,12 @@ def mha_func(q, k, v, config):
return output
def rotary_options_for_current_os():
# Reference implementation of rotary uses triton, which is not available in Windows.
# So we only test rotary in Linux right now.
return [(False, False)] if platform.system() != "Linux" else [(True, False), (True, True), (False, False)]
def gqa_prompt_func(
q,
k,
@ -1176,6 +1192,13 @@ def parity_check_mha(
return all_close
def rotary_embedding(*args, **kwargs):
# Use local import since triton is not available in Windows.
from rotary_flash import apply_rotary_emb
return apply_rotary_emb(*args, **kwargs)
def parity_check_gqa_prompt(
config,
causal=True,
@ -1265,11 +1288,12 @@ def parity_check_gqa_prompt(
angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi
cos = torch.cos(angle).to(dtype=torch.float16)
sin = torch.sin(angle).to(dtype=torch.float16)
if causal or local:
q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
else:
q_ro = rearrange(
apply_rotary_emb(
rotary_embedding(
rearrange(q, "b s h d -> b 1 (s h) d"),
cos,
sin,
@ -1280,7 +1304,7 @@ def parity_check_gqa_prompt(
s=config.q_sequence_length,
)
# q_ro = q
k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
else:
cos, sin = None, None
q_ro, k_ro = q, new_k
@ -1469,11 +1493,12 @@ def parity_check_gqa_prompt_no_buff(
angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi
cos = torch.cos(angle).to(dtype=torch.float16)
sin = torch.sin(angle).to(dtype=torch.float16)
if causal or local:
q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
else:
q_ro = rearrange(
apply_rotary_emb(
rotary_embedding(
rearrange(q, "b s h d -> b 1 (s h) d"),
cos,
sin,
@ -1484,7 +1509,7 @@ def parity_check_gqa_prompt_no_buff(
s=config.q_sequence_length,
)
# q_ro = q
k_ro = apply_rotary_emb(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
k_ro = rotary_embedding(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
else:
cos, sin = None, None
q_ro, k_ro = q, k_cache_ref
@ -1669,10 +1694,10 @@ def parity_check_gqa_past(
cos = torch.cos(angle).to(dtype=torch.float16)
sin = torch.sin(angle).to(dtype=torch.float16)
if causal or local:
q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
else:
q_ro = rearrange(
apply_rotary_emb(
rotary_embedding(
rearrange(q, "b s h d -> b 1 (s h) d"),
cos,
sin,
@ -1683,7 +1708,7 @@ def parity_check_gqa_past(
s=config.sequence_length,
)
# q_ro = q
k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
else:
cos, sin = None, None
q_ro, k_ro = q, new_k
@ -1878,10 +1903,10 @@ def parity_check_gqa_past_no_buff(
cos = torch.cos(angle).to(dtype=torch.float16)
sin = torch.sin(angle).to(dtype=torch.float16)
if causal or local:
q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
else:
q_ro = rearrange(
apply_rotary_emb(
rotary_embedding(
rearrange(q, "b s h d -> b 1 (s h) d"),
cos,
sin,
@ -1892,7 +1917,7 @@ def parity_check_gqa_past_no_buff(
s=config.sequence_length,
)
# q_ro = q
k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
else:
cos, sin = None, None
q_ro, k_ro = q, new_k
@ -2081,7 +2106,7 @@ def gqa_no_past_memory_efficient_test_cases():
for sq, skv in seqs:
for n, n2 in num_h:
for h in h_sizes:
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
for rotary, rotary_interleaved in rotary_options_for_current_os():
for packed in [False, True]:
config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
yield (
@ -2121,7 +2146,7 @@ def gqa_no_past_flash_attention_test_cases():
for n, n2 in num_h:
for h in h_sizes:
for local in [False, True]:
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
for rotary, rotary_interleaved in rotary_options_for_current_os():
for packed in [False, True]:
config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
yield (
@ -2161,7 +2186,7 @@ def gqa_past_memory_efficient_test_cases():
for s, s2 in seqs:
for n, n2 in num_h:
for h in h_sizes:
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
for rotary, rotary_interleaved in rotary_options_for_current_os():
for packed in [False, True]:
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
config = Config(b, s, s2, sp, n, n2, h)
@ -2202,7 +2227,7 @@ def gqa_past_flash_attention_test_cases():
for n, n2 in num_h:
for h in h_sizes:
for local in [False, True]:
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
for rotary, rotary_interleaved in rotary_options_for_current_os():
for packed in [False, True]:
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
config = Config(b, s, s2, sp, n, n2, h)