diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index 3804dbb52a..7dbbb9249c 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -21,7 +21,6 @@ from bert_padding import pad_input, unpad_input from einops import rearrange, repeat from onnx import TensorProto, helper from parameterized import parameterized -from rotary_flash import apply_rotary_emb from onnxruntime import InferenceSession, OrtValue, SessionOptions @@ -48,14 +47,16 @@ class Config: kv_num_heads = 0 head_size = 0 - def __init__(self, b, s, s2, sp, n, n2, h): - self.batch_size = b - self.sequence_length = s - self.kv_sequence_length = s2 - self.past_sequence_length = sp - self.num_heads = n - self.kv_num_heads = n2 - self.head_size = h + def __init__( + self, batch_size, sequence_length, kv_sequence_length, past_sequence_length, num_heads, kv_num_heads, head_size + ): + self.batch_size = batch_size + self.sequence_length = sequence_length + self.kv_sequence_length = kv_sequence_length + self.past_sequence_length = past_sequence_length + self.num_heads = num_heads + self.kv_num_heads = kv_num_heads + self.head_size = head_size def __repr__(self): return ( @@ -74,14 +75,23 @@ class PromptConfig: kv_num_heads = 0 head_size = 0 - def __init__(self, b, sq, skv, sb, n, n2, h): - self.batch_size = b - self.q_sequence_length = sq - self.kv_sequence_length = skv - self.buffer_sequence_length = sb - self.num_heads = n - self.kv_num_heads = n2 - self.head_size = h + def __init__( + self, + batch_size, + q_sequence_length, + kv_sequence_length, + buffer_sequence_length, + num_heads, + kv_num_heads, + head_size, + ): + self.batch_size = batch_size + self.q_sequence_length = q_sequence_length + self.kv_sequence_length = kv_sequence_length + self.buffer_sequence_length = buffer_sequence_length + self.num_heads = num_heads + self.kv_num_heads = kv_num_heads + self.head_size = head_size def __repr__(self): return ( @@ -752,6 +762,12 @@ def mha_func(q, k, v, config): return output +def rotary_options_for_current_os(): + # Reference implementation of rotary uses triton, which is not available in Windows. + # So we only test rotary in Linux right now. + return [(False, False)] if platform.system() != "Linux" else [(True, False), (True, True), (False, False)] + + def gqa_prompt_func( q, k, @@ -1176,6 +1192,13 @@ def parity_check_mha( return all_close +def rotary_embedding(*args, **kwargs): + # Use local import since triton is not available in Windows. + from rotary_flash import apply_rotary_emb + + return apply_rotary_emb(*args, **kwargs) + + def parity_check_gqa_prompt( config, causal=True, @@ -1265,11 +1288,12 @@ def parity_check_gqa_prompt( angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi cos = torch.cos(angle).to(dtype=torch.float16) sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: - q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) else: q_ro = rearrange( - apply_rotary_emb( + rotary_embedding( rearrange(q, "b s h d -> b 1 (s h) d"), cos, sin, @@ -1280,7 +1304,7 @@ def parity_check_gqa_prompt( s=config.q_sequence_length, ) # q_ro = q - k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) else: cos, sin = None, None q_ro, k_ro = q, new_k @@ -1469,11 +1493,12 @@ def parity_check_gqa_prompt_no_buff( angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi cos = torch.cos(angle).to(dtype=torch.float16) sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: - q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) else: q_ro = rearrange( - apply_rotary_emb( + rotary_embedding( rearrange(q, "b s h d -> b 1 (s h) d"), cos, sin, @@ -1484,7 +1509,7 @@ def parity_check_gqa_prompt_no_buff( s=config.q_sequence_length, ) # q_ro = q - k_ro = apply_rotary_emb(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + k_ro = rotary_embedding(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) else: cos, sin = None, None q_ro, k_ro = q, k_cache_ref @@ -1669,10 +1694,10 @@ def parity_check_gqa_past( cos = torch.cos(angle).to(dtype=torch.float16) sin = torch.sin(angle).to(dtype=torch.float16) if causal or local: - q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) else: q_ro = rearrange( - apply_rotary_emb( + rotary_embedding( rearrange(q, "b s h d -> b 1 (s h) d"), cos, sin, @@ -1683,7 +1708,7 @@ def parity_check_gqa_past( s=config.sequence_length, ) # q_ro = q - k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) else: cos, sin = None, None q_ro, k_ro = q, new_k @@ -1878,10 +1903,10 @@ def parity_check_gqa_past_no_buff( cos = torch.cos(angle).to(dtype=torch.float16) sin = torch.sin(angle).to(dtype=torch.float16) if causal or local: - q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) else: q_ro = rearrange( - apply_rotary_emb( + rotary_embedding( rearrange(q, "b s h d -> b 1 (s h) d"), cos, sin, @@ -1892,7 +1917,7 @@ def parity_check_gqa_past_no_buff( s=config.sequence_length, ) # q_ro = q - k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) else: cos, sin = None, None q_ro, k_ro = q, new_k @@ -2081,7 +2106,7 @@ def gqa_no_past_memory_efficient_test_cases(): for sq, skv in seqs: for n, n2 in num_h: for h in h_sizes: - for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for rotary, rotary_interleaved in rotary_options_for_current_os(): for packed in [False, True]: config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) yield ( @@ -2121,7 +2146,7 @@ def gqa_no_past_flash_attention_test_cases(): for n, n2 in num_h: for h in h_sizes: for local in [False, True]: - for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for rotary, rotary_interleaved in rotary_options_for_current_os(): for packed in [False, True]: config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) yield ( @@ -2161,7 +2186,7 @@ def gqa_past_memory_efficient_test_cases(): for s, s2 in seqs: for n, n2 in num_h: for h in h_sizes: - for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for rotary, rotary_interleaved in rotary_options_for_current_os(): for packed in [False, True]: sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 config = Config(b, s, s2, sp, n, n2, h) @@ -2202,7 +2227,7 @@ def gqa_past_flash_attention_test_cases(): for n, n2 in num_h: for h in h_sizes: for local in [False, True]: - for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for rotary, rotary_interleaved in rotary_options_for_current_os(): for packed in [False, True]: sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 config = Config(b, s, s2, sp, n, n2, h)