mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix cache_utils for optimum.quanto kvcache quantization (#34750)
* add co-author Co-authored-by: w3rew <w3rew@users.noreply.github.com> * fix docs * fix cache * remove print --------- Co-authored-by: w3rew <w3rew@users.noreply.github.com>
This commit is contained in:
parent
4bff54f921
commit
ce1d328e3b
3 changed files with 14 additions and 8 deletions
|
|
@ -55,7 +55,7 @@ Use the table below to help you decide which quantization method to use.
|
|||
| GGUF / GGML (llama.cpp) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 1 - 8 | 🔴 | [See GGUF section](../gguf) | [See GGUF section](../gguf) | https://github.com/ggerganov/llama.cpp |
|
||||
| [GPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 2 - 3 - 4 - 8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |
|
||||
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
|
||||
| [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto |
|
||||
| [optimum-quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
|
||||
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
|
||||
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | partial support (int4 weight only) | 🔴 | | 4 / 8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
|
||||
|
||||
|
|
|
|||
|
|
@ -14,16 +14,16 @@ rendered properly in your Markdown viewer.
|
|||
|
||||
-->
|
||||
|
||||
# Quanto
|
||||
# Optimum-quanto
|
||||
|
||||
<Tip>
|
||||
|
||||
Try Quanto + transformers with this [notebook](https://colab.research.google.com/drive/16CXfVmtdQvciSh9BopZUDYcmXCDpvgrT?usp=sharing)!
|
||||
Try optimum-quanto + transformers with this [notebook](https://colab.research.google.com/drive/16CXfVmtdQvciSh9BopZUDYcmXCDpvgrT?usp=sharing)!
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
[🤗 Quanto](https://github.com/huggingface/quanto) library is a versatile pytorch quantization toolkit. The quantization method used is the linear quantization. Quanto provides several unique features such as:
|
||||
[🤗 optimum-quanto](https://github.com/huggingface/optimum-quanto) library is a versatile pytorch quantization toolkit. The quantization method used is the linear quantization. Quanto provides several unique features such as:
|
||||
|
||||
- weights quantization (`float8`,`int8`,`int4`,`int2`)
|
||||
- activation quantization (`float8`,`int8`)
|
||||
|
|
@ -37,12 +37,12 @@ Try Quanto + transformers with this [notebook](https://colab.research.google.com
|
|||
Before you begin, make sure the following libraries are installed:
|
||||
|
||||
```bash
|
||||
pip install quanto accelerate transformers
|
||||
pip install optimum-quanto accelerate transformers
|
||||
```
|
||||
|
||||
Now you can quantize a model by passing [`QuantoConfig`] object in the [`~PreTrainedModel.from_pretrained`] method. This works for any model in any modality, as long as it contains `torch.nn.Linear` layers.
|
||||
|
||||
The integration with transformers only supports weights quantization. For the more complex use case such as activation quantization, calibration and quantization aware training, you should use [quanto](https://github.com/huggingface/quanto) library instead.
|
||||
The integration with transformers only supports weights quantization. For the more complex use case such as activation quantization, calibration and quantization aware training, you should use [optimum-quanto](https://github.com/huggingface/optimum-quanto) library instead.
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
|
||||
|
|
@ -55,7 +55,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cud
|
|||
|
||||
Note that serialization is not supported yet with transformers but it is coming soon! If you want to save the model, you can use quanto library instead.
|
||||
|
||||
Quanto library uses linear quantization algorithm for quantization. Even though this is a basic quantization technique, we get very good results! Have a look at the following benchmark (llama-2-7b on perplexity metric). You can find more benchmarks [here](https://github.com/huggingface/quanto/tree/main/bench/generation)
|
||||
Optimum-quanto library uses linear quantization algorithm for quantization. Even though this is a basic quantization technique, we get very good results! Have a look at the following benchmark (llama-2-7b on perplexity metric). You can find more benchmarks [here](https://github.com/huggingface/optimum-quanto/tree/main/bench/generation)
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
|
|
|
|||
|
|
@ -784,6 +784,11 @@ class QuantoQuantizedCache(QuantizedCache):
|
|||
super().__init__(cache_config)
|
||||
|
||||
if is_optimum_quanto_available():
|
||||
optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto"))
|
||||
if optimum_quanto_version <= version.parse("0.2.5"):
|
||||
raise ImportError(
|
||||
f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}."
|
||||
)
|
||||
from optimum.quanto import MaxOptimizer, qint2, qint4
|
||||
elif is_quanto_available():
|
||||
logger.warning_once(
|
||||
|
|
@ -816,7 +821,8 @@ class QuantoQuantizedCache(QuantizedCache):
|
|||
if is_optimum_quanto_available():
|
||||
from optimum.quanto import quantize_weight
|
||||
|
||||
qtensor = quantize_weight(tensor, self.qtype, axis, self.q_group_size)
|
||||
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
|
||||
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
|
||||
return qtensor
|
||||
elif is_quanto_available():
|
||||
logger.warning_once(
|
||||
|
|
|
|||
Loading…
Reference in a new issue