mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
test: refactor flash_attn tests to use parameterized (#20913)
Use `parameterized` to decompose the huge test case. This will make adding ROCm support be possible. --------- Co-authored-by: Guangyun Han <guangyunhan@microsoft.com@h100vm-ort.kxelwkzfzxguje5bxvwxxs135a.gvxx.internal.cloudapp.net>
This commit is contained in:
parent
b3fc9b5a0e
commit
67c8befd1d
2 changed files with 329 additions and 249 deletions
|
|
@ -20,6 +20,7 @@ import torch
|
|||
from bert_padding import pad_input, unpad_input
|
||||
from einops import rearrange, repeat
|
||||
from onnx import TensorProto, helper
|
||||
from parameterized import parameterized
|
||||
from rotary_flash import apply_rotary_emb
|
||||
|
||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||
|
|
@ -56,6 +57,13 @@ class Config:
|
|||
self.kv_num_heads = n2
|
||||
self.head_size = h
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"Config(batch_size={self.batch_size}, sequence_length={self.sequence_length}, "
|
||||
f"kv_sequence_length={self.kv_sequence_length}, past_sequence_length={self.past_sequence_length}, "
|
||||
f"num_heads={self.num_heads}, kv_num_heads={self.kv_num_heads}, head_size={self.head_size})"
|
||||
)
|
||||
|
||||
|
||||
class PromptConfig:
|
||||
batch_size = 0
|
||||
|
|
@ -75,6 +83,13 @@ class PromptConfig:
|
|||
self.kv_num_heads = n2
|
||||
self.head_size = h
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"PromptConfig(batch_size={self.batch_size}, q_sequence_length={self.q_sequence_length}, "
|
||||
f"kv_sequence_length={self.kv_sequence_length}, buffer_sequence_length={self.buffer_sequence_length}, "
|
||||
f"num_heads={self.num_heads}, kv_num_heads={self.kv_num_heads}, head_size={self.head_size})"
|
||||
)
|
||||
|
||||
|
||||
def create_packed_multihead_attention_graph(config):
|
||||
nodes = [
|
||||
|
|
@ -1974,293 +1989,357 @@ def parity_check_gqa_past_no_buff(
|
|||
return all_close
|
||||
|
||||
|
||||
def packed_mha_test_cases():
|
||||
batches = [2] if pipeline_mode else [1, 5]
|
||||
seqs = [8, 97, 256, 1024] if pipeline_mode else [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]
|
||||
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
|
||||
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
for b in batches:
|
||||
for s in seqs:
|
||||
for n in num_h:
|
||||
for h in h_sizes:
|
||||
config = Config(b, s, s, 0, n, n, h)
|
||||
yield str(config), config
|
||||
|
||||
|
||||
def mha_test_cases():
|
||||
batches = [2] if pipeline_mode else [1, 5]
|
||||
seqs = (
|
||||
[(1, 128), (113, 211), (2048, 2048)]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(113, 203),
|
||||
(128, 217),
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
(512, 256),
|
||||
(1024, 1024),
|
||||
(1023, 1024),
|
||||
(1024, 1023),
|
||||
(2048, 2048),
|
||||
]
|
||||
)
|
||||
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
|
||||
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
for b in batches:
|
||||
for s, s2 in seqs:
|
||||
for n in num_h:
|
||||
for h in h_sizes:
|
||||
config = Config(b, s, s2, 0, n, n, h)
|
||||
yield str(config), config
|
||||
|
||||
|
||||
class TestMHA(unittest.TestCase):
|
||||
def test_packed_mha(self):
|
||||
@parameterized.expand(packed_mha_test_cases())
|
||||
def test_packed_mha(self, _, config):
|
||||
if not torch.cuda.is_available() or platform.system() != "Linux":
|
||||
return
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major < 8:
|
||||
return
|
||||
print("-------- TEST PACKED MHA ---------")
|
||||
batches = [2] if pipeline_mode else [1, 5]
|
||||
seqs = [8, 97, 256, 1024] if pipeline_mode else [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]
|
||||
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
|
||||
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
for b in batches:
|
||||
for s in seqs:
|
||||
for n in num_h:
|
||||
for h in h_sizes:
|
||||
config = Config(b, s, s, 0, n, n, h)
|
||||
all_close = parity_check_mha(config, True)
|
||||
self.assertTrue(all_close)
|
||||
all_close = parity_check_mha(config, True)
|
||||
self.assertTrue(all_close)
|
||||
|
||||
def test_mha(self):
|
||||
@parameterized.expand(mha_test_cases())
|
||||
def test_mha(self, _, config):
|
||||
if not torch.cuda.is_available() or platform.system() != "Linux":
|
||||
return
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major < 8:
|
||||
return
|
||||
print("-------- TEST MHA ---------")
|
||||
batches = [2] if pipeline_mode else [1, 5]
|
||||
seqs = (
|
||||
[(1, 128), (113, 211), (2048, 2048)]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(113, 203),
|
||||
(128, 217),
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
(512, 256),
|
||||
(1024, 1024),
|
||||
(1023, 1024),
|
||||
(1024, 1023),
|
||||
(2048, 2048),
|
||||
]
|
||||
)
|
||||
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
|
||||
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
for b in batches:
|
||||
for s, s2 in seqs:
|
||||
for n in num_h:
|
||||
for h in h_sizes:
|
||||
config = Config(b, s, s2, 0, n, n, h)
|
||||
all_close = parity_check_mha(config, False)
|
||||
self.assertTrue(all_close)
|
||||
all_close = parity_check_mha(config, False)
|
||||
self.assertTrue(all_close)
|
||||
|
||||
|
||||
class TestGQA(unittest.TestCase):
|
||||
def test_gqa_no_past_memory_efficient(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
torch.manual_seed(69)
|
||||
batches = [3] if pipeline_mode else [1, 3, 5]
|
||||
seqs = (
|
||||
[
|
||||
(127, 127),
|
||||
(35, 35),
|
||||
(2000, 2000),
|
||||
(200, 200),
|
||||
(240, 240),
|
||||
]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(127, 127),
|
||||
(35, 35),
|
||||
(2000, 2000),
|
||||
(200, 200),
|
||||
(240, 240),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
if major < 5 or (major == 5 and minor < 3):
|
||||
return
|
||||
print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------")
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
|
||||
for b in batches:
|
||||
for sq, skv in seqs:
|
||||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
def gqa_no_past_memory_efficient_test_cases():
|
||||
batches = [3] if pipeline_mode else [1, 3, 5]
|
||||
seqs = (
|
||||
[
|
||||
(127, 127),
|
||||
(35, 35),
|
||||
(2000, 2000),
|
||||
(200, 200),
|
||||
(240, 240),
|
||||
]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(127, 127),
|
||||
(35, 35),
|
||||
(2000, 2000),
|
||||
(200, 200),
|
||||
(240, 240),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
torch.manual_seed(69)
|
||||
|
||||
for b in batches:
|
||||
for sq, skv in seqs:
|
||||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
|
||||
for packed in [False, True]:
|
||||
config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
|
||||
yield (
|
||||
str(config) + f"{rotary}_{rotary_interleaved}_{packed}",
|
||||
config,
|
||||
rotary,
|
||||
rotary_interleaved,
|
||||
packed,
|
||||
)
|
||||
|
||||
|
||||
def gqa_no_past_flash_attention_test_cases():
|
||||
batches = [3] if pipeline_mode else [1, 3, 5]
|
||||
seqs = (
|
||||
[
|
||||
(127, 127),
|
||||
(35, 35),
|
||||
(2000, 2000),
|
||||
(200, 200),
|
||||
(240, 240),
|
||||
]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(127, 127),
|
||||
(35, 35),
|
||||
(2000, 2000),
|
||||
(200, 200),
|
||||
(240, 240),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
torch.manual_seed(69)
|
||||
|
||||
for b in batches:
|
||||
for sq, skv in seqs:
|
||||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
for local in [False, True]:
|
||||
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
|
||||
for packed in [False, True]:
|
||||
config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
|
||||
all_close = parity_check_gqa_prompt(
|
||||
yield (
|
||||
str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}",
|
||||
config,
|
||||
rtol=5e-3,
|
||||
atol=5e-3,
|
||||
past_format=Formats.BNSH,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
local,
|
||||
rotary,
|
||||
rotary_interleaved,
|
||||
packed,
|
||||
)
|
||||
self.assertTrue(all_close)
|
||||
all_close = parity_check_gqa_prompt_no_buff(
|
||||
config,
|
||||
rtol=5e-3,
|
||||
atol=5e-3,
|
||||
past_format=Formats.BNSH,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
)
|
||||
self.assertTrue(all_close)
|
||||
|
||||
def test_gqa_no_past_flash_attention(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
torch.manual_seed(69)
|
||||
batches = [3] if pipeline_mode else [1, 3, 5]
|
||||
seqs = (
|
||||
[
|
||||
(127, 127),
|
||||
(35, 35),
|
||||
(2000, 2000),
|
||||
(200, 200),
|
||||
(240, 240),
|
||||
]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(127, 127),
|
||||
(35, 35),
|
||||
(2000, 2000),
|
||||
(200, 200),
|
||||
(240, 240),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
if major < 8 or platform.system() != "Linux":
|
||||
return
|
||||
print("------- FLASH ATTENTION (PROMPT CASE) --------")
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
|
||||
for b in batches:
|
||||
for sq, skv in seqs:
|
||||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
for local in [False, True]:
|
||||
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
|
||||
for packed in [False, True]:
|
||||
config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
|
||||
all_close = parity_check_gqa_prompt(
|
||||
config,
|
||||
local=local,
|
||||
past_format=Formats.BNSH,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
)
|
||||
self.assertTrue(all_close)
|
||||
all_close = parity_check_gqa_prompt_no_buff(
|
||||
config,
|
||||
local=local,
|
||||
past_format=Formats.BNSH,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
)
|
||||
self.assertTrue(all_close)
|
||||
|
||||
def test_gqa_past_memory_efficient(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major < 5 or (major == 5 and minor < 3):
|
||||
return
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
|
||||
batches = [5] if pipeline_mode else [1, 3, 5]
|
||||
seqs = (
|
||||
[(1, 128), (1, 1024), (1, 2048)]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(1, 128),
|
||||
(1, 339),
|
||||
(1, 1024),
|
||||
(1, 5000),
|
||||
(1, 800),
|
||||
(1, 256),
|
||||
(1, 799),
|
||||
(1, 2048),
|
||||
# (1, 128 * 512),
|
||||
# (16, 128 * 512),
|
||||
# (128, 128),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
random.seed(69)
|
||||
print("-------- MEMORY EFFICIENT (TOKEN GEN) --------")
|
||||
for b in batches:
|
||||
for s, s2 in seqs:
|
||||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
def gqa_past_memory_efficient_test_cases():
|
||||
batches = [5] if pipeline_mode else [1, 3, 5]
|
||||
seqs = (
|
||||
[(1, 128), (1, 1024), (1, 2048)]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(1, 128),
|
||||
(1, 339),
|
||||
(1, 1024),
|
||||
(1, 5000),
|
||||
(1, 800),
|
||||
(1, 256),
|
||||
(1, 799),
|
||||
(1, 2048),
|
||||
# (1, 128 * 512),
|
||||
# (16, 128 * 512),
|
||||
# (128, 128),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
random.seed(69)
|
||||
|
||||
for b in batches:
|
||||
for s, s2 in seqs:
|
||||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
|
||||
for packed in [False, True]:
|
||||
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
|
||||
config = Config(b, s, s2, sp, n, n2, h)
|
||||
yield (
|
||||
str(config) + f"{rotary}_{rotary_interleaved}_{packed}",
|
||||
config,
|
||||
rotary,
|
||||
rotary_interleaved,
|
||||
packed,
|
||||
)
|
||||
|
||||
|
||||
def gqa_past_flash_attention_test_cases():
|
||||
batches = [5] if pipeline_mode else [1, 3, 5]
|
||||
seqs = (
|
||||
[(1, 128), (1, 1024), (1, 2048)]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(1, 128),
|
||||
(1, 339),
|
||||
(1, 1024),
|
||||
(1, 5000),
|
||||
(1, 800),
|
||||
(1, 256),
|
||||
(1, 799),
|
||||
(1, 2048),
|
||||
# (1, 128 * 512),
|
||||
# (16, 128 * 512),
|
||||
# (128, 128),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
random.seed(69)
|
||||
|
||||
for b in batches:
|
||||
for s, s2 in seqs:
|
||||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
for local in [False, True]:
|
||||
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
|
||||
for packed in [False, True]:
|
||||
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
|
||||
config = Config(b, s, s2, sp, n, n2, h)
|
||||
all_close = parity_check_gqa_past(
|
||||
yield (
|
||||
str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}",
|
||||
config,
|
||||
past_format=Formats.BNSH,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
local,
|
||||
rotary,
|
||||
rotary_interleaved,
|
||||
packed,
|
||||
)
|
||||
self.assertTrue(all_close)
|
||||
all_close = parity_check_gqa_past_no_buff(
|
||||
config,
|
||||
past_format=Formats.BNSH,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
)
|
||||
self.assertTrue(all_close)
|
||||
|
||||
def test_gqa_past_flash_attention(self):
|
||||
|
||||
class TestGQA(unittest.TestCase):
|
||||
@parameterized.expand(gqa_no_past_memory_efficient_test_cases())
|
||||
def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major < 5 or (major == 5 and minor < 3):
|
||||
return
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
|
||||
print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------")
|
||||
|
||||
all_close = parity_check_gqa_prompt(
|
||||
config,
|
||||
rtol=5e-3,
|
||||
atol=5e-3,
|
||||
past_format=Formats.BNSH,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
)
|
||||
self.assertTrue(all_close)
|
||||
all_close = parity_check_gqa_prompt_no_buff(
|
||||
config,
|
||||
rtol=5e-3,
|
||||
atol=5e-3,
|
||||
past_format=Formats.BNSH,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
)
|
||||
self.assertTrue(all_close)
|
||||
|
||||
@parameterized.expand(gqa_no_past_flash_attention_test_cases())
|
||||
def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
batches = [5] if pipeline_mode else [1, 3, 5]
|
||||
seqs = (
|
||||
[(1, 128), (1, 1024), (1, 2048)]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(1, 128),
|
||||
(1, 339),
|
||||
(1, 1024),
|
||||
(1, 5000),
|
||||
(1, 800),
|
||||
(1, 256),
|
||||
(1, 799),
|
||||
(1, 2048),
|
||||
# (1, 128 * 512),
|
||||
# (16, 128 * 512),
|
||||
# (128, 128),
|
||||
]
|
||||
if major < 8 or platform.system() != "Linux":
|
||||
return
|
||||
print("------- FLASH ATTENTION (PROMPT CASE) --------")
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
|
||||
|
||||
all_close = parity_check_gqa_prompt(
|
||||
config,
|
||||
local=local,
|
||||
past_format=Formats.BNSH,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
random.seed(69)
|
||||
self.assertTrue(all_close)
|
||||
all_close = parity_check_gqa_prompt_no_buff(
|
||||
config,
|
||||
local=local,
|
||||
past_format=Formats.BNSH,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
)
|
||||
self.assertTrue(all_close)
|
||||
|
||||
@parameterized.expand(gqa_past_memory_efficient_test_cases())
|
||||
def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major < 5 or (major == 5 and minor < 3):
|
||||
return
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
|
||||
print("-------- MEMORY EFFICIENT (TOKEN GEN) --------")
|
||||
|
||||
all_close = parity_check_gqa_past(
|
||||
config,
|
||||
past_format=Formats.BNSH,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
)
|
||||
self.assertTrue(all_close)
|
||||
all_close = parity_check_gqa_past_no_buff(
|
||||
config,
|
||||
past_format=Formats.BNSH,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
)
|
||||
self.assertTrue(all_close)
|
||||
|
||||
@parameterized.expand(gqa_past_flash_attention_test_cases())
|
||||
def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major < 8 or platform.system() != "Linux":
|
||||
return
|
||||
print("------- FLASH ATTENTION (TOKEN GEN) -------")
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
|
||||
for b in batches:
|
||||
for s, s2 in seqs:
|
||||
for n, n2 in num_h:
|
||||
for h in h_sizes:
|
||||
for local in [False, True]:
|
||||
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
|
||||
for packed in [False, True]:
|
||||
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
|
||||
config = Config(b, s, s2, sp, n, n2, h)
|
||||
all_close = parity_check_gqa_past(
|
||||
config,
|
||||
local=local,
|
||||
past_format=Formats.BNSH,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
)
|
||||
self.assertTrue(all_close)
|
||||
all_close = parity_check_gqa_past_no_buff(
|
||||
config,
|
||||
local=local,
|
||||
past_format=Formats.BNSH,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
)
|
||||
self.assertTrue(all_close)
|
||||
|
||||
all_close = parity_check_gqa_past(
|
||||
config,
|
||||
local=local,
|
||||
past_format=Formats.BNSH,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
)
|
||||
self.assertTrue(all_close)
|
||||
all_close = parity_check_gqa_past_no_buff(
|
||||
config,
|
||||
local=local,
|
||||
past_format=Formats.BNSH,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
)
|
||||
self.assertTrue(all_close)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -6,5 +6,6 @@ numpy==1.26.0 ; python_version >= '3.12'
|
|||
torch
|
||||
coloredlogs==15.0
|
||||
transformers==4.38.0
|
||||
parameterized>=0.8.1
|
||||
psutil
|
||||
einops
|
||||
|
|
|
|||
Loading…
Reference in a new issue