mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix low memory beam search (#34746)
* fix * higher max positions in tests
This commit is contained in:
parent
145fbd46cb
commit
9470d65324
3 changed files with 10 additions and 5 deletions
|
|
@ -528,7 +528,7 @@ class DynamicCache(Cache):
|
|||
cache = cls()
|
||||
for idx in range(len(splits[0])):
|
||||
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
|
||||
value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
|
||||
value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx] != []]
|
||||
if key_cache != []:
|
||||
layer_keys = torch.cat(key_cache, dim=0)
|
||||
layer_values = torch.cat(value_cache, dim=0)
|
||||
|
|
@ -1523,7 +1523,10 @@ class EncoderDecoderCache(Cache):
|
|||
self.check_dynamic_cache(self.crop.__name__)
|
||||
self.self_attention_cache.crop(maximum_length)
|
||||
|
||||
def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
|
||||
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
||||
def batch_split(
|
||||
self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
|
||||
) -> "List[EncoderDecoderCache]":
|
||||
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
|
||||
`_split_model_inputs()` in `generation.utils`"""
|
||||
self.check_dynamic_cache(self.batch_split.__name__)
|
||||
|
|
@ -1536,7 +1539,10 @@ class EncoderDecoderCache(Cache):
|
|||
return out
|
||||
|
||||
@classmethod
|
||||
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
|
||||
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
||||
def from_batch_splits(
|
||||
cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int = None
|
||||
) -> "EncoderDecoderCache":
|
||||
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
|
||||
`generation.utils`"""
|
||||
self_attention_cache = DynamicCache()
|
||||
|
|
|
|||
|
|
@ -1046,7 +1046,6 @@ class GenerationTesterMixin:
|
|||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Started to break with https://github.com/huggingface/transformers/pull/33703")
|
||||
def test_beam_search_low_memory(self):
|
||||
# Check that choosing 'low_memory' does not change the model output
|
||||
for model_class in self.all_generative_model_classes:
|
||||
|
|
|
|||
|
|
@ -330,7 +330,7 @@ class Blip2TextModelDecoderOnlyTester:
|
|||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=20,
|
||||
max_position_embeddings=512,
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
|
|
|
|||
Loading…
Reference in a new issue