add clone and to

This commit is contained in:
IlyasMoutawwakil 2025-01-22 15:42:43 +01:00
parent 4950a9e3f0
commit 6e9799c817

View file

@ -27,8 +27,7 @@ class Cache(torch.Tensor):
@staticmethod
def __new__(cls, *args, **kwargs):
# We use a tensor wrapper to allow for torch script tracing when using the cache as an input to nn.Module
# dtype and device don't need to be in the subclass's __init__ (unless they are used for something)
# We use a tensor wrapper to allow for torch script tracing when using the cache as an input in a forward method
wrapper_kwargs = {}
init_signature = inspect.signature(cls.__init__)
@ -54,19 +53,38 @@ class Cache(torch.Tensor):
), f"Class {cls.__name__} is a tensor wrapper and does not implement method {func.__name__}"
return getattr(cls, func.__name__)(*args, **kwargs)
def __bool__(self):
# in many places, past_key_values is checked for not being None using `if past_key_values:`
# I think `if past_key_values is not None:` should be used instead
return True
def __repr__(self):
return f"{self.__class__.__name__}()"
def __bool__(self):
# in many places, past_key_values is checked for not being None using `if past_key_values:`
# I think `if past_key_values is not None:` should be used instead
return self is not None # True
def to(self, *args, **kwargs):
# We override this method to prevent the cache from being moved to a different device
# It can be implemented in a way that moves all contained tensors to the new device/dtype
def reccursive_to(elm):
if isinstance(elm, dict):
return {k: reccursive_to(v) for k, v in elm.items()}
elif isinstance(elm, (list, tuple, set)):
return type(elm)(reccursive_to(t) for t in elm)
elif isinstance(elm, torch.Tensor):
return elm.to(*args, **kwargs)
else:
return elm
self.__dict__ = reccursive_to(self.__dict__)
return self
def clone(self):
wrapper_kwargs = {
"dtype": getattr(self, "dtype", None),
"device": getattr(self, "device", None),
"requires_grad": getattr(self, "requires_grad", None),
}
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
new_self.__dict__ = copy.deepcopy(self.__dict__)
return new_self
def update(
self,
key_states: torch.Tensor,
@ -267,7 +285,6 @@ class QuantizedCacheConfig(CacheConfig):
q_group_size: Optional[int] = 64,
residual_length: Optional[int] = 128,
compute_dtype: Optional[torch.dtype] = torch.float16,
device: Optional[str] = "cpu",
):
self.backend = backend
self.nbits = nbits
@ -276,7 +293,6 @@ class QuantizedCacheConfig(CacheConfig):
self.q_group_size = q_group_size
self.residual_length = residual_length
self.compute_dtype = compute_dtype
self.device = device
def validate(self):
"""Validates if the arguments passed are correct"""
@ -339,10 +355,9 @@ class StaticCacheConfig(CacheConfig):
cache_implementation = "static"
def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
def __init__(self, batch_size: int, max_cache_len: int):
self.batch_size = batch_size
self.max_cache_len = max_cache_len
self.device = device
def validate(self):
"""Validates if the arguments passed are correct"""
@ -710,7 +725,7 @@ class QuantizedCache(DynamicCache):
self.axis_key = cache_config.axis_key
self.axis_value = cache_config.axis_value
self.compute_dtype = cache_config.compute_dtype
self.device = cache_config.device
self.to(cache_config.device)
def update(
self,
@ -1955,7 +1970,8 @@ class OffloadedStaticCache(StaticCache):
) -> None:
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device) if layer_device_map is None else layer_device_map[0]
if layer_device_map is not None:
self.to(layer_device_map[0])
self.offload_device = torch.device(offload_device)
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads