Remove batch size argument warning when unjustified (#35519)

* use max batch size

* revert unneccessary change

---------

Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
This commit is contained in:
Quinten Roets 2025-01-16 08:48:11 -08:00 committed by GitHub
parent 91be6a5eb2
commit 57bf1a12a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1142,7 +1142,7 @@ class StaticCache(Cache):
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for idx in range(config.num_hidden_layers):
if layer_device_map is not None:
layer_device = layer_device_map[idx]