mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[generate] return Cache object even if passed in a legacy format (#35673)
* generate returns a Cache object by default * fix tests * fix test for encoder-decoder models
This commit is contained in:
parent
2818307e93
commit
94af1c0aa2
9 changed files with 36 additions and 156 deletions
|
|
@ -2111,9 +2111,6 @@ class GenerationMixin:
|
|||
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
|
||||
# - different models have a different cache name expected by the model (default = "past_key_values")
|
||||
# - `max_length`, prepared above, is used to determine the maximum cache length
|
||||
# TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format)
|
||||
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
|
||||
user_defined_cache = model_kwargs.get(cache_name)
|
||||
max_cache_length = generation_config.max_length
|
||||
if (
|
||||
inputs_tensor.shape[1] != input_ids_length
|
||||
|
|
@ -2395,32 +2392,12 @@ class GenerationMixin:
|
|||
|
||||
# Convert to legacy cache format if requested
|
||||
if (
|
||||
generation_config.return_legacy_cache is not False # Should check for `True` after v4.47
|
||||
generation_config.return_legacy_cache is True
|
||||
and not is_torchdynamo_compiling()
|
||||
and hasattr(result, "past_key_values")
|
||||
and hasattr(result.past_key_values, "to_legacy_cache")
|
||||
and result.past_key_values.to_legacy_cache is not None
|
||||
and getattr(result.past_key_values, "to_legacy_cache") is not None
|
||||
):
|
||||
# handle BC (convert by default if he user hasn't passed a cache AND the cache is of the default type)
|
||||
should_convert_cache = generation_config.return_legacy_cache
|
||||
is_user_defined_cache = user_defined_cache is not None
|
||||
is_default_cache_type = (
|
||||
type(result.past_key_values) == DynamicCache # noqa E721
|
||||
or (
|
||||
isinstance(result.past_key_values, EncoderDecoderCache)
|
||||
and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721
|
||||
and type(result.past_key_values.cross_attention_cache) == DynamicCache # noqa E721
|
||||
)
|
||||
)
|
||||
if not is_user_defined_cache and is_default_cache_type:
|
||||
logger.warning_once(
|
||||
"From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` "
|
||||
"instance instead by default (as opposed to the legacy tuple of tuples format). If you want to "
|
||||
"keep returning the legacy format, please set `return_legacy_cache=True`."
|
||||
)
|
||||
should_convert_cache = True
|
||||
if should_convert_cache:
|
||||
result.past_key_values = result.past_key_values.to_legacy_cache()
|
||||
result.past_key_values = result.past_key_values.to_legacy_cache()
|
||||
return result
|
||||
|
||||
def _has_unfinished_sequences(
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ import numpy as np
|
|||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoConfig, is_torch_available, pipeline, set_seed
|
||||
from transformers import AutoConfig, is_torch_available, pipeline
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
|
|
@ -69,7 +69,7 @@ if is_torch_available():
|
|||
SpeechEncoderDecoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
|
||||
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
|
||||
from transformers.generation import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
|
|
@ -1851,75 +1851,6 @@ class GenerationTesterMixin:
|
|||
)
|
||||
)
|
||||
|
||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||
@pytest.mark.generate
|
||||
def test_new_cache_format(self, num_beams, do_sample):
|
||||
# Tests that generating with the new format is exactly the same as the legacy one (for models that support it).
|
||||
# 👉 tests with and without beam search so that we can test with and without cache reordering.
|
||||
# 👉 tests with and without sampling so we can cover the most common use cases.
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_cache_class:
|
||||
self.skipTest(reason="This model does not support the new cache format")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 5,
|
||||
"do_sample": do_sample,
|
||||
"num_beams": num_beams,
|
||||
"num_return_sequences": num_beams,
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
# Sets seed before calling `generate` for the case with do_sample=True
|
||||
seed = torch.randint(0, 1000000, (1,)).item()
|
||||
set_seed(seed)
|
||||
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
|
||||
set_seed(seed)
|
||||
if config.is_encoder_decoder:
|
||||
cache_cls = EncoderDecoderCache
|
||||
past_key_values = cache_cls(DynamicCache(), DynamicCache())
|
||||
else:
|
||||
cache_cls = DynamicCache
|
||||
past_key_values = cache_cls()
|
||||
|
||||
new_results = model.generate(past_key_values=past_key_values, **generation_kwargs, **inputs_dict)
|
||||
|
||||
# The two sets of generated sequences must match, despite the cache format between forward passes being
|
||||
# different
|
||||
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
|
||||
self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
|
||||
self.assertTrue(isinstance(new_results.past_key_values, cache_cls))
|
||||
|
||||
# The contents of the two caches, when converted to the same format (in both directions!), must match
|
||||
legacy_cache = legacy_results.past_key_values
|
||||
new_cache_converted = new_results.past_key_values.to_legacy_cache()
|
||||
for layer_idx in range(len(legacy_cache)):
|
||||
for kv_idx in range(len(legacy_cache[layer_idx])):
|
||||
# TODO: @raushan, please look into this for new cache format
|
||||
if legacy_cache[layer_idx][kv_idx] != []:
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
legacy_cache[layer_idx][kv_idx],
|
||||
new_cache_converted[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
new_cache = new_results.past_key_values
|
||||
legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
|
||||
for layer_idx in range(len(new_cache)):
|
||||
for kv_idx in range(len(new_cache[layer_idx])):
|
||||
# TODO: @raushan, please look into this for new cache format
|
||||
if new_cache[layer_idx][kv_idx] != []:
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
new_cache[layer_idx][kv_idx],
|
||||
legacy_cache_converted[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
|
||||
@require_torch_gpu
|
||||
@pytest.mark.generate
|
||||
|
|
@ -2438,11 +2369,11 @@ class GenerationTesterMixin:
|
|||
)
|
||||
|
||||
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
|
||||
self.assertIsInstance(past_key_values, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
|
||||
[True] * len(past_key_values),
|
||||
)
|
||||
self.assertIsInstance(past_key_values, (tuple, Cache))
|
||||
|
||||
# Encoder-decoder models: pull and verify the decoder cache
|
||||
if isinstance(past_key_values, EncoderDecoderCache):
|
||||
past_key_values = past_key_values.self_attention_cache
|
||||
|
||||
# (batch, head, seq_length, head_features)
|
||||
expected_shape = (
|
||||
|
|
@ -2451,15 +2382,32 @@ class GenerationTesterMixin:
|
|||
seq_length,
|
||||
config.hidden_size // config.num_attention_heads,
|
||||
)
|
||||
# check shape key, value
|
||||
self.assertListEqual(
|
||||
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
|
||||
[expected_shape] * len(past_key_values),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
|
||||
[expected_shape] * len(past_key_values),
|
||||
)
|
||||
|
||||
if isinstance(past_key_values, Cache):
|
||||
self.assertListEqual(
|
||||
[key_tensor.shape for key_tensor in past_key_values.key_cache],
|
||||
[expected_shape] * len(past_key_values.key_cache),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[value_tensor.shape for value_tensor in past_key_values.value_cache],
|
||||
[expected_shape] * len(past_key_values.value_cache),
|
||||
)
|
||||
|
||||
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
|
||||
else:
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
|
||||
[True] * len(past_key_values),
|
||||
)
|
||||
# check shape key, value
|
||||
self.assertListEqual(
|
||||
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
|
||||
[expected_shape] * len(past_key_values),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
|
||||
[expected_shape] * len(past_key_values),
|
||||
)
|
||||
|
||||
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
||||
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
|
||||
|
|
|
|||
|
|
@ -268,18 +268,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
|
|||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="")
|
||||
def test_new_cache_format_0(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="")
|
||||
def test_new_cache_format_1(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="")
|
||||
def test_new_cache_format_2(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ import inspect
|
|||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, BambaConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
|
|
@ -395,11 +394,6 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Bamba has its own special cache type")
|
||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||
def test_new_cache_format(self, num_beams, do_sample):
|
||||
pass
|
||||
|
||||
def test_batching_equivalence(self):
|
||||
# need to disable the tril input mask
|
||||
orig = self.model_tester.use_input_mask
|
||||
|
|
|
|||
|
|
@ -103,11 +103,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
|||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||
@unittest.skip("Cohere2 has HybridCache and doesn't support old tuple format at all")
|
||||
def test_new_cache_format(self, num_beams, do_sample):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cohere2 has HybridCache and doesn't support continue from past kv")
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -117,11 +117,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
|||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support old tuple format at all")
|
||||
def test_new_cache_format(self, num_beams, do_sample):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv")
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ import tempfile
|
|||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, JambaConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
|
|
@ -550,11 +549,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||
"""
|
||||
self.skipTest(reason="Jamba flash attention does not support right padding")
|
||||
|
||||
@unittest.skip(reason="Jamba has its own special cache type")
|
||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||
def test_new_cache_format(self, num_beams, do_sample):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class JambaModelIntegrationTest(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ import gc
|
|||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, JetMoeConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
|
|
@ -299,10 +298,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||
test_disk_offload_bin = False
|
||||
test_disk_offload_safetensors = False
|
||||
|
||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||
def test_new_cache_format(self, num_beams, do_sample):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = JetMoeModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ import tempfile
|
|||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, ZambaConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
|
|
@ -551,11 +550,6 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||
"""
|
||||
self.skipTest(reason="Zamba flash attention does not support right padding")
|
||||
|
||||
@unittest.skip(reason="Zamba has its own special cache type")
|
||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||
def test_new_cache_format(self, num_beams, do_sample):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class ZambaModelIntegrationTest(unittest.TestCase):
|
||||
|
|
|
|||
Loading…
Reference in a new issue