diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index b2be3f238..911966323 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -9,12 +9,7 @@ import torch from packaging import version from .configuration_utils import PretrainedConfig -from .utils import ( - is_hqq_available, - is_optimum_quanto_available, - is_torchdynamo_compiling, - logging, -) +from .utils import is_hqq_available, is_optimum_quanto_available, logging from .utils.deprecation import deprecate_kwarg @@ -24,7 +19,7 @@ if is_hqq_available(): logger = logging.get_logger(__name__) -class Cache(torch.nn.Module): +class Cache: """ Base, abstract class for all caches. The actual data structure is specific to each subclass. """ @@ -1144,18 +1139,10 @@ class StaticCache(Cache): layer_device = self.device new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) - # Notes: - # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case - # it is not needed anyway) - # 2. `torch.export()` requires mutations to be registered as buffers. - if not is_torchdynamo_compiling(): - self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device)) - self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device)) - new_layer_key_cache = getattr(self, f"key_cache_{idx}") - new_layer_value_cache = getattr(self, f"value_cache_{idx}") - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index a0cbc8ba4..4ee525ddf 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -16,10 +16,7 @@ from ..utils.import_utils import is_torch_available if is_torch_available(): - from transformers import ( - PreTrainedModel, - StaticCache, - ) + from transformers import PreTrainedModel, StaticCache from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3 @@ -72,9 +69,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): config=self.model.config, batch_size=self.model.generation_config.cache_config.batch_size, max_cache_len=self.model.generation_config.cache_config.max_cache_len, - dtype=self.model.dtype, device=self.model.generation_config.cache_config.device, + dtype=self.model.dtype, ) + for i in range(len(self.static_cache.key_cache)): + self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) + self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) + self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures) if self.is_causal: causal_mask = torch.tril( @@ -109,12 +110,15 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): """ _, seqlen = input_ids.shape attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None + position_ids = cache_position.unsqueeze(0) + past_key_values = self.static_cache + outs = self.model( input_ids=input_ids, attention_mask=attn_mask, - position_ids=cache_position.unsqueeze(0), + position_ids=position_ids, cache_position=cache_position, - past_key_values=self.static_cache, + past_key_values=past_key_values, use_cache=True, ) return outs.logits @@ -143,7 +147,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): prompt_token_len = prompt_token_ids.shape[-1] max_generation_length = prompt_token_len + max_new_tokens for buffer_name, buffer in exported_program.named_buffers(): - if buffer_name.startswith("static_cache.key_cache"): + if buffer_name.startswith("key_cache"): max_cache_len = buffer.shape[2] max_generation_length = min(max_generation_length, max_cache_len) break diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index d67b02663..a8b8b1eff 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -215,11 +215,11 @@ class CacheTest(unittest.TestCase): # Check if the exported model is configured with the `StaticCache` correctly n_static_key_caches = n_static_value_caches = 0 for buffer_name, buffer in exported_program.named_buffers(): - if buffer_name.startswith("static_cache.key_cache"): + if buffer_name.startswith("key_cache"): self.assertTrue(buffer.shape[0] == batch_size) self.assertTrue(buffer.shape[2] == max_cache_len) n_static_key_caches = n_static_key_caches + 1 - if buffer_name.startswith("static_cache.value_cache"): + if buffer_name.startswith("value_cache"): self.assertTrue(buffer.shape[0] == batch_size) self.assertTrue(buffer.shape[2] == max_cache_len) n_static_value_caches = n_static_value_caches + 1 @@ -364,7 +364,7 @@ class CacheIntegrationTest(unittest.TestCase): input_ids = gen_out # We went well beyond the cache length - self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5) + self.assertTrue(input_ids.shape[1] > cache.get_max_cache_shape() * 1.5) # And it still produces a coherent english decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) @@ -619,4 +619,4 @@ class CacheIntegrationTest(unittest.TestCase): "You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week", 'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the' ] # fmt: skip - self.assertTrue(responses == EXPECTED_DECODED_TEXT) + self.assertEqual(responses, EXPECTED_DECODED_TEXT)