mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
FEAT : Adding VPTQ quantization method to HFQuantizer (#34770)
* init vptq * add integration * add vptq support fix readme * add tests && format * format * address comments * format * format * address comments * format * address comments * remove debug code * Revert "remove debug code" This reverts commit ed3b3eaaba82caf58cb3aa6e865d98e49650cf66. * fix test --------- Co-authored-by: Yang Wang <wyatuestc@gmail.com>
This commit is contained in:
parent
5a2aedca1e
commit
4e27a4009d
21 changed files with 647 additions and 3 deletions
|
|
@ -50,6 +50,9 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/pef
|
|||
# Add aqlm for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2
|
||||
|
||||
# Add vptq for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir vptq
|
||||
|
||||
# Add hqq for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir hqq
|
||||
|
||||
|
|
|
|||
|
|
@ -157,6 +157,8 @@
|
|||
# title: AWQ
|
||||
# - local: quantization/aqlm
|
||||
# title: AQLM
|
||||
# - local: quantization/vptq
|
||||
# title: VPTQ
|
||||
# - local: quantization/quanto
|
||||
# title: Quanto
|
||||
# - local: quantization/eetq
|
||||
|
|
|
|||
|
|
@ -167,6 +167,8 @@
|
|||
title: AWQ
|
||||
- local: quantization/aqlm
|
||||
title: AQLM
|
||||
- local: quantization/vptq
|
||||
title: VPTQ
|
||||
- local: quantization/quanto
|
||||
title: Quanto
|
||||
- local: quantization/eetq
|
||||
|
|
|
|||
|
|
@ -473,7 +473,7 @@ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable
|
|||
Quantization reduces the size of the LLM weights by storing them in a lower precision. This translates to lower memory usage and makes loading LLMs for inference more accessible if you're constrained by your GPUs memory. If you aren't limited by your GPU, you don't necessarily need to quantize your model because it can incur a small latency cost (except for AWQ and fused AWQ modules) due to the extra step required to quantize and dequantize the weights.
|
||||
|
||||
> [!TIP]
|
||||
> There are many quantization libraries (see the [Quantization](./quantization) guide for more details) available, such as Quanto, AQLM, AWQ, and AutoGPTQ. Feel free to try them out and see which one works best for your use case. We also recommend reading the [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) blog post which compares AutoGPTQ and bitsandbytes.
|
||||
> There are many quantization libraries (see the [Quantization](./quantization) guide for more details) available, such as Quanto, AQLM, VPTQ, AWQ, and AutoGPTQ. Feel free to try them out and see which one works best for your use case. We also recommend reading the [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) blog post which compares AutoGPTQ and bitsandbytes.
|
||||
|
||||
Use the Model Memory Calculator below to estimate and compare how much memory is required to load a model. For example, try estimating how much memory it costs to load [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1).
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
|||
|
||||
[[autodoc]] AqlmConfig
|
||||
|
||||
## VptqConfig
|
||||
|
||||
[[autodoc]] VptqConfig
|
||||
|
||||
## AwqConfig
|
||||
|
||||
[[autodoc]] AwqConfig
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ Use the table below to help you decide which quantization method to use.
|
|||
| [optimum-quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
|
||||
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
|
||||
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | partial support (int4 weight only) | 🔴 | | 4 / 8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
|
||||
| [VPTQ](./vptq) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
|
||||
|
||||
<Tip>
|
||||
|
||||
|
|
@ -71,4 +72,4 @@ We value your feedback to help identify bugs before the full release! Check out
|
|||
|
||||
\** bitsandbytes is seeking contributors to help develop and lead the Apple Silicon backend. Interested? Contact them directly via their repo. Stipends may be available through sponsorships.
|
||||
|
||||
</Tip>
|
||||
</Tip>
|
||||
111
docs/source/en/quantization/vptq.md
Normal file
111
docs/source/en/quantization/vptq.md
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# VPTQ
|
||||
|
||||
> [!TIP]
|
||||
> Try VPTQ on [Hugging Face](https://huggingface.co/spaces/microsoft/VPTQ)!
|
||||
> Try VPTQ on [Google Colab](https://colab.research.google.com/github/microsoft/VPTQ/blob/main/notebooks/vptq_example.ipynb)!
|
||||
> Know more about VPTQ on [ArXiv](https://arxiv.org/pdf/2409.17066)!
|
||||
|
||||
Vector Post-Training Quantization ([VPTQ](https://github.com/microsoft/VPTQ)) is a novel Post-Training Quantization method that leverages Vector Quantization to high accuracy on LLMs at an extremely low bit-width (<2-bit). VPTQ can compress 70B, even the 405B model, to 1-2 bits without retraining and maintain high accuracy.
|
||||
|
||||
- Better Accuracy on 1-2 bits, (405B @ <2bit, 70B @ 2bit)
|
||||
- Lightweight Quantization Algorithm: only cost ~17 hours to quantize 405B Llama-3.1
|
||||
- Agile Quantization Inference: low decode overhead, best throughput, and TTFT
|
||||
|
||||
Inference support for VPTQ is released in the `vptq` library. Make sure to install it to run the models:
|
||||
```bash
|
||||
pip install vptq
|
||||
```
|
||||
|
||||
The library provides efficient kernels for NVIDIA/AMD GPU inference.
|
||||
|
||||
To run VPTQ models simply load a model that has been quantized with VPTQ:
|
||||
|
||||
## Inference example
|
||||
**Run Llama 3.1 70b on RTX4090 (24G @ ~2bits) in real time**
|
||||

|
||||
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"VPTQ-community/Meta-Llama-3.1-70B-Instruct-v16-k65536-65536-woft",
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("VPTQ-community/Meta-Llama-3.1-70B-Instruct-v16-k65536-65536-woft")
|
||||
input_ids = tokenizer("hello, it's me", return_tensors="pt").to("cuda")
|
||||
out = model.generate(**input_ids, max_new_tokens=32, do_sample=False)
|
||||
```
|
||||
|
||||
## Quantize your own model
|
||||
VPTQ algorithm early-released at [VPTQ ](https://github.com/microsoft/VPTQ/tree/algorithm),
|
||||
and checkout the [tutorial](https://github.com/microsoft/VPTQ/blob/algorithm/algorithm.md).
|
||||
|
||||
## Early Results from Tech Report
|
||||
VPTQ achieves better accuracy and higher throughput with lower quantization overhead across models of different sizes. The following experimental results are for reference only; VPTQ can achieve better outcomes under reasonable parameters, especially in terms of model accuracy and inference speed.
|
||||
|
||||
|
||||
| Model | bitwidth | W2↓ | C4↓ | AvgQA↑ | tok/s↑ | mem(GB) | cost/h↓ |
|
||||
| ----------- | -------- | ---- | ---- | ------ | ------ | ------- | ------- |
|
||||
| LLaMA-2 7B | 2.02 | 6.13 | 8.07 | 58.2 | 39.9 | 2.28 | 2 |
|
||||
| | 2.26 | 5.95 | 7.87 | 59.4 | 35.7 | 2.48 | 3.1 |
|
||||
| LLaMA-2 13B | 2.02 | 5.32 | 7.15 | 62.4 | 26.9 | 4.03 | 3.2 |
|
||||
| | 2.18 | 5.28 | 7.04 | 63.1 | 18.5 | 4.31 | 3.6 |
|
||||
| LLaMA-2 70B | 2.07 | 3.93 | 5.72 | 68.6 | 9.7 | 19.54 | 19 |
|
||||
| | 2.11 | 3.92 | 5.71 | 68.7 | 9.7 | 20.01 | 19 |
|
||||
|
||||
|
||||
|
||||
## More Models in [VPTQ-community](https://huggingface.co/VPTQ-community)
|
||||
|
||||
⚠️ The repository only provides a method of model quantization algorithm.
|
||||
|
||||
⚠️ The open-source community VPTQ-community provides models based on the technical report and quantization algorithm.
|
||||
|
||||
|
||||
|
||||
**Quick Estimation of Model Bitwidth (Excluding Codebook Overhead)**:
|
||||
|
||||
- **Model Naming Convention**: The model's name includes the **vector length** $v$, **codebook (lookup table) size**, and **residual codebook size**. For example, "Meta-Llama-3.1-70B-Instruct-v8-k65536-256-woft" is "Meta-Llama-3.1-70B-Instruct", where:
|
||||
- **Vector Length**: 8
|
||||
- **Number of Centroids**: 65536 (2^16)
|
||||
- **Number of Residual Centroids**: 256 (2^8)
|
||||
- **Equivalent Bitwidth Calculation**:
|
||||
- **Index**: log2(65536) = 16 / 8 = 2 bits
|
||||
- **Residual Index**: log2(256) = 8 / 8 = 1 bit
|
||||
- **Total Bitwidth**: 2 + 1 = 3 bits
|
||||
- **Model Size Estimation**: 70B * 3 bits / 8 bits per Byte = 26.25 GB
|
||||
|
||||
- **Note**: This estimate does not include the size of the codebook (lookup table), other parameter overheads, and the padding overhead for storing indices. For the detailed calculation method, please refer to **Tech Report Appendix C.2**.
|
||||
|
||||
|
||||
| Model Series | Collections | (Estimated) Bit per weight |
|
||||
| :--------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------: | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Llama 3.1 Nemotron 70B Instruct HF | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-llama-31-nemotron-70b-instruct-hf-without-finetune-671730b96f16208d0b3fe942) | [4 bits](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v8-k65536-256-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v16-k65536-65536-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v8-k65536-0-woft) [1.875 bits](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v16-k65536-16384-woft) [1.625 bits](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v16-k65536-1024-woft) [1.5 bits](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v16-k65536-256-woft) |
|
||||
| Llama 3.1 8B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-llama-31-8b-instruct-without-finetune-66f2b70b1d002ceedef02d2e) | [4 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-65536-woft) [3.5 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-4096-woft) [3 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-256-woft) [2.3 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-8B-Instruct-v12-k65536-4096-woft) |
|
||||
| Llama 3.1 70B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-llama-31-70b-instruct-without-finetune-66f2bf454d3dd78dfee2ff11) | [4 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-256-woft) [2.25 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-4-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v16-k65536-65536-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-0-woft) [1.93 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v16-k65536-32768-woft) [1.875 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k32768-0-woft) [1.75 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k16384-0-woft) |
|
||||
| Llama 3.1 405B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-llama-31-405b-instruct-without-finetune-66f4413f9ba55e1a9e52cfb0) | [4 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v8-k65536-256-woft) [2 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k65536-65536-woft) [1.875 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k32768-32768-woft) [1.625 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k65536-1024-woft) [1.5 bits (1)](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v8-k4096-0-woft) [1.5 bits (2)](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k65536-256-woft) [1.43 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k65536-128-woft) [1.375 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k65536-64-woft) |
|
||||
| Mistral Large Instruct 2407 (123B) | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-mistral-large-instruct-2407-without-finetune-6711ebfb7faf85eed9cceb16) | [4 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v8-k65536-256-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v16-k65536-65536-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v8-k65536-0-woft) [1.875 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v16-k65536-16384-woft) [1.75 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v16-k65536-4096-woft) [1.625 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v16-k65536-1024-woft) [1.5 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v16-k65536-256-woft) |
|
||||
| Qwen 2.5 7B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-qwen-25-7b-instruct-without-finetune-66f3e9866d3167cc05ce954a) | [4 bits](https://huggingface.co/VPTQ-community/Qwen2.5-7B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Qwen2.5-7B-Instruct-v8-k65536-256-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Qwen2.5-7B-Instruct-v8-k256-256-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Qwen2.5-7B-Instruct-v8-k65536-0-woft) [2 bits (3)](https://huggingface.co/VPTQ-community/Qwen2.5-7B-Instruct-v16-k65536-65536-woft) |
|
||||
| Qwen 2.5 14B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-qwen-25-14b-instruct-without-finetune-66f827f83c7ffa7931b8376c) | [4 bits](https://huggingface.co/VPTQ-community/Qwen2.5-14B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Qwen2.5-14B-Instruct-v8-k65536-256-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Qwen2.5-14B-Instruct-v8-k256-256-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Qwen2.5-14B-Instruct-v8-k65536-0-woft) [2 bits (3)](https://huggingface.co/VPTQ-community/Qwen2.5-14B-Instruct-v16-k65536-65536-woft) |
|
||||
| Qwen 2.5 32B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-qwen-25-32b-instruct-without-finetune-66fe77173bf7d64139f0f613) | [4 bits](https://huggingface.co/VPTQ-community/Qwen2.5-32B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Qwen2.5-32B-Instruct-v8-k65536-256-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Qwen2.5-32B-Instruct-v16-k65536-65536-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Qwen2.5-32B-Instruct-v8-k65536-0-woft) [2 bits (3)](https://huggingface.co/VPTQ-community/Qwen2.5-32B-Instruct-v8-k256-256-woft) |
|
||||
| Qwen 2.5 72B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-qwen-25-72b-instruct-without-finetune-66f3bf1b3757dfa1ecb481c0) | [4 bits](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k65536-256-woft) [2.38 bits](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k1024-512-woft) [2.25 bits (1)](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k512-512-woft) [2.25 bits (2)](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k65536-4-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k65536-0-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v16-k65536-65536-woft) [1.94 bits](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v16-k65536-32768-woft) |
|
||||
| Reproduced from the tech report | [HF 🤗](https://huggingface.co/collections/VPTQ-community/reproduced-vptq-tech-report-baseline-66fbf1dffe741cc9e93ecf04) | Results from the open source community for reference only, please use them responsibly. |
|
||||
| Hessian and Inverse Hessian Matrix | [HF 🤗](https://huggingface.co/collections/VPTQ-community/hessian-and-invhessian-checkpoints-66fd249a104850d17b23fd8b) | Collected from RedPajama-Data-1T-Sample, following [Quip#](https://github.com/Cornell-RelaxML/quip-sharp/blob/main/quantize_llama/hessian_offline_llama.py)
|
||||
|
|
@ -151,6 +151,8 @@
|
|||
title: AWQ
|
||||
- local: in_translation
|
||||
title: (번역중) AQLM
|
||||
- local: in_translation
|
||||
title: (번역중) VPTQ
|
||||
- local: in_translation
|
||||
title: (번역중) Quanto
|
||||
- local: in_translation
|
||||
|
|
@ -173,6 +175,8 @@
|
|||
title: (번역중) AWQ
|
||||
- local: in_translation
|
||||
title: (번역중) AQLM
|
||||
- local: in_translation
|
||||
title: (번역중) VPTQ
|
||||
- local: quantization/quanto
|
||||
title: Quanto
|
||||
- local: quantization/eetq
|
||||
|
|
|
|||
|
|
@ -375,7 +375,7 @@ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable
|
|||
양자화는 LLM 가중치를 더 낮은 정밀도로 저장하여 크기를 줄입니다. 이는 메모리 사용량을 줄이며 GPU 메모리에 제약이 있는 경우 추론을 위해 LLM을 로드하는 것을 더 용이하게 합니다. GPU가 충분하다면, 모델을 양자화할 필요는 없습니다. 추가적인 양자화 및 양자화 해제 단계로 인해 약간의 지연이 발생할 수 있기 때문입니다(AWQ 및 융합 AWQ 모듈 제외).
|
||||
|
||||
> [!TIP]
|
||||
> 다양한 양자화 라이브러리(자세한 내용은 [Quantization](./quantization) 가이드를 참조하십시오)가 있습니다. 여기에는 Quanto, AQLM, AWQ 및 AutoGPTQ가 포함됩니다. 사용 사례에 가장 잘 맞는 라이브러리를 사용해 보십시오. 또한 AutoGPTQ와 bitsandbytes를 비교하는 [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) 블로그 게시물을 읽어보는 것을 추천합니다.
|
||||
> 다양한 양자화 라이브러리(자세한 내용은 [Quantization](./quantization) 가이드를 참조하십시오)가 있습니다. 여기에는 Quanto, AQLM, VPTQ, AWQ 및 AutoGPTQ가 포함됩니다. 사용 사례에 가장 잘 맞는 라이브러리를 사용해 보십시오. 또한 AutoGPTQ와 bitsandbytes를 비교하는 [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) 블로그 게시물을 읽어보는 것을 추천합니다.
|
||||
|
||||
아래의 모델 메모리 계산기를 사용하여 모델을 로드하는 데 필요한 메모리를 추정하고 비교해 보십시오. 예를 들어 [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)를 로드하는 데 필요한 메모리를 추정해 보십시오.
|
||||
|
||||
|
|
|
|||
|
|
@ -35,6 +35,10 @@ Transformers에서 지원되지 않는 양자화 기법들은 [`HfQuantizer`]
|
|||
|
||||
[[autodoc]] AqlmConfig
|
||||
|
||||
## VptqConfig[[transformers.VptqConfig]]
|
||||
|
||||
[[autodoc]] VptqConfig
|
||||
|
||||
## AwqConfig[[transformers.AwqConfig]]
|
||||
|
||||
[[autodoc]] AwqConfig
|
||||
|
|
|
|||
|
|
@ -1000,6 +1000,7 @@ _import_structure = {
|
|||
"HqqConfig",
|
||||
"QuantoConfig",
|
||||
"TorchAoConfig",
|
||||
"VptqConfig",
|
||||
],
|
||||
}
|
||||
|
||||
|
|
@ -6017,6 +6018,7 @@ if TYPE_CHECKING:
|
|||
HqqConfig,
|
||||
QuantoConfig,
|
||||
TorchAoConfig,
|
||||
VptqConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -105,6 +105,7 @@ _import_structure = {
|
|||
],
|
||||
"peft": ["PeftAdapterMixin"],
|
||||
"quanto": ["replace_with_quanto_layers"],
|
||||
"vptq": ["replace_with_vptq_linear"],
|
||||
}
|
||||
|
||||
try:
|
||||
|
|
@ -207,6 +208,7 @@ if TYPE_CHECKING:
|
|||
)
|
||||
from .peft import PeftAdapterMixin
|
||||
from .quanto import replace_with_quanto_layers
|
||||
from .vptq import replace_with_vptq_linear
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
|
|
|
|||
101
src/transformers/integrations/vptq.py
Normal file
101
src/transformers/integrations/vptq.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"VPTQ (Vector Post-Training Quantization) integration file"
|
||||
|
||||
import torch.nn as nn
|
||||
from accelerate import init_empty_weights
|
||||
from vptq import VQuantLinear
|
||||
|
||||
|
||||
def replace_with_vptq_linear(
|
||||
model,
|
||||
quantization_config=None,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
"""
|
||||
Public method that recursively replaces the Linear layers of the given model with VPTQ quantized layers.
|
||||
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
|
||||
conversion has been successfull or not.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
The model to convert, can be any `torch.nn.Module` instance.
|
||||
quantization_config (`VptqConfig`):
|
||||
The quantization config object that contains the quantization parameters.
|
||||
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
|
||||
Names of the modules to not convert in `VQuantLinear`. In practice we keep the `lm_head` in full precision
|
||||
for numerical stability reasons.
|
||||
current_key_name (`list`, *optional*):
|
||||
A list that contains the current key name. This is used for recursion and should not be passed by the user.
|
||||
has_been_replaced (`bool`, *optional*):
|
||||
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
|
||||
should not be passed by the user.
|
||||
"""
|
||||
|
||||
modules_to_not_convert = ["lm_head"] if not modules_to_not_convert else modules_to_not_convert
|
||||
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
layer_name = ".".join(current_key_name)
|
||||
shared_layer_config = quantization_config.shared_layer_config
|
||||
config_for_layers = quantization_config.config_for_layers
|
||||
|
||||
if (
|
||||
isinstance(module, nn.Linear)
|
||||
and layer_name not in modules_to_not_convert
|
||||
and ((layer_name in config_for_layers) or (current_key_name[-1] in shared_layer_config))
|
||||
):
|
||||
layer_params = config_for_layers.get(layer_name, None) or shared_layer_config.get(
|
||||
current_key_name[-1], None
|
||||
)
|
||||
|
||||
with init_empty_weights():
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
|
||||
model._modules[name] = VQuantLinear(
|
||||
in_features,
|
||||
out_features,
|
||||
vector_lens=layer_params["vector_lens"],
|
||||
num_centroids=layer_params["num_centroids"],
|
||||
num_res_centroids=layer_params["num_res_centroids"],
|
||||
group_num=layer_params["group_num"],
|
||||
group_size=layer_params["group_size"],
|
||||
outlier_size=layer_params["outlier_size"],
|
||||
indices_as_float=layer_params["indices_as_float"],
|
||||
enable_norm=layer_params["enable_norm"],
|
||||
enable_perm=layer_params["enable_perm"],
|
||||
is_indice_packed=True,
|
||||
enable_proxy_error=False,
|
||||
bias=module.bias is not None,
|
||||
)
|
||||
has_been_replaced = True
|
||||
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = replace_with_vptq_linear(
|
||||
module,
|
||||
quantization_config=quantization_config,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
current_key_name=current_key_name,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
return model, has_been_replaced
|
||||
|
|
@ -29,6 +29,7 @@ from ..utils.quantization_config import (
|
|||
QuantizationMethod,
|
||||
QuantoConfig,
|
||||
TorchAoConfig,
|
||||
VptqConfig,
|
||||
)
|
||||
from .quantizer_aqlm import AqlmHfQuantizer
|
||||
from .quantizer_awq import AwqQuantizer
|
||||
|
|
@ -42,6 +43,7 @@ from .quantizer_gptq import GptqHfQuantizer
|
|||
from .quantizer_hqq import HqqHfQuantizer
|
||||
from .quantizer_quanto import QuantoHfQuantizer
|
||||
from .quantizer_torchao import TorchAoHfQuantizer
|
||||
from .quantizer_vptq import VptqHfQuantizer
|
||||
|
||||
|
||||
AUTO_QUANTIZER_MAPPING = {
|
||||
|
|
@ -57,6 +59,7 @@ AUTO_QUANTIZER_MAPPING = {
|
|||
"fbgemm_fp8": FbgemmFp8HfQuantizer,
|
||||
"torchao": TorchAoHfQuantizer,
|
||||
"bitnet": BitNetHfQuantizer,
|
||||
"vptq": VptqHfQuantizer,
|
||||
}
|
||||
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
|
|
@ -72,6 +75,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
|||
"fbgemm_fp8": FbgemmFp8Config,
|
||||
"torchao": TorchAoConfig,
|
||||
"bitnet": BitNetConfig,
|
||||
"vptq": VptqConfig,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
98
src/transformers/quantizers/quantizer_vptq.py
Normal file
98
src/transformers/quantizers/quantizer_vptq.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .base import HfQuantizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
from ..utils import is_accelerate_available, is_torch_available, is_vptq_available, logging
|
||||
from ..utils.quantization_config import QuantizationConfigMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VptqHfQuantizer(HfQuantizer):
|
||||
"""
|
||||
Quantizer of the VPTQ method. Enables the loading of prequantized models.
|
||||
"""
|
||||
|
||||
requires_calibration = True
|
||||
required_packages = ["vptq"]
|
||||
|
||||
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Using `vptq` quantization requires Accelerate: `pip install accelerate`")
|
||||
|
||||
if not is_vptq_available():
|
||||
raise ImportError("Using `vptq` quantization requires VPTQ>=0.0.4: `pip install -U vptq`")
|
||||
|
||||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if torch_dtype is None:
|
||||
if torch.cuda.is_available():
|
||||
torch_dtype = torch.float16
|
||||
logger.info(
|
||||
"CUDA available. Assuming VPTQ inference on GPU and loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually."
|
||||
)
|
||||
else:
|
||||
import vptq
|
||||
|
||||
device_availability = getattr(vptq, "device_availability", lambda device: False)
|
||||
if device_availability("cpu") is True:
|
||||
raise RuntimeError("No GPU found. Please wait for the next release of VPTQ to use CPU inference")
|
||||
torch_dtype = torch.float32
|
||||
logger.info("No GPU found. Assuming VPTQ inference on CPU and loading the model in `torch.float32`.")
|
||||
return torch_dtype
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
we don't have param like modules_to_not_convert to indicate which layers should not be quantized
|
||||
because `quantization_config` include the layers that should be quantized
|
||||
"""
|
||||
from ..integrations import replace_with_vptq_linear
|
||||
|
||||
modules_to_not_convert = kwargs.get("modules_to_not_convert", []) + (
|
||||
self.quantization_config.modules_to_not_convert or []
|
||||
)
|
||||
|
||||
replace_with_vptq_linear(
|
||||
model,
|
||||
quantization_config=self.quantization_config,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
@property
|
||||
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
|
||||
return False
|
||||
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return True
|
||||
|
|
@ -142,6 +142,7 @@ from .utils import (
|
|||
is_torchdynamo_available,
|
||||
is_torchvision_available,
|
||||
is_vision_available,
|
||||
is_vptq_available,
|
||||
strtobool,
|
||||
)
|
||||
|
||||
|
|
@ -1142,6 +1143,13 @@ def require_aqlm(test_case):
|
|||
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
|
||||
|
||||
|
||||
def require_vptq(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires vptq
|
||||
"""
|
||||
return unittest.skipUnless(is_vptq_available(), "test requires vptq")(test_case)
|
||||
|
||||
|
||||
def require_eetq(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires eetq
|
||||
|
|
|
|||
|
|
@ -233,6 +233,7 @@ from .import_utils import (
|
|||
is_training_run_on_sagemaker,
|
||||
is_uroman_available,
|
||||
is_vision_available,
|
||||
is_vptq_available,
|
||||
requires_backends,
|
||||
torch_only_method,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -93,11 +93,13 @@ FSDP_MIN_VERSION = "1.12.0"
|
|||
GGUF_MIN_VERSION = "0.10.0"
|
||||
XLA_FSDPV2_MIN_VERSION = "2.2.0"
|
||||
HQQ_MIN_VERSION = "0.2.1"
|
||||
VPTQ_MIN_VERSION = "0.0.4"
|
||||
|
||||
|
||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
||||
_apex_available = _is_package_available("apex")
|
||||
_aqlm_available = _is_package_available("aqlm")
|
||||
_vptq_available, _vptq_version = _is_package_available("vptq", return_version=True)
|
||||
_av_available = importlib.util.find_spec("av") is not None
|
||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||
_eetq_available = _is_package_available("eetq")
|
||||
|
|
@ -816,6 +818,10 @@ def is_aqlm_available():
|
|||
return _aqlm_available
|
||||
|
||||
|
||||
def is_vptq_available(min_version: str = VPTQ_MIN_VERSION):
|
||||
return _vptq_available and version.parse(_vptq_version) >= version.parse(min_version)
|
||||
|
||||
|
||||
def is_av_available():
|
||||
return _av_available
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ class QuantizationMethod(str, Enum):
|
|||
GPTQ = "gptq"
|
||||
AWQ = "awq"
|
||||
AQLM = "aqlm"
|
||||
VPTQ = "vptq"
|
||||
QUANTO = "quanto"
|
||||
EETQ = "eetq"
|
||||
HQQ = "hqq"
|
||||
|
|
@ -994,6 +995,102 @@ class AqlmConfig(QuantizationConfigMixin):
|
|||
self.linear_weights_not_to_quantize = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class VptqLayerConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
This is used to explain vptq config params for each layer
|
||||
Args:
|
||||
enable_norm (`bool`, *optional*, defaults to `True`): to control if we have scale/bias for fp-weight
|
||||
enable_perm (`bool`, *optional*, defaults to `True`): to perm input_channel or not
|
||||
group_num (`int`, *optional*, defaults to `1`): how many single groups for vector-quantization
|
||||
group_size (`int`, *optional*, defaults to `-1`): depends on out-features
|
||||
indices_as_float (`bool`, *optional*, defaults to `False`): for Finetuning
|
||||
is_indice_packed (`bool`, *optional*, defaults to `True`): should always be True
|
||||
num_centroids (`list`, *optional*, defaults to `[-1, -1]`): centriod numbers of clusters
|
||||
num_res_centroids (`list`, *optional*, defaults to `[-1, -1]`): ditto for residual
|
||||
outlier_size (`int`, *optional*, defaults to `1`): outliers
|
||||
vector_lens (`list`, *optional*, defaults to `[-1, -1]`): centroid vector length in quantization
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enable_norm: bool = True,
|
||||
enable_perm: bool = True,
|
||||
group_num: int = 1,
|
||||
group_size: int = -1,
|
||||
in_features: int = -1,
|
||||
indices_as_float: bool = False,
|
||||
is_indice_packed: bool = True,
|
||||
num_centroids: tuple = [-1, -1],
|
||||
num_res_centroids: tuple = [-1, -1],
|
||||
out_features: int = -1,
|
||||
outlier_size: int = 0,
|
||||
vector_lens: tuple = [-1, -1],
|
||||
**kwargs,
|
||||
):
|
||||
self.enable_norm = enable_norm
|
||||
self.enable_perm = enable_perm
|
||||
self.group_num = group_num
|
||||
self.group_size = group_size
|
||||
self.in_features = in_features
|
||||
self.indices_as_float = indices_as_float
|
||||
self.is_indice_packed = is_indice_packed
|
||||
self.num_centroids = num_centroids
|
||||
self.num_res_centroids = num_res_centroids
|
||||
self.out_features = out_features
|
||||
self.outlier_size = outlier_size
|
||||
self.vector_lens = vector_lens
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct
|
||||
"""
|
||||
if self.is_indice_packed is False:
|
||||
raise ValueError("is_indice_packed should always be True")
|
||||
|
||||
|
||||
@dataclass
|
||||
class VptqConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
This is a wrapper class about `vptq` parameters.
|
||||
|
||||
Args:
|
||||
enable_proxy_error (`bool`, *optional*, defaults to `False`): calculate proxy error for each layer
|
||||
config_for_layers (`Dict`, *optional*, defaults to `{}`): quantization params for each layer
|
||||
shared_layer_config (`Dict`, *optional*, defaults to `{}`): shared quantization params among layers
|
||||
modules_to_not_convert (`list`, *optional*, default to `None`):
|
||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
||||
some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional parameters from which to initialize the configuration object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enable_proxy_error: bool = False,
|
||||
config_for_layers: Dict[str, Any] = {},
|
||||
shared_layer_config: Dict[str, Any] = {},
|
||||
modules_to_not_convert: Optional[List] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.VPTQ
|
||||
self.enable_proxy_error = enable_proxy_error
|
||||
self.config_for_layers: Dict[str, Any] = config_for_layers
|
||||
self.shared_layer_config: Dict[str, Any] = shared_layer_config
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct
|
||||
"""
|
||||
for layer_name, layer_param in self.config_for_layers.items():
|
||||
VptqLayerConfig(**layer_param)
|
||||
if self.enable_proxy_error is True:
|
||||
raise ValueError("enable_proxy_error should always be False until we support training")
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantoConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
|
|
|
|||
0
tests/quantization/vptq_integration/__init__.py
Normal file
0
tests/quantization/vptq_integration/__init__.py
Normal file
194
tests/quantization/vptq_integration/test_vptq.py
Normal file
194
tests/quantization/vptq_integration/test_vptq.py
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, VptqConfig
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_vptq,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_accelerate_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
|
||||
class VptqConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
Makes sure the config format is properly set
|
||||
"""
|
||||
quantization_config = VptqConfig()
|
||||
vptq_orig_config = quantization_config.to_dict()
|
||||
|
||||
self.assertEqual(quantization_config.quant_config, vptq_orig_config["quant_config"])
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_vptq
|
||||
@require_accelerate
|
||||
class VptqTest(unittest.TestCase):
|
||||
model_name = "VPTQ-community/Meta-Llama-3.1-8B-Instruct-v12-k65536-4096-woft"
|
||||
|
||||
input_text = "Hello my name is"
|
||||
max_new_tokens = 32
|
||||
|
||||
EXPECTED_OUTPUT = "Hello my name is Sarah and I am a 25 year old woman from the United States. I am a college graduate and I am currently working as a marketing specialist for a small"
|
||||
|
||||
device_map = "cuda"
|
||||
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Setup quantized model
|
||||
"""
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model_name,
|
||||
device_map=cls.device_map,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def test_quantized_model(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_raise_if_non_quantized(self):
|
||||
model_id = "facebook/opt-125m"
|
||||
quantization_config = VptqConfig()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
|
||||
|
||||
def test_save_pretrained(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly after being saved and loaded
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname)
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
|
||||
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_quantized_model_multi_gpu(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly with multiple GPUs
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto")
|
||||
|
||||
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_quantized_model_conversion(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model has been converted properly
|
||||
"""
|
||||
from vptq import VQuantLinear
|
||||
|
||||
from transformers.integrations import replace_with_vptq_linear
|
||||
|
||||
model_id = "facebook/opt-350m"
|
||||
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
|
||||
modules_to_not_convert = ["lm_head"]
|
||||
names = [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"out_proj",
|
||||
"fc1",
|
||||
"fc2",
|
||||
]
|
||||
value = {
|
||||
"enable_norm": True,
|
||||
"enable_perm": True,
|
||||
"group_num": 1,
|
||||
"group_size": 128,
|
||||
"indices_as_float": False,
|
||||
"num_centroids": [-1, 128],
|
||||
"num_res_centroids": [-1, 128],
|
||||
"outlier_size": 0,
|
||||
"vector_lens": [-1, 12],
|
||||
}
|
||||
shared_layer_config = {}
|
||||
for name in names:
|
||||
shared_layer_config[name] = value
|
||||
for i in range(24):
|
||||
modules_to_not_convert.append("model.decoder.layers.{layer_idx}.fc1".format(layer_idx=i))
|
||||
layer_configs = {}
|
||||
layer_configs["model.decoder.project_out"] = value
|
||||
layer_configs["model.decoder.project_in"] = value
|
||||
quantization_config = VptqConfig(config_for_layers=layer_configs, shared_layer_config=shared_layer_config)
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
|
||||
nb_linears = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
nb_linears += 1
|
||||
|
||||
model, _ = replace_with_vptq_linear(model, quantization_config=quantization_config)
|
||||
nb_vptq_linear = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, VQuantLinear):
|
||||
nb_vptq_linear += 1
|
||||
|
||||
self.assertEqual(nb_linears - 1, nb_vptq_linear)
|
||||
|
||||
# Try with `linear_weights_not_to_quantize`
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
quantization_config = VptqConfig(config_for_layers=layer_configs, shared_layer_config=shared_layer_config)
|
||||
model, _ = replace_with_vptq_linear(
|
||||
model, quantization_config=quantization_config, modules_to_not_convert=modules_to_not_convert
|
||||
)
|
||||
nb_vptq_linear = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, VQuantLinear):
|
||||
nb_vptq_linear += 1
|
||||
# 25 comes from 24 decoder.layers.{layer_idx}.fc1
|
||||
# and the last lm_head
|
||||
self.assertEqual(nb_linears - 25, nb_vptq_linear)
|
||||
Loading…
Reference in a new issue