test: refactor flash_attn tests to use parameterized (#20913)

Use `parameterized` to decompose the huge test case. This will make
adding ROCm support be possible.

---------

Co-authored-by: Guangyun Han <guangyunhan@microsoft.com@h100vm-ort.kxelwkzfzxguje5bxvwxxs135a.gvxx.internal.cloudapp.net>
This commit is contained in:
cloudhan 2024-06-12 06:57:20 +08:00 committed by GitHub
parent b3fc9b5a0e
commit 67c8befd1d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 329 additions and 249 deletions

View file

@ -20,6 +20,7 @@ import torch
from bert_padding import pad_input, unpad_input
from einops import rearrange, repeat
from onnx import TensorProto, helper
from parameterized import parameterized
from rotary_flash import apply_rotary_emb
from onnxruntime import InferenceSession, OrtValue, SessionOptions
@ -56,6 +57,13 @@ class Config:
self.kv_num_heads = n2
self.head_size = h
def __repr__(self):
return (
f"Config(batch_size={self.batch_size}, sequence_length={self.sequence_length}, "
f"kv_sequence_length={self.kv_sequence_length}, past_sequence_length={self.past_sequence_length}, "
f"num_heads={self.num_heads}, kv_num_heads={self.kv_num_heads}, head_size={self.head_size})"
)
class PromptConfig:
batch_size = 0
@ -75,6 +83,13 @@ class PromptConfig:
self.kv_num_heads = n2
self.head_size = h
def __repr__(self):
return (
f"PromptConfig(batch_size={self.batch_size}, q_sequence_length={self.q_sequence_length}, "
f"kv_sequence_length={self.kv_sequence_length}, buffer_sequence_length={self.buffer_sequence_length}, "
f"num_heads={self.num_heads}, kv_num_heads={self.kv_num_heads}, head_size={self.head_size})"
)
def create_packed_multihead_attention_graph(config):
nodes = [
@ -1974,293 +1989,357 @@ def parity_check_gqa_past_no_buff(
return all_close
def packed_mha_test_cases():
batches = [2] if pipeline_mode else [1, 5]
seqs = [8, 97, 256, 1024] if pipeline_mode else [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
for b in batches:
for s in seqs:
for n in num_h:
for h in h_sizes:
config = Config(b, s, s, 0, n, n, h)
yield str(config), config
def mha_test_cases():
batches = [2] if pipeline_mode else [1, 5]
seqs = (
[(1, 128), (113, 211), (2048, 2048)]
if pipeline_mode
else [
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
]
)
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
for b in batches:
for s, s2 in seqs:
for n in num_h:
for h in h_sizes:
config = Config(b, s, s2, 0, n, n, h)
yield str(config), config
class TestMHA(unittest.TestCase):
def test_packed_mha(self):
@parameterized.expand(packed_mha_test_cases())
def test_packed_mha(self, _, config):
if not torch.cuda.is_available() or platform.system() != "Linux":
return
major, _ = torch.cuda.get_device_capability()
if major < 8:
return
print("-------- TEST PACKED MHA ---------")
batches = [2] if pipeline_mode else [1, 5]
seqs = [8, 97, 256, 1024] if pipeline_mode else [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
for b in batches:
for s in seqs:
for n in num_h:
for h in h_sizes:
config = Config(b, s, s, 0, n, n, h)
all_close = parity_check_mha(config, True)
self.assertTrue(all_close)
all_close = parity_check_mha(config, True)
self.assertTrue(all_close)
def test_mha(self):
@parameterized.expand(mha_test_cases())
def test_mha(self, _, config):
if not torch.cuda.is_available() or platform.system() != "Linux":
return
major, _ = torch.cuda.get_device_capability()
if major < 8:
return
print("-------- TEST MHA ---------")
batches = [2] if pipeline_mode else [1, 5]
seqs = (
[(1, 128), (113, 211), (2048, 2048)]
if pipeline_mode
else [
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
]
)
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
for b in batches:
for s, s2 in seqs:
for n in num_h:
for h in h_sizes:
config = Config(b, s, s2, 0, n, n, h)
all_close = parity_check_mha(config, False)
self.assertTrue(all_close)
all_close = parity_check_mha(config, False)
self.assertTrue(all_close)
class TestGQA(unittest.TestCase):
def test_gqa_no_past_memory_efficient(self):
if not torch.cuda.is_available():
return
major, minor = torch.cuda.get_device_capability()
torch.manual_seed(69)
batches = [3] if pipeline_mode else [1, 3, 5]
seqs = (
[
(127, 127),
(35, 35),
(2000, 2000),
(200, 200),
(240, 240),
]
if pipeline_mode
else [
(127, 127),
(35, 35),
(2000, 2000),
(200, 200),
(240, 240),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
if major < 5 or (major == 5 and minor < 3):
return
print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
for b in batches:
for sq, skv in seqs:
for n, n2 in num_h:
for h in h_sizes:
def gqa_no_past_memory_efficient_test_cases():
batches = [3] if pipeline_mode else [1, 3, 5]
seqs = (
[
(127, 127),
(35, 35),
(2000, 2000),
(200, 200),
(240, 240),
]
if pipeline_mode
else [
(127, 127),
(35, 35),
(2000, 2000),
(200, 200),
(240, 240),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
torch.manual_seed(69)
for b in batches:
for sq, skv in seqs:
for n, n2 in num_h:
for h in h_sizes:
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
for packed in [False, True]:
config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
yield (
str(config) + f"{rotary}_{rotary_interleaved}_{packed}",
config,
rotary,
rotary_interleaved,
packed,
)
def gqa_no_past_flash_attention_test_cases():
batches = [3] if pipeline_mode else [1, 3, 5]
seqs = (
[
(127, 127),
(35, 35),
(2000, 2000),
(200, 200),
(240, 240),
]
if pipeline_mode
else [
(127, 127),
(35, 35),
(2000, 2000),
(200, 200),
(240, 240),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
torch.manual_seed(69)
for b in batches:
for sq, skv in seqs:
for n, n2 in num_h:
for h in h_sizes:
for local in [False, True]:
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
for packed in [False, True]:
config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
all_close = parity_check_gqa_prompt(
yield (
str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}",
config,
rtol=5e-3,
atol=5e-3,
past_format=Formats.BNSH,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
local,
rotary,
rotary_interleaved,
packed,
)
self.assertTrue(all_close)
all_close = parity_check_gqa_prompt_no_buff(
config,
rtol=5e-3,
atol=5e-3,
past_format=Formats.BNSH,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
self.assertTrue(all_close)
def test_gqa_no_past_flash_attention(self):
if not torch.cuda.is_available():
return
major, _ = torch.cuda.get_device_capability()
torch.manual_seed(69)
batches = [3] if pipeline_mode else [1, 3, 5]
seqs = (
[
(127, 127),
(35, 35),
(2000, 2000),
(200, 200),
(240, 240),
]
if pipeline_mode
else [
(127, 127),
(35, 35),
(2000, 2000),
(200, 200),
(240, 240),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
if major < 8 or platform.system() != "Linux":
return
print("------- FLASH ATTENTION (PROMPT CASE) --------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
for b in batches:
for sq, skv in seqs:
for n, n2 in num_h:
for h in h_sizes:
for local in [False, True]:
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
for packed in [False, True]:
config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
all_close = parity_check_gqa_prompt(
config,
local=local,
past_format=Formats.BNSH,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
self.assertTrue(all_close)
all_close = parity_check_gqa_prompt_no_buff(
config,
local=local,
past_format=Formats.BNSH,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
self.assertTrue(all_close)
def test_gqa_past_memory_efficient(self):
if not torch.cuda.is_available():
return
major, minor = torch.cuda.get_device_capability()
if major < 5 or (major == 5 and minor < 3):
return
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
batches = [5] if pipeline_mode else [1, 3, 5]
seqs = (
[(1, 128), (1, 1024), (1, 2048)]
if pipeline_mode
else [
(1, 128),
(1, 339),
(1, 1024),
(1, 5000),
(1, 800),
(1, 256),
(1, 799),
(1, 2048),
# (1, 128 * 512),
# (16, 128 * 512),
# (128, 128),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)
print("-------- MEMORY EFFICIENT (TOKEN GEN) --------")
for b in batches:
for s, s2 in seqs:
for n, n2 in num_h:
for h in h_sizes:
def gqa_past_memory_efficient_test_cases():
batches = [5] if pipeline_mode else [1, 3, 5]
seqs = (
[(1, 128), (1, 1024), (1, 2048)]
if pipeline_mode
else [
(1, 128),
(1, 339),
(1, 1024),
(1, 5000),
(1, 800),
(1, 256),
(1, 799),
(1, 2048),
# (1, 128 * 512),
# (16, 128 * 512),
# (128, 128),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)
for b in batches:
for s, s2 in seqs:
for n, n2 in num_h:
for h in h_sizes:
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
for packed in [False, True]:
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
config = Config(b, s, s2, sp, n, n2, h)
yield (
str(config) + f"{rotary}_{rotary_interleaved}_{packed}",
config,
rotary,
rotary_interleaved,
packed,
)
def gqa_past_flash_attention_test_cases():
batches = [5] if pipeline_mode else [1, 3, 5]
seqs = (
[(1, 128), (1, 1024), (1, 2048)]
if pipeline_mode
else [
(1, 128),
(1, 339),
(1, 1024),
(1, 5000),
(1, 800),
(1, 256),
(1, 799),
(1, 2048),
# (1, 128 * 512),
# (16, 128 * 512),
# (128, 128),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)
for b in batches:
for s, s2 in seqs:
for n, n2 in num_h:
for h in h_sizes:
for local in [False, True]:
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
for packed in [False, True]:
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
config = Config(b, s, s2, sp, n, n2, h)
all_close = parity_check_gqa_past(
yield (
str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}",
config,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
local,
rotary,
rotary_interleaved,
packed,
)
self.assertTrue(all_close)
all_close = parity_check_gqa_past_no_buff(
config,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
self.assertTrue(all_close)
def test_gqa_past_flash_attention(self):
class TestGQA(unittest.TestCase):
@parameterized.expand(gqa_no_past_memory_efficient_test_cases())
def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed):
if not torch.cuda.is_available():
return
major, minor = torch.cuda.get_device_capability()
if major < 5 or (major == 5 and minor < 3):
return
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------")
all_close = parity_check_gqa_prompt(
config,
rtol=5e-3,
atol=5e-3,
past_format=Formats.BNSH,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
self.assertTrue(all_close)
all_close = parity_check_gqa_prompt_no_buff(
config,
rtol=5e-3,
atol=5e-3,
past_format=Formats.BNSH,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
self.assertTrue(all_close)
@parameterized.expand(gqa_no_past_flash_attention_test_cases())
def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed):
if not torch.cuda.is_available():
return
major, _ = torch.cuda.get_device_capability()
batches = [5] if pipeline_mode else [1, 3, 5]
seqs = (
[(1, 128), (1, 1024), (1, 2048)]
if pipeline_mode
else [
(1, 128),
(1, 339),
(1, 1024),
(1, 5000),
(1, 800),
(1, 256),
(1, 799),
(1, 2048),
# (1, 128 * 512),
# (16, 128 * 512),
# (128, 128),
]
if major < 8 or platform.system() != "Linux":
return
print("------- FLASH ATTENTION (PROMPT CASE) --------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
all_close = parity_check_gqa_prompt(
config,
local=local,
past_format=Formats.BNSH,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)
self.assertTrue(all_close)
all_close = parity_check_gqa_prompt_no_buff(
config,
local=local,
past_format=Formats.BNSH,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
self.assertTrue(all_close)
@parameterized.expand(gqa_past_memory_efficient_test_cases())
def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed):
if not torch.cuda.is_available():
return
major, minor = torch.cuda.get_device_capability()
if major < 5 or (major == 5 and minor < 3):
return
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
print("-------- MEMORY EFFICIENT (TOKEN GEN) --------")
all_close = parity_check_gqa_past(
config,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
self.assertTrue(all_close)
all_close = parity_check_gqa_past_no_buff(
config,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
self.assertTrue(all_close)
@parameterized.expand(gqa_past_flash_attention_test_cases())
def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed):
if not torch.cuda.is_available():
return
major, _ = torch.cuda.get_device_capability()
if major < 8 or platform.system() != "Linux":
return
print("------- FLASH ATTENTION (TOKEN GEN) -------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
for b in batches:
for s, s2 in seqs:
for n, n2 in num_h:
for h in h_sizes:
for local in [False, True]:
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
for packed in [False, True]:
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
config = Config(b, s, s2, sp, n, n2, h)
all_close = parity_check_gqa_past(
config,
local=local,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
self.assertTrue(all_close)
all_close = parity_check_gqa_past_no_buff(
config,
local=local,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
self.assertTrue(all_close)
all_close = parity_check_gqa_past(
config,
local=local,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
self.assertTrue(all_close)
all_close = parity_check_gqa_past_no_buff(
config,
local=local,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
)
self.assertTrue(all_close)
if __name__ == "__main__":

View file

@ -6,5 +6,6 @@ numpy==1.26.0 ; python_version >= '3.12'
torch
coloredlogs==15.0
transformers==4.38.0
parameterized>=0.8.1
psutil
einops