mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
add clone and to
This commit is contained in:
parent
4950a9e3f0
commit
6e9799c817
1 changed files with 31 additions and 15 deletions
|
|
@ -27,8 +27,7 @@ class Cache(torch.Tensor):
|
|||
|
||||
@staticmethod
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# We use a tensor wrapper to allow for torch script tracing when using the cache as an input to nn.Module
|
||||
# dtype and device don't need to be in the subclass's __init__ (unless they are used for something)
|
||||
# We use a tensor wrapper to allow for torch script tracing when using the cache as an input in a forward method
|
||||
|
||||
wrapper_kwargs = {}
|
||||
init_signature = inspect.signature(cls.__init__)
|
||||
|
|
@ -54,19 +53,38 @@ class Cache(torch.Tensor):
|
|||
), f"Class {cls.__name__} is a tensor wrapper and does not implement method {func.__name__}"
|
||||
return getattr(cls, func.__name__)(*args, **kwargs)
|
||||
|
||||
def __bool__(self):
|
||||
# in many places, past_key_values is checked for not being None using `if past_key_values:`
|
||||
# I think `if past_key_values is not None:` should be used instead
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
def __bool__(self):
|
||||
# in many places, past_key_values is checked for not being None using `if past_key_values:`
|
||||
# I think `if past_key_values is not None:` should be used instead
|
||||
return self is not None # True
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
# We override this method to prevent the cache from being moved to a different device
|
||||
# It can be implemented in a way that moves all contained tensors to the new device/dtype
|
||||
def reccursive_to(elm):
|
||||
if isinstance(elm, dict):
|
||||
return {k: reccursive_to(v) for k, v in elm.items()}
|
||||
elif isinstance(elm, (list, tuple, set)):
|
||||
return type(elm)(reccursive_to(t) for t in elm)
|
||||
elif isinstance(elm, torch.Tensor):
|
||||
return elm.to(*args, **kwargs)
|
||||
else:
|
||||
return elm
|
||||
|
||||
self.__dict__ = reccursive_to(self.__dict__)
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
wrapper_kwargs = {
|
||||
"dtype": getattr(self, "dtype", None),
|
||||
"device": getattr(self, "device", None),
|
||||
"requires_grad": getattr(self, "requires_grad", None),
|
||||
}
|
||||
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
|
||||
new_self.__dict__ = copy.deepcopy(self.__dict__)
|
||||
return new_self
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
|
|
@ -267,7 +285,6 @@ class QuantizedCacheConfig(CacheConfig):
|
|||
q_group_size: Optional[int] = 64,
|
||||
residual_length: Optional[int] = 128,
|
||||
compute_dtype: Optional[torch.dtype] = torch.float16,
|
||||
device: Optional[str] = "cpu",
|
||||
):
|
||||
self.backend = backend
|
||||
self.nbits = nbits
|
||||
|
|
@ -276,7 +293,6 @@ class QuantizedCacheConfig(CacheConfig):
|
|||
self.q_group_size = q_group_size
|
||||
self.residual_length = residual_length
|
||||
self.compute_dtype = compute_dtype
|
||||
self.device = device
|
||||
|
||||
def validate(self):
|
||||
"""Validates if the arguments passed are correct"""
|
||||
|
|
@ -339,10 +355,9 @@ class StaticCacheConfig(CacheConfig):
|
|||
|
||||
cache_implementation = "static"
|
||||
|
||||
def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
|
||||
def __init__(self, batch_size: int, max_cache_len: int):
|
||||
self.batch_size = batch_size
|
||||
self.max_cache_len = max_cache_len
|
||||
self.device = device
|
||||
|
||||
def validate(self):
|
||||
"""Validates if the arguments passed are correct"""
|
||||
|
|
@ -710,7 +725,7 @@ class QuantizedCache(DynamicCache):
|
|||
self.axis_key = cache_config.axis_key
|
||||
self.axis_value = cache_config.axis_value
|
||||
self.compute_dtype = cache_config.compute_dtype
|
||||
self.device = cache_config.device
|
||||
self.to(cache_config.device)
|
||||
|
||||
def update(
|
||||
self,
|
||||
|
|
@ -1955,7 +1970,8 @@ class OffloadedStaticCache(StaticCache):
|
|||
) -> None:
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
self.device = torch.device(device) if layer_device_map is None else layer_device_map[0]
|
||||
if layer_device_map is not None:
|
||||
self.to(layer_device_map[0])
|
||||
self.offload_device = torch.device(offload_device)
|
||||
|
||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||
|
|
|
|||
Loading…
Reference in a new issue