diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 6acf0c256..eead49b33 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -27,8 +27,7 @@ class Cache(torch.Tensor): @staticmethod def __new__(cls, *args, **kwargs): - # We use a tensor wrapper to allow for torch script tracing when using the cache as an input to nn.Module - # dtype and device don't need to be in the subclass's __init__ (unless they are used for something) + # We use a tensor wrapper to allow for torch script tracing when using the cache as an input in a forward method wrapper_kwargs = {} init_signature = inspect.signature(cls.__init__) @@ -54,19 +53,38 @@ class Cache(torch.Tensor): ), f"Class {cls.__name__} is a tensor wrapper and does not implement method {func.__name__}" return getattr(cls, func.__name__)(*args, **kwargs) - def __bool__(self): - # in many places, past_key_values is checked for not being None using `if past_key_values:` - # I think `if past_key_values is not None:` should be used instead - return True - def __repr__(self): return f"{self.__class__.__name__}()" + def __bool__(self): + # in many places, past_key_values is checked for not being None using `if past_key_values:` + # I think `if past_key_values is not None:` should be used instead + return self is not None # True + def to(self, *args, **kwargs): - # We override this method to prevent the cache from being moved to a different device - # It can be implemented in a way that moves all contained tensors to the new device/dtype + def reccursive_to(elm): + if isinstance(elm, dict): + return {k: reccursive_to(v) for k, v in elm.items()} + elif isinstance(elm, (list, tuple, set)): + return type(elm)(reccursive_to(t) for t in elm) + elif isinstance(elm, torch.Tensor): + return elm.to(*args, **kwargs) + else: + return elm + + self.__dict__ = reccursive_to(self.__dict__) return self + def clone(self): + wrapper_kwargs = { + "dtype": getattr(self, "dtype", None), + "device": getattr(self, "device", None), + "requires_grad": getattr(self, "requires_grad", None), + } + new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs) + new_self.__dict__ = copy.deepcopy(self.__dict__) + return new_self + def update( self, key_states: torch.Tensor, @@ -267,7 +285,6 @@ class QuantizedCacheConfig(CacheConfig): q_group_size: Optional[int] = 64, residual_length: Optional[int] = 128, compute_dtype: Optional[torch.dtype] = torch.float16, - device: Optional[str] = "cpu", ): self.backend = backend self.nbits = nbits @@ -276,7 +293,6 @@ class QuantizedCacheConfig(CacheConfig): self.q_group_size = q_group_size self.residual_length = residual_length self.compute_dtype = compute_dtype - self.device = device def validate(self): """Validates if the arguments passed are correct""" @@ -339,10 +355,9 @@ class StaticCacheConfig(CacheConfig): cache_implementation = "static" - def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): + def __init__(self, batch_size: int, max_cache_len: int): self.batch_size = batch_size self.max_cache_len = max_cache_len - self.device = device def validate(self): """Validates if the arguments passed are correct""" @@ -710,7 +725,7 @@ class QuantizedCache(DynamicCache): self.axis_key = cache_config.axis_key self.axis_value = cache_config.axis_value self.compute_dtype = cache_config.compute_dtype - self.device = cache_config.device + self.to(cache_config.device) def update( self, @@ -1955,7 +1970,8 @@ class OffloadedStaticCache(StaticCache): ) -> None: self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len - self.device = torch.device(device) if layer_device_map is None else layer_device_map[0] + if layer_device_map is not None: + self.to(layer_device_map[0]) self.offload_device = torch.device(offload_device) # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads