mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fixed dynamic cache
This commit is contained in:
parent
80b49d721b
commit
5ccb79c16d
1 changed files with 48 additions and 55 deletions
|
|
@ -36,19 +36,28 @@ class Cache(torch.Tensor):
|
|||
k: v.default for k, v in init_signature.parameters.items() if v.default is not inspect.Parameter.empty
|
||||
}
|
||||
|
||||
for argument in ["dtype", "device", "requires_grad"]:
|
||||
for argument in ["dtype", "device"]:
|
||||
if argument in init_arguments:
|
||||
argument_index = init_arguments.index(argument)
|
||||
if len(args) > argument_index:
|
||||
wrapper_kwargs[argument] = args[argument_index]
|
||||
elif argument in kwargs:
|
||||
arg_idx = init_arguments.index(argument)
|
||||
if len(args) > arg_idx and args[arg_idx] is not None:
|
||||
wrapper_kwargs[argument] = args[arg_idx]
|
||||
elif kwargs.get(argument, None) is not None:
|
||||
wrapper_kwargs[argument] = kwargs[argument]
|
||||
elif argument in init_defaults:
|
||||
elif init_defaults[argument] is not None:
|
||||
wrapper_kwargs[argument] = init_defaults[argument]
|
||||
|
||||
self = torch.Tensor._make_wrapper_subclass(cls, (), **wrapper_kwargs)
|
||||
cls.__init__(self, *args, **kwargs)
|
||||
if "cache_config" in init_arguments:
|
||||
cache_config_idx = init_arguments.index("cache_config")
|
||||
if len(args) > cache_config_idx and args[cache_config_idx] is not None:
|
||||
wrapper_kwargs["device"] = args[cache_config_idx].device
|
||||
elif kwargs.get("cache_config", None) is not None:
|
||||
wrapper_kwargs["device"] = kwargs["cache_config"].device
|
||||
elif init_defaults["cache_config"] is not None:
|
||||
wrapper_kwargs["device"] = init_defaults["cache_config"].device
|
||||
|
||||
self = torch.Tensor._make_wrapper_subclass(cls, (), **wrapper_kwargs, requires_grad=False)
|
||||
# we create a dummy empty tensor for generic tensor flattening/unflattening
|
||||
self._empty_tensor = torch.tensor([], **wrapper_kwargs, requires_grad=False)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
|
|
@ -66,42 +75,27 @@ class Cache(torch.Tensor):
|
|||
# I think `if past_key_values is not None:` should be used instead
|
||||
return self is not None # True
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
# for the many places where `self.device` or `self.dtype` is set
|
||||
if name in ["dtype", "device"]:
|
||||
self.to(value)
|
||||
else:
|
||||
return super().__setattr__(name, value)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
# originals
|
||||
wrapper_kwargs = {
|
||||
"dtype": getattr(self, "dtype", None),
|
||||
"device": getattr(self, "device", None),
|
||||
"requires_grad": getattr(self, "requires_grad", False),
|
||||
}
|
||||
wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)}
|
||||
|
||||
# overrides
|
||||
for arg in list(args) + list(kwargs.values()):
|
||||
if isinstance(arg, torch.device):
|
||||
if isinstance(arg, (torch.device, str, int)):
|
||||
wrapper_kwargs["device"] = arg
|
||||
elif isinstance(arg, torch.dtype):
|
||||
wrapper_kwargs["dtype"] = arg
|
||||
|
||||
new_tensor_wrapper = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
|
||||
new_tensor_wrapper.__dict__ = self.__dict__
|
||||
self = new_tensor_wrapper
|
||||
return self
|
||||
# new wrapper
|
||||
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
|
||||
new_self.__dict__ = {k: v for k, v in self.__dict__.items() if k not in ["device", "dtype"]}
|
||||
return new_self
|
||||
|
||||
def clone(self):
|
||||
wrapper_kwargs = {
|
||||
"dtype": getattr(self, "dtype", None),
|
||||
"device": getattr(self, "device", None),
|
||||
"requires_grad": getattr(self, "requires_grad", False),
|
||||
}
|
||||
new_tensor_wrapper = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
|
||||
new_tensor_wrapper.__dict__ = copy.deepcopy(self.__dict__)
|
||||
return new_tensor_wrapper
|
||||
wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)}
|
||||
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs, requires_grad=False)
|
||||
new_self.__dict__ = copy.deepcopy(self.__dict__)
|
||||
return new_self
|
||||
|
||||
def update(
|
||||
self,
|
||||
|
|
@ -375,7 +369,7 @@ class StaticCacheConfig(CacheConfig):
|
|||
|
||||
cache_implementation = "static"
|
||||
|
||||
def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
|
||||
def __init__(self, batch_size: int, max_cache_len: int, device: Union[str, torch.device] = torch.device("cpu")):
|
||||
self.batch_size = batch_size
|
||||
self.max_cache_len = max_cache_len
|
||||
self.device = device
|
||||
|
|
@ -432,6 +426,16 @@ class DynamicCache(Cache):
|
|||
```
|
||||
"""
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
return ["_empty_tensor"], {"_seen_tokens": self._seen_tokens}
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, meta, _, __):
|
||||
cache = DynamicCache()
|
||||
cache._seen_tokens = meta["_seen_tokens"]
|
||||
cache._empty_tensor = inner_tensors["_empty_tensor"]
|
||||
return cache
|
||||
|
||||
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
||||
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -519,7 +523,7 @@ class DynamicCache(Cache):
|
|||
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
||||
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
|
||||
)
|
||||
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
||||
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else torch.tensor(0)
|
||||
return layer_seq_length
|
||||
|
||||
def get_max_cache_shape(self) -> Optional[int]:
|
||||
|
|
@ -746,9 +750,6 @@ class QuantizedCache(DynamicCache):
|
|||
self.axis_key = cache_config.axis_key
|
||||
self.axis_value = cache_config.axis_value
|
||||
self.compute_dtype = cache_config.compute_dtype
|
||||
self.device = cache_config.device
|
||||
|
||||
super().__init__()
|
||||
|
||||
def update(
|
||||
self,
|
||||
|
|
@ -848,7 +849,7 @@ class QuantoQuantizedCache(QuantizedCache):
|
|||
raise ImportError(
|
||||
f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}."
|
||||
)
|
||||
from optimum.quanto import MaxOptimizer, qint2, qint4
|
||||
from optimum.quanto import MaxOptimizer, qint2, qint4 # type: ignore
|
||||
|
||||
if self.nbits not in [2, 4]:
|
||||
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
|
||||
|
|
@ -867,7 +868,7 @@ class QuantoQuantizedCache(QuantizedCache):
|
|||
def _quantize(self, tensor, axis):
|
||||
# We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore
|
||||
if is_optimum_quanto_available():
|
||||
from optimum.quanto import quantize_weight
|
||||
from optimum.quanto import quantize_weight # type: ignore
|
||||
|
||||
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
|
||||
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
|
||||
|
|
@ -1176,7 +1177,7 @@ class StaticCache(Cache):
|
|||
config: PretrainedConfig,
|
||||
batch_size: int = None,
|
||||
max_cache_len: int = None,
|
||||
device: torch.device = None,
|
||||
device: Union[torch.device, str] = torch.device("meta"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
max_batch_size: Optional[int] = None,
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
|
|
@ -1195,8 +1196,6 @@ class StaticCache(Cache):
|
|||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
|
||||
self.dtype = dtype
|
||||
self.device = torch.device(device) if device is not None else torch.device("meta")
|
||||
self.num_key_value_heads = (
|
||||
config.num_attention_heads
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
|
|
@ -1366,7 +1365,7 @@ class SlidingWindowCache(StaticCache):
|
|||
config: PretrainedConfig,
|
||||
batch_size: int = None,
|
||||
max_cache_len: int = None,
|
||||
device: torch.device = None,
|
||||
device: Union[torch.device, str] = torch.device("meta"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
max_batch_size: Optional[int] = None,
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
|
|
@ -1681,7 +1680,7 @@ class HybridCache(Cache):
|
|||
config: PretrainedConfig,
|
||||
batch_size: int = None,
|
||||
max_cache_len: int = None,
|
||||
device: Union[torch.device, str] = None,
|
||||
device: Union[torch.device, str] = torch.device("meta"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
max_batch_size: Optional[int] = None,
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
|
|
@ -1710,7 +1709,6 @@ class HybridCache(Cache):
|
|||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||
)
|
||||
|
||||
self.device = torch.device(device) if device is not None else torch.device("meta")
|
||||
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
|
||||
self.is_sliding = torch.tensor(
|
||||
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
|
||||
|
|
@ -1843,7 +1841,7 @@ class HybridCache(Cache):
|
|||
return self.max_batch_size
|
||||
|
||||
|
||||
class MambaCache:
|
||||
class MambaCache(Cache):
|
||||
"""
|
||||
Cache for mamba model which does not have attention mechanism and key value states.
|
||||
|
||||
|
|
@ -1900,7 +1898,7 @@ class MambaCache:
|
|||
config: PretrainedConfig,
|
||||
batch_size: int = None,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
device: Union[torch.device, str] = torch.device("meta"),
|
||||
max_batch_size: Optional[int] = None,
|
||||
):
|
||||
if batch_size is not None:
|
||||
|
|
@ -1908,12 +1906,10 @@ class MambaCache:
|
|||
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
||||
)
|
||||
self.dtype = dtype
|
||||
self.max_batch_size = batch_size or max_batch_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.ssm_state_size = config.state_size
|
||||
self.conv_kernel_size = config.conv_kernel
|
||||
self.device = torch.device(device) if device is not None else torch.device("meta")
|
||||
|
||||
self.conv_states: List[torch.Tensor] = []
|
||||
self.ssm_states: List[torch.Tensor] = []
|
||||
|
|
@ -2043,17 +2039,14 @@ class OffloadedStaticCache(StaticCache):
|
|||
config: PretrainedConfig,
|
||||
max_batch_size: int,
|
||||
max_cache_len: Optional[int],
|
||||
device: Union[str, torch.device],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Union[torch.device, str] = torch.device("meta"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
) -> None:
|
||||
super(Cache, self).__init__()
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
|
||||
self.offload_device = torch.device(offload_device)
|
||||
self.dtype = dtype if dtype is not None else torch.float32
|
||||
|
||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
|
|
|
|||
Loading…
Reference in a new issue