[generate] return Cache object even if passed in a legacy format (#35673)

* generate returns a Cache object by default

* fix tests

* fix test for encoder-decoder models
This commit is contained in:
Joao Gante 2025-01-16 17:06:24 +00:00 committed by GitHub
parent 2818307e93
commit 94af1c0aa2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 36 additions and 156 deletions

View file

@ -2111,9 +2111,6 @@ class GenerationMixin:
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
# - different models have a different cache name expected by the model (default = "past_key_values")
# - `max_length`, prepared above, is used to determine the maximum cache length
# TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format)
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
user_defined_cache = model_kwargs.get(cache_name)
max_cache_length = generation_config.max_length
if (
inputs_tensor.shape[1] != input_ids_length
@ -2395,32 +2392,12 @@ class GenerationMixin:
# Convert to legacy cache format if requested
if (
generation_config.return_legacy_cache is not False # Should check for `True` after v4.47
generation_config.return_legacy_cache is True
and not is_torchdynamo_compiling()
and hasattr(result, "past_key_values")
and hasattr(result.past_key_values, "to_legacy_cache")
and result.past_key_values.to_legacy_cache is not None
and getattr(result.past_key_values, "to_legacy_cache") is not None
):
# handle BC (convert by default if he user hasn't passed a cache AND the cache is of the default type)
should_convert_cache = generation_config.return_legacy_cache
is_user_defined_cache = user_defined_cache is not None
is_default_cache_type = (
type(result.past_key_values) == DynamicCache # noqa E721
or (
isinstance(result.past_key_values, EncoderDecoderCache)
and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721
and type(result.past_key_values.cross_attention_cache) == DynamicCache # noqa E721
)
)
if not is_user_defined_cache and is_default_cache_type:
logger.warning_once(
"From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` "
"instance instead by default (as opposed to the legacy tuple of tuples format). If you want to "
"keep returning the legacy format, please set `return_legacy_cache=True`."
)
should_convert_cache = True
if should_convert_cache:
result.past_key_values = result.past_key_values.to_legacy_cache()
result.past_key_values = result.past_key_values.to_legacy_cache()
return result
def _has_unfinished_sequences(

View file

@ -26,7 +26,7 @@ import numpy as np
import pytest
from parameterized import parameterized
from transformers import AutoConfig, is_torch_available, pipeline, set_seed
from transformers import AutoConfig, is_torch_available, pipeline
from transformers.testing_utils import (
is_flaky,
require_accelerate,
@ -69,7 +69,7 @@ if is_torch_available():
SpeechEncoderDecoderModel,
T5ForConditionalGeneration,
)
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
@ -1851,75 +1851,6 @@ class GenerationTesterMixin:
)
)
@parameterized.expand([(1, False), (1, True), (4, False)])
@pytest.mark.generate
def test_new_cache_format(self, num_beams, do_sample):
# Tests that generating with the new format is exactly the same as the legacy one (for models that support it).
# 👉 tests with and without beam search so that we can test with and without cache reordering.
# 👉 tests with and without sampling so we can cover the most common use cases.
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
self.skipTest(reason="This model does not support the new cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": 5,
"do_sample": do_sample,
"num_beams": num_beams,
"num_return_sequences": num_beams,
"return_dict_in_generate": True, # Required to return `past_key_values`
"use_cache": True,
}
# Sets seed before calling `generate` for the case with do_sample=True
seed = torch.randint(0, 1000000, (1,)).item()
set_seed(seed)
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
set_seed(seed)
if config.is_encoder_decoder:
cache_cls = EncoderDecoderCache
past_key_values = cache_cls(DynamicCache(), DynamicCache())
else:
cache_cls = DynamicCache
past_key_values = cache_cls()
new_results = model.generate(past_key_values=past_key_values, **generation_kwargs, **inputs_dict)
# The two sets of generated sequences must match, despite the cache format between forward passes being
# different
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
self.assertTrue(isinstance(new_results.past_key_values, cache_cls))
# The contents of the two caches, when converted to the same format (in both directions!), must match
legacy_cache = legacy_results.past_key_values
new_cache_converted = new_results.past_key_values.to_legacy_cache()
for layer_idx in range(len(legacy_cache)):
for kv_idx in range(len(legacy_cache[layer_idx])):
# TODO: @raushan, please look into this for new cache format
if legacy_cache[layer_idx][kv_idx] != []:
self.assertTrue(
torch.allclose(
legacy_cache[layer_idx][kv_idx],
new_cache_converted[layer_idx][kv_idx],
)
)
new_cache = new_results.past_key_values
legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
for layer_idx in range(len(new_cache)):
for kv_idx in range(len(new_cache[layer_idx])):
# TODO: @raushan, please look into this for new cache format
if new_cache[layer_idx][kv_idx] != []:
self.assertTrue(
torch.allclose(
new_cache[layer_idx][kv_idx],
legacy_cache_converted[layer_idx][kv_idx],
)
)
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
@require_torch_gpu
@pytest.mark.generate
@ -2438,11 +2369,11 @@ class GenerationTesterMixin:
)
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
self.assertIsInstance(past_key_values, tuple)
self.assertListEqual(
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
[True] * len(past_key_values),
)
self.assertIsInstance(past_key_values, (tuple, Cache))
# Encoder-decoder models: pull and verify the decoder cache
if isinstance(past_key_values, EncoderDecoderCache):
past_key_values = past_key_values.self_attention_cache
# (batch, head, seq_length, head_features)
expected_shape = (
@ -2451,15 +2382,32 @@ class GenerationTesterMixin:
seq_length,
config.hidden_size // config.num_attention_heads,
)
# check shape key, value
self.assertListEqual(
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
[expected_shape] * len(past_key_values),
)
self.assertListEqual(
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
[expected_shape] * len(past_key_values),
)
if isinstance(past_key_values, Cache):
self.assertListEqual(
[key_tensor.shape for key_tensor in past_key_values.key_cache],
[expected_shape] * len(past_key_values.key_cache),
)
self.assertListEqual(
[value_tensor.shape for value_tensor in past_key_values.value_cache],
[expected_shape] * len(past_key_values.value_cache),
)
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
else:
self.assertListEqual(
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
[True] * len(past_key_values),
)
# check shape key, value
self.assertListEqual(
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
[expected_shape] * len(past_key_values),
)
self.assertListEqual(
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
[expected_shape] * len(past_key_values),
)
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.

