From 27c7f971c0dcd3bb423ea221fe2bce751d313119 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Fri, 26 Jul 2024 17:41:27 +0800 Subject: [PATCH] [tests] fix `static` cache implementation is not compatible with `attn_implementation==flash_attention_2` (#32039) * add flash attention check * fix * fix --- tests/utils/test_cache_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index b8366cc27..98db42cfe 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -290,7 +290,7 @@ class CacheIntegrationTest(unittest.TestCase): self.assertTrue(decoded[0].endswith(last_output)) @require_torch_gpu - @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) + @parameterized.expand(["eager", "sdpa"]) def test_static_cache_greedy_decoding_pad_left(self, attn_implementation): EXPECTED_GENERATION = [ "The best color is the one that complements the skin tone of the", @@ -330,7 +330,7 @@ class CacheIntegrationTest(unittest.TestCase): self.assertListEqual(decoded, EXPECTED_GENERATION) @require_torch_gpu - @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) + @parameterized.expand(["eager", "sdpa"]) def test_static_cache_greedy_decoding_pad_right(self, attn_implementation): EXPECTED_GENERATION = [ "The best color isЋ the one that complements the skin tone of",