mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
reduce GQA test combinations (#22918)
### Description * Reduce GQA test combinations to save about 35 minutes test time in CI pipelines. * Show latency of transformers tests * Use seed in DMMHA test to avoid random failure. * For test_flash_attn_rocm.py, test skipping condition from "has cuda ep" to "not has rocm ep", so that it does not run in cpu build. * For test_flash_attn_cuda.py, move flash attention and memory efficient attention tests to different classes, so that we can skip a test suite instead of checking in each test. ### Motivation and Context It takes too long to run GQA tests in CI pipelines since there are too many combinations. ###### Linux GPU CI Pipeline Before: 5097 passed, 68 skipped, 8 warnings in 1954.64s (0:32:34) After: 150 passed, 176 skipped, 8 warnings in 530.38s (0:08:50) Time Saved: **1424** seconds (0:23:44) ###### Windows GPU CUDA CI Pipeline Before: 1781 passed, 72 skipped, 6 warnings in 605.48s (0:10:05) After: 116 passed, 118 skipped, 6 warnings in 275.48s (0:04:35) Time Saved: **330** seconds (0:05:30) ###### Linux CPU CI Pipeline Before: 5093 passed, 72 skipped, 4 warnings in 467.04s (0:07:47) - 212.96s transformers/test_gqa_cpu.py::TestGQA::test_gqa_past - 154.12s transformers/test_gqa_cpu.py::TestGQA::test_gqa_no_past - 26.45s transformers/test_gqa_cpu.py::TestGQA::test_gqa_interactive_one_batch After: 116 passed, 210 skipped, 4 warnings in 93.41s (0:01:33) - 0.97s transformers/test_gqa_cpu.py::TestGQA::test_gqa_past - 19.23s transformers/test_gqa_cpu.py::TestGQA::test_gqa_no_past - 2.41s transformers/test_gqa_cpu.py::TestGQA::test_gqa_interactive_one_batch Time Saved: **374** seconds (0:06:14).
This commit is contained in:
parent
55f0559e5d
commit
8d99b1a8dc
5 changed files with 103 additions and 121 deletions
|
|
@ -757,7 +757,7 @@ static void TestDecoderMaskedMultiHeadAttention(bool is_cross_attn = true, bool
|
|||
|
||||
OpTester tester("DecoderMaskedMultiHeadAttention", 1, onnxruntime::kMSDomain);
|
||||
FixedPatternValueGenerator generator{};
|
||||
RandomValueGenerator random{};
|
||||
RandomValueGenerator random{123};
|
||||
|
||||
// Attributes
|
||||
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from packaging import version
|
|||
from parameterized import parameterized
|
||||
from test_gqa_cpu import smooth_softmax_ref
|
||||
|
||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
|
@ -1999,6 +1999,8 @@ def parity_check_gqa_past_no_buff(
|
|||
def has_flash_attention():
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
if "CUDAExecutionProvider" not in get_available_providers():
|
||||
return False
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
return major >= 8 and (
|
||||
platform.system() == "Linux"
|
||||
|
|
@ -2009,6 +2011,8 @@ def has_flash_attention():
|
|||
def has_memory_efficient():
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
if "CUDAExecutionProvider" not in get_available_providers():
|
||||
return False
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major < 5 or (major == 5 and minor < 3):
|
||||
return False
|
||||
|
|
@ -2047,8 +2051,8 @@ def mha_test_cases():
|
|||
(2048, 2048),
|
||||
]
|
||||
)
|
||||
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
|
||||
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
num_h = [3] if pipeline_mode else [1, 6, 16]
|
||||
h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
for b in batches:
|
||||
for s, s2 in seqs:
|
||||
|
|
@ -2080,11 +2084,7 @@ def gqa_no_past_memory_efficient_test_cases():
|
|||
batches = [3] if pipeline_mode else [1, 3, 5]
|
||||
seqs = (
|
||||
[
|
||||
(127, 127),
|
||||
(35, 35),
|
||||
(2000, 2000),
|
||||
(200, 200),
|
||||
(240, 240),
|
||||
]
|
||||
if pipeline_mode
|
||||
else [
|
||||
|
|
@ -2095,8 +2095,8 @@ def gqa_no_past_memory_efficient_test_cases():
|
|||
(240, 240),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
torch.manual_seed(69)
|
||||
|
||||
for b in batches:
|
||||
|
|
@ -2121,10 +2121,6 @@ def gqa_no_past_flash_attention_test_cases():
|
|||
batches = [3] if pipeline_mode else [1, 3, 5]
|
||||
seqs = (
|
||||
[
|
||||
(127, 127),
|
||||
(35, 35),
|
||||
(2000, 2000),
|
||||
(200, 200),
|
||||
(240, 240),
|
||||
]
|
||||
if pipeline_mode
|
||||
|
|
@ -2136,8 +2132,8 @@ def gqa_no_past_flash_attention_test_cases():
|
|||
(240, 240),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
torch.manual_seed(69)
|
||||
|
||||
for b in batches:
|
||||
|
|
@ -2163,7 +2159,7 @@ def gqa_no_past_flash_attention_test_cases():
|
|||
def gqa_past_memory_efficient_test_cases():
|
||||
batches = [5] if pipeline_mode else [1, 3, 5]
|
||||
seqs = (
|
||||
[(1, 128), (1, 1024), (1, 2048)]
|
||||
[(1, 1024)]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(1, 128),
|
||||
|
|
@ -2179,8 +2175,8 @@ def gqa_past_memory_efficient_test_cases():
|
|||
# (128, 128),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
random.seed(69)
|
||||
|
||||
for b in batches:
|
||||
|
|
@ -2205,7 +2201,7 @@ def gqa_past_memory_efficient_test_cases():
|
|||
def gqa_past_flash_attention_test_cases():
|
||||
batches = [5] if pipeline_mode else [1, 3, 5]
|
||||
seqs = (
|
||||
[(1, 128), (1, 1024), (1, 2048)]
|
||||
[(1, 2048)]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(1, 128),
|
||||
|
|
@ -2221,8 +2217,8 @@ def gqa_past_flash_attention_test_cases():
|
|||
# (128, 128),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
random.seed(69)
|
||||
|
||||
for b in batches:
|
||||
|
|
@ -2249,7 +2245,7 @@ def gqa_past_flash_attention_test_cases():
|
|||
def gqa_interactive_one_batch_flash_attention_test_cases():
|
||||
batches = [1]
|
||||
seqs = (
|
||||
[(2, 128), (128, 129), (32, 128), (256, 2048)]
|
||||
[(128, 2048)]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(1, 128),
|
||||
|
|
@ -2265,8 +2261,8 @@ def gqa_interactive_one_batch_flash_attention_test_cases():
|
|||
# (128, 128),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
random.seed(69)
|
||||
|
||||
for b in batches:
|
||||
|
|
@ -2290,7 +2286,7 @@ def gqa_interactive_one_batch_flash_attention_test_cases():
|
|||
def gqa_interactive_one_batch_memory_efficient_attention_test_cases():
|
||||
batches = [1]
|
||||
seqs = (
|
||||
[(2, 128), (128, 129), (32, 128), (256, 2048)]
|
||||
[(32, 128)]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(1, 128),
|
||||
|
|
@ -2306,8 +2302,8 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases():
|
|||
# (128, 128),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
random.seed(69)
|
||||
|
||||
for b in batches:
|
||||
|
|
@ -2326,41 +2322,10 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases():
|
|||
)
|
||||
|
||||
|
||||
class TestGQA(unittest.TestCase):
|
||||
@parameterized.expand(gqa_no_past_memory_efficient_test_cases())
|
||||
def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
|
||||
if not has_memory_efficient():
|
||||
return
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
|
||||
print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------")
|
||||
|
||||
parity_check_gqa_prompt(
|
||||
config,
|
||||
rtol=5e-3,
|
||||
atol=5e-3,
|
||||
past_format=Formats.BNSH,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
softcap=softcap,
|
||||
use_smooth_softmax=False,
|
||||
)
|
||||
parity_check_gqa_prompt_no_buff(
|
||||
config,
|
||||
rtol=5e-3,
|
||||
atol=5e-3,
|
||||
past_format=Formats.BNSH,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
softcap=softcap,
|
||||
use_smooth_softmax=True,
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_flash_attention(), reason="Flash Attention is not available, skipping tests.")
|
||||
class TestFlashGQA(unittest.TestCase):
|
||||
@parameterized.expand(gqa_no_past_flash_attention_test_cases())
|
||||
def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
|
||||
if not has_flash_attention():
|
||||
return
|
||||
print("------- FLASH ATTENTION (PROMPT CASE) --------")
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
|
||||
|
||||
|
|
@ -2385,40 +2350,8 @@ class TestGQA(unittest.TestCase):
|
|||
use_smooth_softmax=False,
|
||||
)
|
||||
|
||||
@parameterized.expand(gqa_past_memory_efficient_test_cases())
|
||||
def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
|
||||
if not has_memory_efficient():
|
||||
return
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
|
||||
print("-------- MEMORY EFFICIENT (TOKEN GEN) --------")
|
||||
|
||||
parity_check_gqa_past(
|
||||
config,
|
||||
past_format=Formats.BNSH,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
softcap=softcap,
|
||||
use_smooth_softmax=True,
|
||||
)
|
||||
parity_check_gqa_past_no_buff(
|
||||
config,
|
||||
past_format=Formats.BNSH,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
softcap=softcap,
|
||||
use_smooth_softmax=False,
|
||||
)
|
||||
|
||||
@parameterized.expand(gqa_past_flash_attention_test_cases())
|
||||
def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
|
||||
if not has_flash_attention():
|
||||
return
|
||||
print("------- FLASH ATTENTION (TOKEN GEN) -------")
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
|
||||
|
||||
|
|
@ -2449,8 +2382,6 @@ class TestGQA(unittest.TestCase):
|
|||
|
||||
@parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases())
|
||||
def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed):
|
||||
if not has_flash_attention():
|
||||
return
|
||||
print("------- FLASH ATTENTION (INTERACTIVE) -------")
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
|
||||
|
||||
|
|
@ -2475,10 +2406,67 @@ class TestGQA(unittest.TestCase):
|
|||
packed=packed,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(not has_memory_efficient(), reason="Memory efficient FMHA is not available, skipping tests.")
|
||||
class TestMemoryEfficientGQA(unittest.TestCase):
|
||||
@parameterized.expand(gqa_no_past_memory_efficient_test_cases())
|
||||
def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
|
||||
print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------")
|
||||
|
||||
parity_check_gqa_prompt(
|
||||
config,
|
||||
rtol=5e-3,
|
||||
atol=5e-3,
|
||||
past_format=Formats.BNSH,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
softcap=softcap,
|
||||
use_smooth_softmax=False,
|
||||
)
|
||||
parity_check_gqa_prompt_no_buff(
|
||||
config,
|
||||
rtol=5e-3,
|
||||
atol=5e-3,
|
||||
past_format=Formats.BNSH,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
softcap=softcap,
|
||||
use_smooth_softmax=True,
|
||||
)
|
||||
|
||||
@parameterized.expand(gqa_past_memory_efficient_test_cases())
|
||||
def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
|
||||
print("-------- MEMORY EFFICIENT (TOKEN GEN) --------")
|
||||
|
||||
parity_check_gqa_past(
|
||||
config,
|
||||
past_format=Formats.BNSH,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
softcap=softcap,
|
||||
use_smooth_softmax=True,
|
||||
)
|
||||
parity_check_gqa_past_no_buff(
|
||||
config,
|
||||
past_format=Formats.BNSH,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
softcap=softcap,
|
||||
use_smooth_softmax=False,
|
||||
)
|
||||
|
||||
@parameterized.expand(gqa_interactive_one_batch_memory_efficient_attention_test_cases())
|
||||
def test_gqa_interactive_one_batch_memory_efficient_attention(self, _, config, rotary, rotary_interleaved, packed):
|
||||
if not has_memory_efficient():
|
||||
return
|
||||
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
|
||||
print("-------- MEMORY EFFICIENT (INTERACTIVE) --------")
|
||||
|
||||
|
|
|
|||
|
|
@ -16,16 +16,16 @@ from test_flash_attn_cuda import (
|
|||
import onnxruntime
|
||||
|
||||
|
||||
class TestGQA(unittest.TestCase):
|
||||
@unittest.skipIf(
|
||||
(not torch.cuda.is_available())
|
||||
or (platform.system() != "Linux")
|
||||
or ("ROCMExecutionProvider" not in onnxruntime.get_available_providers()),
|
||||
reason="ROCm is not available, skipping tests.",
|
||||
)
|
||||
class TestRocmGQA(unittest.TestCase):
|
||||
@parameterized.expand(gqa_no_past_flash_attention_test_cases())
|
||||
def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
|
||||
config.ep = "ROCMExecutionProvider"
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
if platform.system() != "Linux":
|
||||
return
|
||||
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
|
||||
return
|
||||
print("------- FLASH ATTENTION (PROMPT CASE) --------")
|
||||
|
||||
parity_check_gqa_prompt(
|
||||
|
|
@ -52,12 +52,6 @@ class TestGQA(unittest.TestCase):
|
|||
@parameterized.expand(gqa_past_flash_attention_test_cases())
|
||||
def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
|
||||
config.ep = "ROCMExecutionProvider"
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
if platform.system() != "Linux":
|
||||
return
|
||||
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
|
||||
return
|
||||
print("------- FLASH ATTENTION (TOKEN GEN) -------")
|
||||
|
||||
parity_check_gqa_past(
|
||||
|
|
|
|||
|
|
@ -1900,7 +1900,7 @@ class TestGQA(unittest.TestCase):
|
|||
def test_gqa_no_past(self):
|
||||
torch.manual_seed(69)
|
||||
print("-------- TEST GQA NO PAST (PROMPT CASE) ---------")
|
||||
batches = [1, 3] if pipeline_mode else [1, 3, 5]
|
||||
batches = [3] if pipeline_mode else [1, 3, 5]
|
||||
seqs = (
|
||||
[
|
||||
(127, 127),
|
||||
|
|
@ -1916,8 +1916,8 @@ class TestGQA(unittest.TestCase):
|
|||
(8000, 8000),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
for b in batches:
|
||||
for sq, skv in seqs:
|
||||
for n, n2 in num_h:
|
||||
|
|
@ -1954,9 +1954,9 @@ class TestGQA(unittest.TestCase):
|
|||
|
||||
def test_gqa_past(self):
|
||||
print("-------- TEST GQA PAST (TOKEN GEN) ---------")
|
||||
batches = [1, 3] if pipeline_mode else [1, 3, 5]
|
||||
batches = [1] if pipeline_mode else [1, 3, 5]
|
||||
seqs = (
|
||||
[(1, 128), (1, 1024), (1, 2048)]
|
||||
[(1, 128)]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(1, 128),
|
||||
|
|
@ -1972,8 +1972,8 @@ class TestGQA(unittest.TestCase):
|
|||
# (128, 128),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
random.seed(69)
|
||||
for b in batches:
|
||||
for s, s2 in seqs:
|
||||
|
|
@ -2018,7 +2018,7 @@ class TestGQA(unittest.TestCase):
|
|||
print("-------- TEST GQA INTERACTIVE ---------")
|
||||
batches = [1]
|
||||
seqs = (
|
||||
[(2, 128), (128, 129), (32, 128), (256, 2048)]
|
||||
[(256, 2048)]
|
||||
if pipeline_mode
|
||||
else [
|
||||
(1, 128),
|
||||
|
|
@ -2034,8 +2034,8 @@ class TestGQA(unittest.TestCase):
|
|||
# (128, 128),
|
||||
]
|
||||
)
|
||||
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
|
||||
h_sizes = [32] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
random.seed(69)
|
||||
for b in batches:
|
||||
for s, s2 in seqs:
|
||||
|
|
|
|||
|
|
@ -2149,7 +2149,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs):
|
|||
],
|
||||
cwd=SCRIPT_DIR,
|
||||
)
|
||||
run_subprocess([sys.executable, "-m", "pytest", "transformers"], cwd=cwd)
|
||||
run_subprocess([sys.executable, "-m", "pytest", "--durations=0", "transformers"], cwd=cwd)
|
||||
# Restore initial numpy/protobuf version in case other tests use it
|
||||
run_subprocess([sys.executable, "-m", "pip", "install", "numpy==" + numpy_init_version])
|
||||
run_subprocess([sys.executable, "-m", "pip", "install", "protobuf==" + pb_init_version])
|
||||
|
|
|
|||
Loading…
Reference in a new issue