View file

@ -268,18 +268,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip(reason="")
def test_new_cache_format_0(self):
pass
@unittest.skip(reason="")
def test_new_cache_format_1(self):
pass
@unittest.skip(reason="")
def test_new_cache_format_2(self):
pass
@unittest.skip(reason="Feedforward chunking is not yet supported")
def test_feed_forward_chunking(self):
pass

View file

@ -18,7 +18,6 @@ import inspect
import unittest
import pytest
from parameterized import parameterized
from transformers import AutoTokenizer, BambaConfig, is_torch_available
from transformers.testing_utils import (
@ -395,11 +394,6 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
@unittest.skip(reason="Bamba has its own special cache type")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
def test_batching_equivalence(self):
# need to disable the tril input mask
orig = self.model_tester.use_input_mask

View file

@ -103,11 +103,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
def test_dola_decoding_sample(self):
pass
@parameterized.expand([(1, False), (1, True), (4, False)])
@unittest.skip("Cohere2 has HybridCache and doesn't support old tuple format at all")
def test_new_cache_format(self, num_beams, do_sample):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass

View file

@ -117,11 +117,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_dola_decoding_sample(self):
pass
@parameterized.expand([(1, False), (1, True), (4, False)])
@unittest.skip("Gemma2 has HybridCache and doesn't support old tuple format at all")
def test_new_cache_format(self, num_beams, do_sample):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass

View file

@ -19,7 +19,6 @@ import tempfile
import unittest
import pytest
from parameterized import parameterized
from transformers import AutoTokenizer, JambaConfig, is_torch_available
from transformers.testing_utils import (
@ -550,11 +549,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
"""
self.skipTest(reason="Jamba flash attention does not support right padding")
@unittest.skip(reason="Jamba has its own special cache type")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch
class JambaModelIntegrationTest(unittest.TestCase):

View file

@ -18,7 +18,6 @@ import gc
import unittest
import pytest
from parameterized import parameterized
from transformers import AutoTokenizer, JetMoeConfig, is_torch_available
from transformers.testing_utils import (
@ -299,10 +298,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
test_disk_offload_bin = False
test_disk_offload_safetensors = False
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
def setUp(self):
self.model_tester = JetMoeModelTester(self)
self.config_tester = ConfigTester(

View file

@ -19,7 +19,6 @@ import tempfile
import unittest
import pytest
from parameterized import parameterized
from transformers import AutoTokenizer, ZambaConfig, is_torch_available
from transformers.testing_utils import (
@ -551,11 +550,6 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
"""
self.skipTest(reason="Zamba flash attention does not support right padding")
@unittest.skip(reason="Zamba has its own special cache type")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch
class ZambaModelIntegrationTest(unittest.TestCase):