diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 18cbab600..655a388cb 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2111,9 +2111,6 @@ class GenerationMixin: # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. # - different models have a different cache name expected by the model (default = "past_key_values") # - `max_length`, prepared above, is used to determine the maximum cache length - # TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format) - cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" - user_defined_cache = model_kwargs.get(cache_name) max_cache_length = generation_config.max_length if ( inputs_tensor.shape[1] != input_ids_length @@ -2395,32 +2392,12 @@ class GenerationMixin: # Convert to legacy cache format if requested if ( - generation_config.return_legacy_cache is not False # Should check for `True` after v4.47 + generation_config.return_legacy_cache is True and not is_torchdynamo_compiling() and hasattr(result, "past_key_values") - and hasattr(result.past_key_values, "to_legacy_cache") - and result.past_key_values.to_legacy_cache is not None + and getattr(result.past_key_values, "to_legacy_cache") is not None ): - # handle BC (convert by default if he user hasn't passed a cache AND the cache is of the default type) - should_convert_cache = generation_config.return_legacy_cache - is_user_defined_cache = user_defined_cache is not None - is_default_cache_type = ( - type(result.past_key_values) == DynamicCache # noqa E721 - or ( - isinstance(result.past_key_values, EncoderDecoderCache) - and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721 - and type(result.past_key_values.cross_attention_cache) == DynamicCache # noqa E721 - ) - ) - if not is_user_defined_cache and is_default_cache_type: - logger.warning_once( - "From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` " - "instance instead by default (as opposed to the legacy tuple of tuples format). If you want to " - "keep returning the legacy format, please set `return_legacy_cache=True`." - ) - should_convert_cache = True - if should_convert_cache: - result.past_key_values = result.past_key_values.to_legacy_cache() + result.past_key_values = result.past_key_values.to_legacy_cache() return result def _has_unfinished_sequences( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 7499a5599..d59a18c59 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -26,7 +26,7 @@ import numpy as np import pytest from parameterized import parameterized -from transformers import AutoConfig, is_torch_available, pipeline, set_seed +from transformers import AutoConfig, is_torch_available, pipeline from transformers.testing_utils import ( is_flaky, require_accelerate, @@ -69,7 +69,7 @@ if is_torch_available(): SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache + from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -1851,75 +1851,6 @@ class GenerationTesterMixin: ) ) - @parameterized.expand([(1, False), (1, True), (4, False)]) - @pytest.mark.generate - def test_new_cache_format(self, num_beams, do_sample): - # Tests that generating with the new format is exactly the same as the legacy one (for models that support it). - # 👉 tests with and without beam search so that we can test with and without cache reordering. - # 👉 tests with and without sampling so we can cover the most common use cases. - for model_class in self.all_generative_model_classes: - if not model_class._supports_cache_class: - self.skipTest(reason="This model does not support the new cache format") - - config, inputs_dict = self.prepare_config_and_inputs_for_generate() - - model = model_class(config).to(torch_device).eval() - generation_kwargs = { - "max_new_tokens": 5, - "do_sample": do_sample, - "num_beams": num_beams, - "num_return_sequences": num_beams, - "return_dict_in_generate": True, # Required to return `past_key_values` - "use_cache": True, - } - - # Sets seed before calling `generate` for the case with do_sample=True - seed = torch.randint(0, 1000000, (1,)).item() - set_seed(seed) - legacy_results = model.generate(**generation_kwargs, **inputs_dict) - set_seed(seed) - if config.is_encoder_decoder: - cache_cls = EncoderDecoderCache - past_key_values = cache_cls(DynamicCache(), DynamicCache()) - else: - cache_cls = DynamicCache - past_key_values = cache_cls() - - new_results = model.generate(past_key_values=past_key_values, **generation_kwargs, **inputs_dict) - - # The two sets of generated sequences must match, despite the cache format between forward passes being - # different - self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist()) - self.assertTrue(isinstance(legacy_results.past_key_values, tuple)) - self.assertTrue(isinstance(new_results.past_key_values, cache_cls)) - - # The contents of the two caches, when converted to the same format (in both directions!), must match - legacy_cache = legacy_results.past_key_values - new_cache_converted = new_results.past_key_values.to_legacy_cache() - for layer_idx in range(len(legacy_cache)): - for kv_idx in range(len(legacy_cache[layer_idx])): - # TODO: @raushan, please look into this for new cache format - if legacy_cache[layer_idx][kv_idx] != []: - self.assertTrue( - torch.allclose( - legacy_cache[layer_idx][kv_idx], - new_cache_converted[layer_idx][kv_idx], - ) - ) - - new_cache = new_results.past_key_values - legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) - for layer_idx in range(len(new_cache)): - for kv_idx in range(len(new_cache[layer_idx])): - # TODO: @raushan, please look into this for new cache format - if new_cache[layer_idx][kv_idx] != []: - self.assertTrue( - torch.allclose( - new_cache[layer_idx][kv_idx], - legacy_cache_converted[layer_idx][kv_idx], - ) - ) - @parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5) @require_torch_gpu @pytest.mark.generate @@ -2438,11 +2369,11 @@ class GenerationTesterMixin: ) def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1): - self.assertIsInstance(past_key_values, tuple) - self.assertListEqual( - [isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values], - [True] * len(past_key_values), - ) + self.assertIsInstance(past_key_values, (tuple, Cache)) + + # Encoder-decoder models: pull and verify the decoder cache + if isinstance(past_key_values, EncoderDecoderCache): + past_key_values = past_key_values.self_attention_cache # (batch, head, seq_length, head_features) expected_shape = ( @@ -2451,15 +2382,32 @@ class GenerationTesterMixin: seq_length, config.hidden_size // config.num_attention_heads, ) - # check shape key, value - self.assertListEqual( - [layer_past_key_values[0].shape for layer_past_key_values in past_key_values], - [expected_shape] * len(past_key_values), - ) - self.assertListEqual( - [layer_past_key_values[1].shape for layer_past_key_values in past_key_values], - [expected_shape] * len(past_key_values), - ) + + if isinstance(past_key_values, Cache): + self.assertListEqual( + [key_tensor.shape for key_tensor in past_key_values.key_cache], + [expected_shape] * len(past_key_values.key_cache), + ) + self.assertListEqual( + [value_tensor.shape for value_tensor in past_key_values.value_cache], + [expected_shape] * len(past_key_values.value_cache), + ) + + # Legacy cache format checks. This branch should be removed when all models use `Cache` by default + else: + self.assertListEqual( + [isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values], + [True] * len(past_key_values), + ) + # check shape key, value + self.assertListEqual( + [layer_past_key_values[0].shape for layer_past_key_values in past_key_values], + [expected_shape] * len(past_key_values), + ) + self.assertListEqual( + [layer_past_key_values[1].shape for layer_past_key_values in past_key_values], + [expected_shape] * len(past_key_values), + ) def _check_sequence_inside_sequence(self, tensor_1, tensor_2): # check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1. diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py index b6f1da56c..ead68f1f2 100644 --- a/tests/models/aria/test_modeling_aria.py +++ b/tests/models/aria/test_modeling_aria.py @@ -268,18 +268,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi def test_sdpa_can_dispatch_on_flash(self): pass - @unittest.skip(reason="") - def test_new_cache_format_0(self): - pass - - @unittest.skip(reason="") - def test_new_cache_format_1(self): - pass - - @unittest.skip(reason="") - def test_new_cache_format_2(self): - pass - @unittest.skip(reason="Feedforward chunking is not yet supported") def test_feed_forward_chunking(self): pass diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index 45819e66b..9356824da 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -18,7 +18,6 @@ import inspect import unittest import pytest -from parameterized import parameterized from transformers import AutoTokenizer, BambaConfig, is_torch_available from transformers.testing_utils import ( @@ -395,11 +394,6 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], ) - @unittest.skip(reason="Bamba has its own special cache type") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - def test_batching_equivalence(self): # need to disable the tril input mask orig = self.model_tester.use_input_mask diff --git a/tests/models/cohere2/test_modeling_cohere2.py b/tests/models/cohere2/test_modeling_cohere2.py index 144846772..55da9e5ee 100644 --- a/tests/models/cohere2/test_modeling_cohere2.py +++ b/tests/models/cohere2/test_modeling_cohere2.py @@ -103,11 +103,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase): def test_dola_decoding_sample(self): pass - @parameterized.expand([(1, False), (1, True), (4, False)]) - @unittest.skip("Cohere2 has HybridCache and doesn't support old tuple format at all") - def test_new_cache_format(self, num_beams, do_sample): - pass - @unittest.skip("Cohere2 has HybridCache and doesn't support continue from past kv") def test_generate_continue_from_past_key_values(self): pass diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index d65c961bc..57c6331c8 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -117,11 +117,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): def test_dola_decoding_sample(self): pass - @parameterized.expand([(1, False), (1, True), (4, False)]) - @unittest.skip("Gemma2 has HybridCache and doesn't support old tuple format at all") - def test_new_cache_format(self, num_beams, do_sample): - pass - @unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv") def test_generate_continue_from_past_key_values(self): pass diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index ef0b58315..2f284763e 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -19,7 +19,6 @@ import tempfile import unittest import pytest -from parameterized import parameterized from transformers import AutoTokenizer, JambaConfig, is_torch_available from transformers.testing_utils import ( @@ -550,11 +549,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi """ self.skipTest(reason="Jamba flash attention does not support right padding") - @unittest.skip(reason="Jamba has its own special cache type") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch class JambaModelIntegrationTest(unittest.TestCase): diff --git a/tests/models/jetmoe/test_modeling_jetmoe.py b/tests/models/jetmoe/test_modeling_jetmoe.py index dc510f0ff..ba7dc5377 100644 --- a/tests/models/jetmoe/test_modeling_jetmoe.py +++ b/tests/models/jetmoe/test_modeling_jetmoe.py @@ -18,7 +18,6 @@ import gc import unittest import pytest -from parameterized import parameterized from transformers import AutoTokenizer, JetMoeConfig, is_torch_available from transformers.testing_utils import ( @@ -299,10 +298,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix test_disk_offload_bin = False test_disk_offload_safetensors = False - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - def setUp(self): self.model_tester = JetMoeModelTester(self) self.config_tester = ConfigTester( diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index ee47f98a1..fc2d94c75 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -19,7 +19,6 @@ import tempfile import unittest import pytest -from parameterized import parameterized from transformers import AutoTokenizer, ZambaConfig, is_torch_available from transformers.testing_utils import ( @@ -551,11 +550,6 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi """ self.skipTest(reason="Zamba flash attention does not support right padding") - @unittest.skip(reason="Zamba has its own special cache type") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch class ZambaModelIntegrationTest(unittest.TestCase):