fixed dynamic cache

This commit is contained in:
IlyasMoutawwakil 2025-01-23 16:45:28 +01:00
parent 80b49d721b
commit 5ccb79c16d

View file

@ -36,19 +36,28 @@ class Cache(torch.Tensor):
k: v.default for k, v in init_signature.parameters.items() if v.default is not inspect.Parameter.empty
}
for argument in ["dtype", "device", "requires_grad"]:
for argument in ["dtype", "device"]:
if argument in init_arguments:
argument_index = init_arguments.index(argument)
if len(args) > argument_index:
wrapper_kwargs[argument] = args[argument_index]
elif argument in kwargs:
arg_idx = init_arguments.index(argument)
if len(args) > arg_idx and args[arg_idx] is not None:
wrapper_kwargs[argument] = args[arg_idx]
elif kwargs.get(argument, None) is not None:
wrapper_kwargs[argument] = kwargs[argument]
elif argument in init_defaults:
elif init_defaults[argument] is not None:
wrapper_kwargs[argument] = init_defaults[argument]
self = torch.Tensor._make_wrapper_subclass(cls, (), **wrapper_kwargs)
cls.__init__(self, *args, **kwargs)
if "cache_config" in init_arguments:
cache_config_idx = init_arguments.index("cache_config")
if len(args) > cache_config_idx and args[cache_config_idx] is not None:
wrapper_kwargs["device"] = args[cache_config_idx].device
elif kwargs.get("cache_config", None) is not None:
wrapper_kwargs["device"] = kwargs["cache_config"].device
elif init_defaults["cache_config"] is not None:
wrapper_kwargs["device"] = init_defaults["cache_config"].device
self = torch.Tensor._make_wrapper_subclass(cls, (), **wrapper_kwargs, requires_grad=False)
# we create a dummy empty tensor for generic tensor flattening/unflattening
self._empty_tensor = torch.tensor([], **wrapper_kwargs, requires_grad=False)
return self
@classmethod
@ -66,42 +75,27 @@ class Cache(torch.Tensor):
# I think `if past_key_values is not None:` should be used instead
return self is not None # True
def __setattr__(self, name, value):
# for the many places where `self.device` or `self.dtype` is set
if name in ["dtype", "device"]:
self.to(value)
else:
return super().__setattr__(name, value)
def to(self, *args, **kwargs):
# originals
wrapper_kwargs = {
"dtype": getattr(self, "dtype", None),
"device": getattr(self, "device", None),
"requires_grad": getattr(self, "requires_grad", False),
}
wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)}
# overrides
for arg in list(args) + list(kwargs.values()):
if isinstance(arg, torch.device):
if isinstance(arg, (torch.device, str, int)):
wrapper_kwargs["device"] = arg
elif isinstance(arg, torch.dtype):
wrapper_kwargs["dtype"] = arg
new_tensor_wrapper = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
new_tensor_wrapper.__dict__ = self.__dict__
self = new_tensor_wrapper
return self
# new wrapper
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
new_self.__dict__ = {k: v for k, v in self.__dict__.items() if k not in ["device", "dtype"]}
return new_self
def clone(self):
wrapper_kwargs = {
"dtype": getattr(self, "dtype", None),
"device": getattr(self, "device", None),
"requires_grad": getattr(self, "requires_grad", False),
}
new_tensor_wrapper = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
new_tensor_wrapper.__dict__ = copy.deepcopy(self.__dict__)
return new_tensor_wrapper
wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)}
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs, requires_grad=False)
new_self.__dict__ = copy.deepcopy(self.__dict__)
return new_self
def update(
self,
@ -375,7 +369,7 @@ class StaticCacheConfig(CacheConfig):
cache_implementation = "static"
def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
def __init__(self, batch_size: int, max_cache_len: int, device: Union[str, torch.device] = torch.device("cpu")):
self.batch_size = batch_size
self.max_cache_len = max_cache_len
self.device = device
@ -432,6 +426,16 @@ class DynamicCache(Cache):
```
"""
def __tensor_flatten__(self):
return ["_empty_tensor"], {"_seen_tokens": self._seen_tokens}
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, _, __):
cache = DynamicCache()
cache._seen_tokens = meta["_seen_tokens"]
cache._empty_tensor = inner_tensors["_empty_tensor"]
return cache
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
super().__init__()
@ -519,7 +523,7 @@ class DynamicCache(Cache):
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
)
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else torch.tensor(0)
return layer_seq_length
def get_max_cache_shape(self) -> Optional[int]:
@ -746,9 +750,6 @@ class QuantizedCache(DynamicCache):
self.axis_key = cache_config.axis_key
self.axis_value = cache_config.axis_value
self.compute_dtype = cache_config.compute_dtype
self.device = cache_config.device
super().__init__()
def update(
self,
@ -848,7 +849,7 @@ class QuantoQuantizedCache(QuantizedCache):
raise ImportError(
f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}."
)
from optimum.quanto import MaxOptimizer, qint2, qint4
from optimum.quanto import MaxOptimizer, qint2, qint4 # type: ignore
if self.nbits not in [2, 4]:
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
@ -867,7 +868,7 @@ class QuantoQuantizedCache(QuantizedCache):
def _quantize(self, tensor, axis):
# We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore
if is_optimum_quanto_available():
from optimum.quanto import quantize_weight
from optimum.quanto import quantize_weight # type: ignore
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
@ -1176,7 +1177,7 @@ class StaticCache(Cache):
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: torch.device = None,
device: Union[torch.device, str] = torch.device("meta"),
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
@ -1195,8 +1196,6 @@ class StaticCache(Cache):
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.dtype = dtype
self.device = torch.device(device) if device is not None else torch.device("meta")
self.num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
@ -1366,7 +1365,7 @@ class SlidingWindowCache(StaticCache):
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: torch.device = None,
device: Union[torch.device, str] = torch.device("meta"),
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
@ -1681,7 +1680,7 @@ class HybridCache(Cache):
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: Union[torch.device, str] = None,
device: Union[torch.device, str] = torch.device("meta"),
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
@ -1710,7 +1709,6 @@ class HybridCache(Cache):
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.device = torch.device(device) if device is not None else torch.device("meta")
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
self.is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
@ -1843,7 +1841,7 @@ class HybridCache(Cache):
return self.max_batch_size
class MambaCache:
class MambaCache(Cache):
"""
Cache for mamba model which does not have attention mechanism and key value states.
@ -1900,7 +1898,7 @@ class MambaCache:
config: PretrainedConfig,
batch_size: int = None,
dtype: torch.dtype = torch.float16,
device: Optional[Union[torch.device, str]] = None,
device: Union[torch.device, str] = torch.device("meta"),
max_batch_size: Optional[int] = None,
):
if batch_size is not None:
@ -1908,12 +1906,10 @@ class MambaCache:
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
self.dtype = dtype
self.max_batch_size = batch_size or max_batch_size
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.device = torch.device(device) if device is not None else torch.device("meta")
self.conv_states: List[torch.Tensor] = []
self.ssm_states: List[torch.Tensor] = []
@ -2043,17 +2039,14 @@ class OffloadedStaticCache(StaticCache):
config: PretrainedConfig,
max_batch_size: int,
max_cache_len: Optional[int],
device: Union[str, torch.device],
dtype: Optional[torch.dtype] = None,
device: Union[torch.device, str] = torch.device("meta"),
dtype: torch.dtype = torch.float32,
offload_device: Union[str, torch.device] = torch.device("cpu"),
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super(Cache, self).__init__()
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
self.offload_device = torch.device(offload_device)
self.dtype = dtype if dtype is not None else torch.float32
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads