From 8d99b1a8dc5318bde4463817c02552ebca0cf547 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 21 Nov 2024 12:26:46 -0800 Subject: [PATCH] reduce GQA test combinations (#22918) ### Description * Reduce GQA test combinations to save about 35 minutes test time in CI pipelines. * Show latency of transformers tests * Use seed in DMMHA test to avoid random failure. * For test_flash_attn_rocm.py, test skipping condition from "has cuda ep" to "not has rocm ep", so that it does not run in cpu build. * For test_flash_attn_cuda.py, move flash attention and memory efficient attention tests to different classes, so that we can skip a test suite instead of checking in each test. ### Motivation and Context It takes too long to run GQA tests in CI pipelines since there are too many combinations. ###### Linux GPU CI Pipeline Before: 5097 passed, 68 skipped, 8 warnings in 1954.64s (0:32:34) After: 150 passed, 176 skipped, 8 warnings in 530.38s (0:08:50) Time Saved: **1424** seconds (0:23:44) ###### Windows GPU CUDA CI Pipeline Before: 1781 passed, 72 skipped, 6 warnings in 605.48s (0:10:05) After: 116 passed, 118 skipped, 6 warnings in 275.48s (0:04:35) Time Saved: **330** seconds (0:05:30) ###### Linux CPU CI Pipeline Before: 5093 passed, 72 skipped, 4 warnings in 467.04s (0:07:47) - 212.96s transformers/test_gqa_cpu.py::TestGQA::test_gqa_past - 154.12s transformers/test_gqa_cpu.py::TestGQA::test_gqa_no_past - 26.45s transformers/test_gqa_cpu.py::TestGQA::test_gqa_interactive_one_batch After: 116 passed, 210 skipped, 4 warnings in 93.41s (0:01:33) - 0.97s transformers/test_gqa_cpu.py::TestGQA::test_gqa_past - 19.23s transformers/test_gqa_cpu.py::TestGQA::test_gqa_no_past - 2.41s transformers/test_gqa_cpu.py::TestGQA::test_gqa_interactive_one_batch Time Saved: **374** seconds (0:06:14). --- ...oder_masked_multihead_attention_op_test.cc | 2 +- .../transformers/test_flash_attn_cuda.py | 180 ++++++++---------- .../transformers/test_flash_attn_rocm.py | 20 +- .../test/python/transformers/test_gqa_cpu.py | 20 +- tools/ci_build/build.py | 2 +- 5 files changed, 103 insertions(+), 121 deletions(-) diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc index 17685ab82f..208545eacf 100644 --- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc @@ -757,7 +757,7 @@ static void TestDecoderMaskedMultiHeadAttention(bool is_cross_attn = true, bool OpTester tester("DecoderMaskedMultiHeadAttention", 1, onnxruntime::kMSDomain); FixedPatternValueGenerator generator{}; - RandomValueGenerator random{}; + RandomValueGenerator random{123}; // Attributes tester.AddAttribute("num_heads", static_cast(num_heads)); diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index 46ab905977..a74d5389e9 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -24,7 +24,7 @@ from packaging import version from parameterized import parameterized from test_gqa_cpu import smooth_softmax_ref -from onnxruntime import InferenceSession, OrtValue, SessionOptions +from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers torch.manual_seed(0) @@ -1999,6 +1999,8 @@ def parity_check_gqa_past_no_buff( def has_flash_attention(): if not torch.cuda.is_available(): return False + if "CUDAExecutionProvider" not in get_available_providers(): + return False major, _ = torch.cuda.get_device_capability() return major >= 8 and ( platform.system() == "Linux" @@ -2009,6 +2011,8 @@ def has_flash_attention(): def has_memory_efficient(): if not torch.cuda.is_available(): return False + if "CUDAExecutionProvider" not in get_available_providers(): + return False major, minor = torch.cuda.get_device_capability() if major < 5 or (major == 5 and minor < 3): return False @@ -2047,8 +2051,8 @@ def mha_test_cases(): (2048, 2048), ] ) - num_h = [1, 3] if pipeline_mode else [1, 6, 16] - h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [3] if pipeline_mode else [1, 6, 16] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] for b in batches: for s, s2 in seqs: @@ -2080,11 +2084,7 @@ def gqa_no_past_memory_efficient_test_cases(): batches = [3] if pipeline_mode else [1, 3, 5] seqs = ( [ - (127, 127), - (35, 35), (2000, 2000), - (200, 200), - (240, 240), ] if pipeline_mode else [ @@ -2095,8 +2095,8 @@ def gqa_no_past_memory_efficient_test_cases(): (240, 240), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] torch.manual_seed(69) for b in batches: @@ -2121,10 +2121,6 @@ def gqa_no_past_flash_attention_test_cases(): batches = [3] if pipeline_mode else [1, 3, 5] seqs = ( [ - (127, 127), - (35, 35), - (2000, 2000), - (200, 200), (240, 240), ] if pipeline_mode @@ -2136,8 +2132,8 @@ def gqa_no_past_flash_attention_test_cases(): (240, 240), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] torch.manual_seed(69) for b in batches: @@ -2163,7 +2159,7 @@ def gqa_no_past_flash_attention_test_cases(): def gqa_past_memory_efficient_test_cases(): batches = [5] if pipeline_mode else [1, 3, 5] seqs = ( - [(1, 128), (1, 1024), (1, 2048)] + [(1, 1024)] if pipeline_mode else [ (1, 128), @@ -2179,8 +2175,8 @@ def gqa_past_memory_efficient_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2205,7 +2201,7 @@ def gqa_past_memory_efficient_test_cases(): def gqa_past_flash_attention_test_cases(): batches = [5] if pipeline_mode else [1, 3, 5] seqs = ( - [(1, 128), (1, 1024), (1, 2048)] + [(1, 2048)] if pipeline_mode else [ (1, 128), @@ -2221,8 +2217,8 @@ def gqa_past_flash_attention_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2249,7 +2245,7 @@ def gqa_past_flash_attention_test_cases(): def gqa_interactive_one_batch_flash_attention_test_cases(): batches = [1] seqs = ( - [(2, 128), (128, 129), (32, 128), (256, 2048)] + [(128, 2048)] if pipeline_mode else [ (1, 128), @@ -2265,8 +2261,8 @@ def gqa_interactive_one_batch_flash_attention_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2290,7 +2286,7 @@ def gqa_interactive_one_batch_flash_attention_test_cases(): def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): batches = [1] seqs = ( - [(2, 128), (128, 129), (32, 128), (256, 2048)] + [(32, 128)] if pipeline_mode else [ (1, 128), @@ -2306,8 +2302,8 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2326,41 +2322,10 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): ) -class TestGQA(unittest.TestCase): - @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) - def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): - if not has_memory_efficient(): - return - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") - - parity_check_gqa_prompt( - config, - rtol=5e-3, - atol=5e-3, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=False, - ) - parity_check_gqa_prompt_no_buff( - config, - rtol=5e-3, - atol=5e-3, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=True, - ) - +@unittest.skipIf(not has_flash_attention(), reason="Flash Attention is not available, skipping tests.") +class TestFlashGQA(unittest.TestCase): @parameterized.expand(gqa_no_past_flash_attention_test_cases()) def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - if not has_flash_attention(): - return print("------- FLASH ATTENTION (PROMPT CASE) --------") os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" @@ -2385,40 +2350,8 @@ class TestGQA(unittest.TestCase): use_smooth_softmax=False, ) - @parameterized.expand(gqa_past_memory_efficient_test_cases()) - def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): - if not has_memory_efficient(): - return - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") - - parity_check_gqa_past( - config, - past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=True, - ) - parity_check_gqa_past_no_buff( - config, - past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=False, - ) - @parameterized.expand(gqa_past_flash_attention_test_cases()) def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - if not has_flash_attention(): - return print("------- FLASH ATTENTION (TOKEN GEN) -------") os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" @@ -2449,8 +2382,6 @@ class TestGQA(unittest.TestCase): @parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases()) def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): - if not has_flash_attention(): - return print("------- FLASH ATTENTION (INTERACTIVE) -------") os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" @@ -2475,10 +2406,67 @@ class TestGQA(unittest.TestCase): packed=packed, ) + +@unittest.skipIf(not has_memory_efficient(), reason="Memory efficient FMHA is not available, skipping tests.") +class TestMemoryEfficientGQA(unittest.TestCase): + @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) + def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") + + parity_check_gqa_prompt( + config, + rtol=5e-3, + atol=5e-3, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=False, + ) + parity_check_gqa_prompt_no_buff( + config, + rtol=5e-3, + atol=5e-3, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=True, + ) + + @parameterized.expand(gqa_past_memory_efficient_test_cases()) + def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") + + parity_check_gqa_past( + config, + past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=True, + ) + parity_check_gqa_past_no_buff( + config, + past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=False, + ) + @parameterized.expand(gqa_interactive_one_batch_memory_efficient_attention_test_cases()) def test_gqa_interactive_one_batch_memory_efficient_attention(self, _, config, rotary, rotary_interleaved, packed): - if not has_memory_efficient(): - return os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" print("-------- MEMORY EFFICIENT (INTERACTIVE) --------") diff --git a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py index 99460722c2..a5910c28c2 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py @@ -16,16 +16,16 @@ from test_flash_attn_cuda import ( import onnxruntime -class TestGQA(unittest.TestCase): +@unittest.skipIf( + (not torch.cuda.is_available()) + or (platform.system() != "Linux") + or ("ROCMExecutionProvider" not in onnxruntime.get_available_providers()), + reason="ROCm is not available, skipping tests.", +) +class TestRocmGQA(unittest.TestCase): @parameterized.expand(gqa_no_past_flash_attention_test_cases()) def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): config.ep = "ROCMExecutionProvider" - if not torch.cuda.is_available(): - return - if platform.system() != "Linux": - return - if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): - return print("------- FLASH ATTENTION (PROMPT CASE) --------") parity_check_gqa_prompt( @@ -52,12 +52,6 @@ class TestGQA(unittest.TestCase): @parameterized.expand(gqa_past_flash_attention_test_cases()) def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): config.ep = "ROCMExecutionProvider" - if not torch.cuda.is_available(): - return - if platform.system() != "Linux": - return - if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): - return print("------- FLASH ATTENTION (TOKEN GEN) -------") parity_check_gqa_past( diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 08ec5de328..77b4b326bf 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -1900,7 +1900,7 @@ class TestGQA(unittest.TestCase): def test_gqa_no_past(self): torch.manual_seed(69) print("-------- TEST GQA NO PAST (PROMPT CASE) ---------") - batches = [1, 3] if pipeline_mode else [1, 3, 5] + batches = [3] if pipeline_mode else [1, 3, 5] seqs = ( [ (127, 127), @@ -1916,8 +1916,8 @@ class TestGQA(unittest.TestCase): (8000, 8000), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] for b in batches: for sq, skv in seqs: for n, n2 in num_h: @@ -1954,9 +1954,9 @@ class TestGQA(unittest.TestCase): def test_gqa_past(self): print("-------- TEST GQA PAST (TOKEN GEN) ---------") - batches = [1, 3] if pipeline_mode else [1, 3, 5] + batches = [1] if pipeline_mode else [1, 3, 5] seqs = ( - [(1, 128), (1, 1024), (1, 2048)] + [(1, 128)] if pipeline_mode else [ (1, 128), @@ -1972,8 +1972,8 @@ class TestGQA(unittest.TestCase): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: for s, s2 in seqs: @@ -2018,7 +2018,7 @@ class TestGQA(unittest.TestCase): print("-------- TEST GQA INTERACTIVE ---------") batches = [1] seqs = ( - [(2, 128), (128, 129), (32, 128), (256, 2048)] + [(256, 2048)] if pipeline_mode else [ (1, 128), @@ -2034,8 +2034,8 @@ class TestGQA(unittest.TestCase): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [32] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: for s, s2 in seqs: diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index aa1198102f..3bfbc01086 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2149,7 +2149,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): ], cwd=SCRIPT_DIR, ) - run_subprocess([sys.executable, "-m", "pytest", "transformers"], cwd=cwd) + run_subprocess([sys.executable, "-m", "pytest", "--durations=0", "transformers"], cwd=cwd) # Restore initial numpy/protobuf version in case other tests use it run_subprocess([sys.executable, "-m", "pip", "install", "numpy==" + numpy_init_version]) run_subprocess([sys.executable, "-m", "pip", "install", "protobuf==" + pb_init_version])