diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 822692315..ad497581c 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1142,7 +1142,7 @@ class StaticCache(Cache): self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] # Note: There will be significant perf decrease if switching to use 5D tensors instead. - cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) + cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) for idx in range(config.num_hidden_layers): if layer_device_map is not None: layer_device = layer_device_map[idx]