Fix low memory beam search (#34746)

* fix

* higher max positions in tests
This commit is contained in:
Raushan Turganbay 2024-11-20 07:46:35 +01:00 committed by GitHub
parent 145fbd46cb
commit 9470d65324
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 10 additions and 5 deletions

View file

@ -528,7 +528,7 @@ class DynamicCache(Cache):
cache = cls()
for idx in range(len(splits[0])):
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx] != []]
if key_cache != []:
layer_keys = torch.cat(key_cache, dim=0)
layer_values = torch.cat(value_cache, dim=0)
@ -1523,7 +1523,10 @@ class EncoderDecoderCache(Cache):
self.check_dynamic_cache(self.crop.__name__)
self.self_attention_cache.crop(maximum_length)
def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def batch_split(
self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
) -> "List[EncoderDecoderCache]":
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
self.check_dynamic_cache(self.batch_split.__name__)
@ -1536,7 +1539,10 @@ class EncoderDecoderCache(Cache):
return out
@classmethod
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def from_batch_splits(
cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int = None
) -> "EncoderDecoderCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
self_attention_cache = DynamicCache()

View file

@ -1046,7 +1046,6 @@ class GenerationTesterMixin:
self.assertListEqual(low_output.tolist(), high_output.tolist())
@pytest.mark.generate
@unittest.skip("Started to break with https://github.com/huggingface/transformers/pull/33703")
def test_beam_search_low_memory(self):
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:

View file

@ -330,7 +330,7 @@ class Blip2TextModelDecoderOnlyTester:
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=20,
max_position_embeddings=512,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,