HIGGS Quantization Support (#34997)

* higgs init

* working with crunches

* per-model workspaces

* style

* style 2

* tests and style

* higgs tests passing

* protecting torch import

* removed torch.Tensor type annotations

* torch.nn.Module inheritance fix maybe

* hide inputs inside quantizer calls

* style structure something

* Update src/transformers/quantizers/quantizer_higgs.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* reworked num_sms

* Update src/transformers/integrations/higgs.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* revamped device checks

* docstring upd

* Update src/transformers/quantizers/quantizer_higgs.py

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* edited tests and device map assertions

* minor edits

* updated flute cuda version in docker

* Added p=1 and 2,3bit HIGGS

* flute version check update

* incorporated `modules_to_not_convert`

* less hardcoding

* Fixed comment

* Added docs

* Fixed gemma support

* example in docs

* fixed torch_dtype for HIGGS

* Update docs/source/en/quantization/higgs.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Collection link

* dequantize interface

* newer flute version, torch.compile support

* unittest message fix

* docs update compile

* isort

* ValueError instead of assert

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Andrei Panferov 2024-12-23 22:54:49 +07:00 committed by GitHub
parent ef1f54a0a7
commit 64c05eecd6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 1249 additions and 0 deletions

View file

@ -69,6 +69,10 @@ RUN python3 -m pip install --no-cache-dir optimum-quanto
# Add eetq for quantization testing
RUN python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git
# Add flute-kernel and fast_hadamard_transform for quantization testing
RUN python3 -m pip install --no-cache-dir flute-kernel==0.3.0 -i https://flute-ai.github.io/whl/cu118
RUN python3 -m pip install --no-cache-dir fast_hadamard_transform==1.0.4.post1
# When installing in editable mode, `transformers` is not recognized as a package.
# this line must be added in order for python to be aware of transformers.
RUN cd transformers && python3 setup.py develop

View file

@ -173,6 +173,8 @@
title: Quanto
- local: quantization/eetq
title: EETQ
- local: quantization/higgs
title: HIGGS
- local: quantization/hqq
title: HQQ
- local: quantization/fbgemm_fp8

View file

@ -57,6 +57,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
[[autodoc]] quantizers.base.HfQuantizer
## HiggsConfig
[[autodoc]] HiggsConfig
## HqqConfig
[[autodoc]] HqqConfig

View file

@ -0,0 +1,66 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# HIGGS
HIGGS is a 0-shot quantization algorithm that combines Hadamard preprocessing with MSE-Optimal quantization grids to achieve lower quantization error and SOTA performance. You can find more information in the paper [arxiv.org/abs/2411.17525](https://arxiv.org/abs/2411.17525).
Runtime support for HIGGS is implemented through [FLUTE](https://arxiv.org/abs/2407.10960), and its [library](https://github.com/HanGuo97/flute).
## Quantization Example
```python
from transformers import AutoModelForCausalLM, AutoTokenizer, HiggsConfig
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-9b-it",
quantization_config=HiggsConfig(bits=4),
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
tokenizer.decode(model.generate(
**tokenizer("Hi,", return_tensors="pt").to(model.device),
temperature=0.5,
top_p=0.80,
)[0])
```
## Pre-quantized models
Some pre-quantized models can be found in the [official collection](https://huggingface.co/collections/ISTA-DASLab/higgs-675308e432fd56b7f6dab94e) on Hugging Face Hub.
## Current Limitations
**Architectures**
Currently, FLUTE, and HIGGS by extension, **only support Llama 3 and 3.0 of 8B, 70B and 405B parameters, as well as Gemma-2 9B and 27B**. We're working on allowing to run more diverse models as well as allow arbitrary models by modifying the FLUTE compilation procedure.
**torch.compile**
HIGGS is fully compatible with `torch.compile`. Compiling `model.forward`, as described [here](../perf_torch_compile.md), here're the speedups it provides on RTX 4090 for `Llama-3.1-8B-Instruct` (forward passes/sec):
| Batch Size | BF16 (With `torch.compile`) | HIGGS 4bit (No `torch.compile`) | HIGGS 4bit (With `torch.compile`) |
|------------|-----------------------------|----------------------------------|-----------------------------------|
| 1 | 59 | 41 | 124 |
| 4 | 57 | 42 | 123 |
| 16 | 56 | 41 | 120 |
**Quantized training**
Currently, HIGGS doesn't support quantized training (and backward passes in general). We're working on adding support for it.

View file

@ -54,6 +54,7 @@ Use the table below to help you decide which quantization method to use.
| [EETQ](./eetq) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | ? | 8 | 🟢 | 🟢 | 🟢 | https://github.com/NetEase-FuXi/EETQ |
| GGUF / GGML (llama.cpp) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 1 - 8 | 🔴 | [See GGUF section](../gguf) | [See GGUF section](../gguf) | https://github.com/ggerganov/llama.cpp |
| [GPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 2 - 3 - 4 - 8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |
| [HIGGS](./higgs) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 2 - 4 | 🔴 | 🟢 | 🟢 | https://github.com/HanGuo97/flute |
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
| [optimum-quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |

View file

@ -998,6 +998,7 @@ _import_structure = {
"EetqConfig",
"FbgemmFp8Config",
"GPTQConfig",
"HiggsConfig",
"HqqConfig",
"QuantoConfig",
"TorchAoConfig",
@ -6023,6 +6024,7 @@ if TYPE_CHECKING:
EetqConfig,
FbgemmFp8Config,
GPTQConfig,
HiggsConfig,
HqqConfig,
QuantoConfig,
TorchAoConfig,

View file

@ -63,6 +63,7 @@ _import_structure = {
"load_dequant_gguf_tensor",
"load_gguf",
],
"higgs": ["HiggsLinear", "dequantize_higgs", "quantize_with_higgs", "replace_with_higgs_linear"],
"hqq": ["prepare_for_hqq_linear"],
"integration_utils": [
"INTEGRATION_TO_CALLBACK",
@ -166,6 +167,7 @@ if TYPE_CHECKING:
load_dequant_gguf_tensor,
load_gguf,
)
from .higgs import HiggsLinear, dequantize_higgs, quantize_with_higgs, replace_with_higgs_linear
from .hqq import prepare_for_hqq_linear
from .integration_utils import (
INTEGRATION_TO_CALLBACK,

View file

@ -0,0 +1,657 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"HIGGS through FLUTE (Flexible Lookup Table Engine for LUT-quantized LLMs) integration file"
from math import sqrt
from ..utils import (
is_flute_available,
is_hadamard_available,
is_torch_available,
)
if is_torch_available():
import torch
from torch import nn
if is_flute_available():
import flute.utils
if is_hadamard_available():
from fast_hadamard_transform import hadamard_transform
if is_flute_available():
import flute.utils
from flute.integrations.higgs import prepare_data_transposed
def pad_to_block(tensor, dims, had_block_size, value=0):
pad_dims = [0 for _ in range(2 * len(tensor.shape))]
for dim in dims:
size = tensor.shape[dim]
next_multiple_of_1024 = ((size - 1) // had_block_size + 1) * had_block_size
delta = next_multiple_of_1024 - size
pad_dims[-2 * dim - 1] = delta
return nn.functional.pad(tensor, pad_dims, "constant", value)
def get_higgs_grid(p: int, n: int):
if (p, n) == (2, 256):
return torch.tensor(
[
[-2.501467704772949, 0.17954708635807037],
[-0.6761789321899414, 1.2728623151779175],
[-1.8025816679000854, 0.7613157629966736],
[-0.538287878036499, -2.6028504371643066],
[0.8415029644966125, -0.8600977659225464],
[0.7023013234138489, 3.3138747215270996],
[0.5699077844619751, 2.5782253742218018],
[3.292393207550049, -0.6016128063201904],
[0.5561617016792297, -1.7723814249038696],
[-2.1012380123138428, 0.020958125591278076],
[0.46085724234580994, 0.8428705334663391],
[1.4548040628433228, -0.6156039237976074],
[3.210029363632202, 0.3546904921531677],
[0.8893890976905823, -0.5967988967895508],
[0.8618854284286499, -3.2061192989349365],
[1.1360996961593628, -0.23852407932281494],
[1.6646337509155273, -0.9265465140342712],
[1.4767773151397705, 1.2476022243499756],
[-1.0511897802352905, 1.94503915309906],
[-1.56318998336792, -0.3264186680316925],
[-0.1829211413860321, 0.2922491431236267],
[-0.8950616717338562, -1.3887052536010742],
[-0.08206957578659058, -1.329533576965332],
[-0.487422913312912, 1.4817842245101929],
[-1.6769757270812988, -2.8269758224487305],
[-1.5057679414749146, 1.8905963897705078],
[1.8335362672805786, 1.0515104532241821],
[0.3273945450782776, 1.0491033792495728],
[-3.295924186706543, -0.7021600008010864],
[-1.8428784608840942, -1.2315762042999268],
[-0.8575026392936707, -1.7005949020385742],
[-1.120667815208435, 0.6467998027801514],
[-0.1588846743106842, -1.804071068763733],
[-0.8539647459983826, 0.5645008683204651],
[-1.4192019701004028, -0.6175029873847961],
[1.0799058675765991, 1.7871345281600952],
[1.171311855316162, 0.7511613965034485],
[2.162078380584717, 0.8044339418411255],
[1.3969420194625854, -1.243762493133545],
[-0.23818807303905487, 0.053944624960422516],
[2.304199457168579, -1.2667627334594727],
[1.4225027561187744, 0.568610668182373],
[0.376836895942688, -0.7134661674499512],
[2.0404467582702637, 0.4087389409542084],
[0.7639489769935608, -1.1367933750152588],
[0.3622530400753021, -1.4827953577041626],
[0.4100743532180786, 0.36108437180519104],
[-1.5867475271224976, -1.618212342262268],
[-2.2769672870635986, -1.2132309675216675],
[0.9184022545814514, -0.34428009390830994],
[-0.3902314603328705, 0.21785245835781097],
[3.120687484741211, 1.3077973127365112],
[1.587440848350525, -1.6506884098052979],
[-1.718808889389038, -0.038405973464250565],
[-0.6888407468795776, -0.8402308821678162],
[-0.7981445789337158, -1.1117373704910278],
[-2.4124443531036377, 1.3419722318649292],
[-0.6611530184745789, 0.9939885139465332],
[-0.33103418350219727, -0.16702833771705627],
[-2.4091389179229736, -2.326857566833496],
[1.6610108613967896, -2.159703254699707],
[0.014884627424180508, 0.3887578248977661],
[0.029668325558304787, 1.8786455392837524],
[1.180362582206726, 2.699317216873169],
[1.821286678314209, -0.5960053205490112],
[-0.44835323095321655, 3.327436685562134],
[-0.3714401423931122, -2.1466753482818604],
[-1.1103475093841553, -2.4536871910095215],
[-0.39110705256462097, 0.6670510172843933],
[0.474752813577652, -1.1959707736968994],
[-0.013110585510730743, -2.52519154548645],
[-2.0836575031280518, -1.703289270401001],
[-1.1077687740325928, -0.1252644956111908],
[-0.4138077199459076, 1.1837692260742188],
[-1.977599024772644, 1.688241720199585],
[-1.659559965133667, -2.1387736797332764],
[0.03242531046271324, 0.6526556015014648],
[0.9127950072288513, 0.6099498867988586],
[-0.38478314876556396, 0.433487206697464],
[0.27454206347465515, -0.27719801664352417],
[0.10388526320457458, 2.2812814712524414],
[-0.014394169673323631, -3.177137613296509],
[-1.2871228456497192, -0.8961855173110962],
[0.5720916986465454, -0.921597957611084],
[1.1159656047821045, -0.7609877586364746],
[2.4383342266082764, -2.2983546257019043],
[-0.294057160615921, -0.9770799875259399],
[-0.9342701435089111, 1.107579231262207],
[-1.549338698387146, 3.090520143508911],
[2.6076579093933105, 2.051239013671875],
[-0.9259037375450134, 1.407211184501648],
[-0.1747353971004486, 0.540488600730896],
[-0.8963701725006104, 0.8271111249923706],
[0.6480194926261902, 1.0128909349441528],
[0.980783998966217, -0.06156221032142639],
[-0.16883476078510284, 1.0601658821105957],
[0.5839992761611938, 0.004697148688137531],
[-0.34228450059890747, -1.2423977851867676],
[2.500824451446533, 0.3665279746055603],
[-0.17641609907150269, 1.3529551029205322],
[0.05378641560673714, 2.817232847213745],
[-1.2391047477722168, 2.354328155517578],
[0.630434513092041, -0.668536365032196],
[1.7576488256454468, 0.6738647818565369],
[0.4435231387615204, 0.6000469326972961],
[-0.08794835954904556, -0.11511358618736267],
[1.6540337800979614, 0.33995017409324646],
[-0.04202975332736969, -0.5375117063522339],
[-0.4247745871543884, -0.7897617220878601],
[0.06695003807544708, 1.2000739574432373],
[-3.2508881092071533, 0.28734830021858215],
[-1.613816261291504, 0.4944162368774414],
[1.3598989248275757, 0.26117825508117676],
[2.308382511138916, 1.3462618589401245],
[-1.2137469053268433, -1.9254342317581177],
[-0.4889402985572815, 1.8136259317398071],
[-0.1870335340499878, -0.3480615019798279],
[1.0766386985778809, -1.0627082586288452],
[0.4651014506816864, 2.131748914718628],
[-0.1306295394897461, -0.7811847925186157],
[0.06433182954788208, -1.5397958755493164],
[-0.2894323468208313, -0.5789554715156555],
[-0.6081662178039551, 0.4845278263092041],
[2.697964668273926, -0.18515698611736298],
[0.1277363896369934, -0.7221432328224182],
[0.8700758218765259, 0.35042452812194824],
[0.22088994085788727, 0.495242178440094],
[-2.5843818187713623, -0.8000828623771667],
[0.6732649803161621, -1.4362232685089111],
[-1.5286413431167603, 1.0417330265045166],
[-1.1222513914108276, -0.6269875764846802],
[-0.9752035140991211, -0.8750635385513306],
[-2.6369473934173584, 0.6918523907661438],
[0.14478731155395508, -0.041986867785453796],
[-1.5629483461380005, 1.4369450807571411],
[0.38952457904815674, -2.16428804397583],
[-0.16885095834732056, 0.7976621985435486],
[-3.12416934967041, 1.256506085395813],
[0.6843105554580688, -0.4203019142150879],
[1.9345275163650513, 1.934950351715088],
[0.012184220366179943, -2.1080918312072754],
[-0.6350273489952087, 0.7358828186988831],
[-0.837304949760437, -0.6214472651481628],
[0.08211923390626907, -0.9472538232803345],
[2.9332995414733887, -1.4956780672073364],
[1.3806978464126587, -0.2916182279586792],
[0.06773144006729126, 0.9285762310028076],
[-1.1943119764328003, 1.5963770151138306],
[1.6395620107650757, -0.32285431027412415],
[-1.390851378440857, -0.08273141086101532],
[1.816330909729004, -1.2812227010726929],
[0.7921574711799622, -2.1135804653167725],
[0.5817914605140686, 1.2644577026367188],
[1.929347038269043, -0.2386285960674286],
[0.8877345323562622, 1.190008521080017],
[1.4732073545455933, 0.8935023546218872],
[-2.8518524169921875, -1.5478795766830444],
[0.2439267635345459, 0.7576767802238464],
[0.5246709585189819, -2.606659412384033],
[1.150876760482788, 1.4073830842971802],
[-0.2643202245235443, 2.0634236335754395],
[1.555483341217041, -0.0023102816194295883],
[2.0830578804016113, -1.7225427627563477],
[-0.5424830317497253, -1.070199728012085],
[0.9168899655342102, 0.8955540060997009],
[-0.8120972514152527, 2.696739912033081],
[-0.29908373951911926, -1.5310651063919067],
[1.2320337295532227, -1.556247353553772],
[1.8612544536590576, 0.08704725652933121],
[0.22133447229862213, -1.8091708421707153],
[-0.4403655230998993, -0.38571012020111084],
[-1.88539457321167, 1.192205786705017],
[2.239687919616699, 0.004709010478109121],
[1.139495611190796, 0.45733731985092163],
[-1.507995367050171, 0.19716016948223114],
[0.46986445784568787, 1.5422041416168213],
[-1.2573751211166382, -0.35984551906585693],
[-1.7415345907211304, -0.6020717024803162],
[1.0751984119415283, 0.19006384909152985],
[2.24186635017395, -0.46343153715133667],
[0.3610347509384155, -0.07658443599939346],
[-1.3111497163772583, 0.432013601064682],
[0.6164408326148987, 0.24538464844226837],
[-1.9266542196273804, -0.3256155550479889],
[-0.5870336890220642, -0.1879584938287735],
[-1.0476511716842651, 0.3677721917629242],
[-1.229940414428711, 1.2433830499649048],
[0.18550436198711395, 0.22753673791885376],
[-0.017921989783644676, 0.12625974416732788],
[1.1659504175186157, -0.5020995736122131],
[-0.5983408093452454, -1.40438973903656],
[0.7519024014472961, -0.16282692551612854],
[0.9920787811279297, -1.344896912574768],
[-0.8103678226470947, 0.3064485788345337],
[0.6956969499588013, 1.8208192586898804],
[-2.7830491065979004, -0.2299390584230423],
[-0.34681546688079834, 2.4890666007995605],
[-1.4452646970748901, -1.2216600179672241],
[-2.1872897148132324, 0.8926076292991638],
[1.706072211265564, -2.8440372943878174],
[1.1119003295898438, -2.4923460483551025],
[-2.582794666290283, 2.0973289012908936],
[0.04987720400094986, -0.2964983284473419],
[-2.063807487487793, -0.7847916483879089],
[-0.4068813621997833, 0.9135897755622864],
[-0.9814359545707703, -0.3874954879283905],
[-1.4227229356765747, 0.7337291240692139],
[0.3065044581890106, 1.3125417232513428],
[1.2160996198654175, -1.9643305540084839],
[-1.2163853645324707, 0.14608727395534515],
[-2.3030710220336914, -0.37558120489120483],
[0.9232977628707886, 2.1843791007995605],
[-0.1989777386188507, 1.651851773262024],
[-0.714374840259552, -0.39365994930267334],
[-0.7805715799331665, -2.099881887435913],
[0.9015759229660034, -1.7053706645965576],
[0.1033422127366066, 1.5256654024124146],
[-1.8773194551467896, 2.324174165725708],
[1.9227174520492554, 2.7441604137420654],
[-0.5994020104408264, 0.23984014987945557],
[1.3496100902557373, -0.9126054644584656],
[-0.8765304088592529, -3.1877026557922363],
[-1.2040035724639893, -1.5169521570205688],
[1.4261796474456787, 2.150200128555298],
[1.463774561882019, 1.6656692028045654],
[0.20364105701446533, -0.4988172650337219],
[0.5195154547691345, -0.24067887663841248],
[-1.1116786003112793, -1.1599653959274292],
[-0.8490808606147766, -0.1681060940027237],
[0.3189965784549713, -0.9641751646995544],
[-0.5664751529693604, -0.5951744318008423],
[-1.6347930431365967, -0.9137664437294006],
[0.44048091769218445, -0.47259435057640076],
[-2.147747039794922, 0.47442489862442017],
[1.834734320640564, 1.4462147951126099],
[1.1777573823928833, 1.0659226179122925],
[-0.9568989872932434, 0.09495053440332413],
[-1.838529348373413, 0.2950586676597595],
[-0.4800611734390259, 0.014894310384988785],
[-0.5235516428947449, -1.7687653303146362],
[2.0735011100769043, -0.8825281262397766],
[2.637502431869507, 0.8455678224563599],
[2.606602907180786, -0.7848446369171143],
[-1.1886937618255615, 0.9330510497093201],
[0.38082656264305115, 0.13328030705451965],
[0.6847941875457764, 0.7384101152420044],
[1.2638574838638306, -0.007309418171644211],
[0.18292222917079926, -1.22371244430542],
[0.8143821954727173, 1.4976691007614136],
[0.6571850776672363, 0.48368802666664124],
[-0.6991601586341858, 2.150190830230713],
[0.8101756572723389, 0.10206498205661774],
[-0.08768226951360703, -1.084917664527893],
[-0.7208092212677002, 0.03657956421375275],
[0.3211449086666107, 1.803687334060669],
[-0.7835946083068848, 1.6869111061096191],
]
)
if (p, n) == (2, 64):
return torch.tensor(
[
[-2.7216711044311523, 0.14431366324424744],
[-0.766914427280426, 1.7193410396575928],
[-2.2575762271881104, 1.2476624250411987],
[1.233758807182312, -2.3560616970062256],
[0.8701965808868408, -0.2649352252483368],
[1.4506438970565796, 2.1776366233825684],
[-0.06305818259716034, 1.9049758911132812],
[2.536226511001587, 0.563927412033081],
[0.4599496126174927, -1.8745561838150024],
[-1.900517225265503, -0.30703988671302795],
[0.09386251866817474, 0.8755807280540466],
[1.946500539779663, -0.6743080615997314],
[2.1338934898376465, 1.4581491947174072],
[0.9429940581321716, -0.8038390278816223],
[2.0697755813598633, -1.614896535873413],
[0.772676408290863, 0.22017823159694672],
[1.0689979791641235, -1.525044322013855],
[0.6813604831695557, 1.1345642805099487],
[0.4706456661224365, 2.606626272201538],
[-1.294018030166626, -0.4372096061706543],
[-0.09134224057197571, 0.4610418677330017],
[-0.7907772064208984, -0.48412787914276123],
[0.060459110885858536, -0.9172890186309814],
[-0.5855047702789307, 2.56172513961792],
[0.11484206467866898, -2.659848213195801],
[-1.5893300771713257, 2.188580274581909],
[1.6750942468643188, 0.7089915871620178],
[-0.445697546005249, 0.7452405095100403],
[-1.8539940118789673, -1.8377939462661743],
[-1.5791912078857422, -1.017285943031311],
[-1.030419945716858, -1.5746369361877441],
[-1.9511750936508179, 0.43696075677871704],
[-0.3446580767631531, -1.8953213691711426],
[-1.4219647645950317, 0.7676230669021606],
[-0.9191089272499084, 0.5021472573280334],
[0.20464491844177246, 1.3684605360031128],
[0.5402919054031372, 0.6699410676956177],
[1.8903915882110596, 0.03638288006186485],
[0.4723062515258789, -0.6216739416122437],
[-0.41345009207725525, -0.22752176225185394],
[2.7119064331054688, -0.5111885070800781],
[1.065286636352539, 0.6950305700302124],
[0.40629103779792786, -0.14339995384216309],
[1.2815024852752686, 0.17108257114887238],
[0.01785222627222538, -0.43778058886528015],
[0.054590027779340744, -1.4225547313690186],
[0.3076786696910858, 0.30697619915008545],
[-0.9498570561408997, -0.9576997756958008],
[-2.4640724658966064, -0.9660449028015137],
[1.3714425563812256, -0.39760473370552063],
[-0.4857747256755829, 0.2386789172887802],
[1.2797833681106567, 1.3097363710403442],
[0.5508887767791748, -1.1777795553207397],
[-1.384316325187683, 0.1465839296579361],
[-0.46556955575942993, -1.2442727088928223],
[-0.3915477693080902, -0.7319604158401489],
[-1.4005504846572876, 1.3890998363494873],
[-0.8647305965423584, 1.0617644786834717],
[-0.8901953101158142, -0.01650036871433258],
[-0.9893633723258972, -2.4662880897521973],
[1.445534110069275, -1.049334168434143],
[-0.041650623083114624, 0.012734669260680676],
[-0.3302375078201294, 1.26217782497406],
[0.6934980154037476, 1.7714335918426514],
]
)
elif (p, n) == (2, 16):
return torch.tensor(
[
[-0.8996632695198059, -1.6360418796539307],
[-0.961183488368988, 1.5999565124511719],
[-1.882026195526123, 0.678778350353241],
[0.36300793290138245, -1.9667866230010986],
[-0.6814072728157043, -0.576818585395813],
[0.7270012497901917, 0.6186859607696533],
[0.3359416127204895, 1.8371193408966064],
[1.859930396080017, 0.036668598651885986],
[0.17208248376846313, -0.9401724338531494],
[-1.7599700689315796, -0.6244229674339294],
[-0.8993809223175049, 0.32267823815345764],
[0.839488685131073, -0.3017036020755768],
[1.5314953327178955, 1.2942044734954834],
[-0.0011779458727687597, 0.00022069070837460458],
[1.4274526834487915, -1.207889199256897],
[-0.16123905777931213, 0.8787511587142944],
]
)
elif (p, n) == (1, 16):
return torch.tensor(
[
[-2.7325894832611084],
[-2.069017171859741],
[-1.6180464029312134],
[-1.2562311887741089],
[-0.9423404335975647],
[-0.6567591428756714],
[-0.38804829120635986],
[-0.12839503586292267],
[0.12839503586292267],
[0.38804829120635986],
[0.6567591428756714],
[0.9423404335975647],
[1.2562311887741089],
[1.6180464029312134],
[2.069017171859741],
[2.7325894832611084],
]
)
elif (p, n) == (1, 8):
return torch.tensor(
[
[-2.1519455909729004],
[-1.3439092636108398],
[-0.7560052871704102],
[-0.2450941801071167],
[0.2450941801071167],
[0.7560052871704102],
[1.3439092636108398],
[2.1519455909729004],
]
)
elif (p, n) == (1, 4):
return torch.tensor([[-1.5104175806045532], [-0.4527800381183624], [0.4527800381183624], [1.5104175806045532]])
else:
raise NotImplementedError(f"Unsupported p={p}, n={n}")
def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256, hadamard_size: int = 1024):
assert len(weight.shape) == 2, "Only 2D weights are supported for now"
grid = get_higgs_grid(p, 2 ** (p * bits)).to(weight.device)
grid_norm_2 = torch.linalg.norm(grid, axis=-1) ** 2
device = weight.device
dtype = weight.dtype
weight = weight.clone().float()
# Pad to Hadamard transform size
weight = pad_to_block(weight, [1], hadamard_size)
# Scale and Hadamard transform
mult = weight.shape[1] // hadamard_size
weight = weight.reshape(-1, mult, hadamard_size)
scales = torch.linalg.norm(weight, axis=-1)
weight = hadamard_transform(weight, 1) / scales[:, :, None]
# Pad to edenn_d and project
weight = pad_to_block(weight, [2], p).reshape(weight.shape[0], mult, -1, p)
# Quantize
codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8)
for i in range(0, weight.shape[0], 64):
codes[i : i + 64] = torch.argmax(2 * weight[i : i + 64] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
del weight
codes = codes.reshape(codes.shape[0], -1)
scales = scales / sqrt(hadamard_size)
weight, scales, tables, tables2 = prepare_data_transposed(
codes,
torch.repeat_interleave(scales.to(dtype), hadamard_size // group_size, dim=1),
grid.to(dtype),
num_bits=bits,
group_size=group_size,
vector_size=p,
dtype=dtype,
device=device,
)
return {
"weight": weight,
"scales": scales,
"tables": tables,
"tables2": tables2.view(dtype=torch.float16),
}
class HiggsLinear(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
num_bits: int,
bias=True,
dtype: torch.dtype = None,
device: torch.device = None,
group_size: int = 256,
hadamard_size: int = 1024,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.num_bits = num_bits
self.group_size = group_size
self.hadamard_size = hadamard_size
self.num_sms_packed = nn.Parameter(torch.tensor(-1, dtype=torch.int32, device=device), requires_grad=False)
assert in_features % group_size == 0
assert num_bits in [2, 3, 4]
self.weight = nn.Parameter(
torch.empty((out_features * num_bits // 16, in_features), dtype=torch.int16, device=device),
requires_grad=False,
)
self.scales = nn.Parameter(
torch.empty((out_features, in_features // group_size), dtype=dtype, device=device), requires_grad=False
)
self.tables = nn.Parameter(torch.empty((2**num_bits,), dtype=dtype, device=device), requires_grad=False)
self.tables2 = nn.Parameter(
torch.empty((2**num_bits, 2**num_bits, 2), dtype=dtype, device=device), requires_grad=False
)
if bias:
self.bias = nn.Parameter(torch.empty(out_features, device=device, dtype=dtype), requires_grad=False)
else:
self.register_parameter("bias", None)
self.workspace = None # must be set externally to be reused among layers
def forward(self, x):
x = pad_to_block(x, [-1], self.hadamard_size)
if self.workspace is None:
raise Exception("Workspace must be set before calling forward")
return flute.qgemm_hadamard(
x,
self.weight,
self.scales,
self.tables,
self.tables2.view(dtype=torch.float32),
self.workspace,
self.num_bits,
self.group_size,
self.hadamard_size,
)
def replace_with_higgs_linear(
model,
quantization_config=None,
current_key_name=None,
has_been_replaced=False,
):
"""
Public method that recursively replaces the Linear layers of the given model with HIGGS quantized layers.
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
conversion has been successfull or not.
Args:
model (`torch.nn.Module`):
The model to convert, can be any `torch.nn.Module` instance.
quantization_config (`HiggsConfig`):
The quantization config object that contains the quantization parameters.
current_key_name (`list`, *optional*):
A list that contains the current key name. This is used for recursion and should not be passed by the user.
has_been_replaced (`bool`, *optional*):
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
should not be passed by the user.
"""
from accelerate import init_empty_weights
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear):
# Check if the current key is not in the `quantization_config.modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
if not any(current_key_name_str.endswith(key) for key in quantization_config.modules_to_not_convert):
with init_empty_weights():
in_features = module.in_features
out_features = module.out_features
model._modules[name] = HiggsLinear(
in_features,
out_features,
bias=module.bias is not None,
num_bits=quantization_config.bits,
hadamard_size=quantization_config.hadamard_size,
group_size=quantization_config.group_size,
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = replace_with_higgs_linear(
module,
quantization_config=quantization_config,
current_key_name=current_key_name,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def dequantize_higgs(model, current_key_name=None):
"""
Dequantizes the HiggsLinear layers in the given model by replacing them with standard torch.nn.Linear layers.
Args:
model (torch.nn.Module): The model containing HiggsLinear layers to be dequantized.
current_key_name (list, optional): A list to keep track of the current module names during recursion. Defaults to None.
Returns:
torch.nn.Module: The model with HiggsLinear layers replaced by torch.nn.Linear layers.
"""
with torch.no_grad():
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, HiggsLinear):
in_features = module.in_features
out_features = module.out_features
model._modules[name] = torch.nn.Linear(
in_features,
out_features,
bias=module.bias is not None,
device=module.scales.device,
dtype=module.scales.dtype,
)
model._modules[name].weight.data = module(
torch.eye(in_features, device=module.scales.device, dtype=module.scales.dtype)
).T.contiguous()
if len(list(module.children())) > 0:
_ = dequantize_higgs(
module,
current_key_name=current_key_name,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model

View file

@ -24,6 +24,7 @@ from ..utils.quantization_config import (
EetqConfig,
FbgemmFp8Config,
GPTQConfig,
HiggsConfig,
HqqConfig,
QuantizationConfigMixin,
QuantizationMethod,
@ -40,6 +41,7 @@ from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer
from .quantizer_eetq import EetqHfQuantizer
from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_higgs import HiggsHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
from .quantizer_torchao import TorchAoHfQuantizer
@ -54,6 +56,7 @@ AUTO_QUANTIZER_MAPPING = {
"aqlm": AqlmHfQuantizer,
"quanto": QuantoHfQuantizer,
"eetq": EetqHfQuantizer,
"higgs": HiggsHfQuantizer,
"hqq": HqqHfQuantizer,
"compressed-tensors": CompressedTensorsHfQuantizer,
"fbgemm_fp8": FbgemmFp8HfQuantizer,
@ -73,6 +76,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"hqq": HqqConfig,
"compressed-tensors": CompressedTensorsConfig,
"fbgemm_fp8": FbgemmFp8Config,
"higgs": HiggsConfig,
"torchao": TorchAoConfig,
"bitnet": BitNetConfig,
"vptq": VptqConfig,

View file

@ -0,0 +1,232 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from .base import HfQuantizer
from .quantizers_utils import get_module_from_name
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..utils import is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
def get_num_sms_from_device(device):
target_device_cc = torch.cuda.get_device_capability(device=device)
if target_device_cc == (8, 6):
return 84
elif target_device_cc == (8, 0):
return 108
elif target_device_cc == (8, 9):
return 128
else:
raise NotImplementedError(
f"Device capability {target_device_cc} not supported for FLUTE (yet?) to verify your device capability check out https://developer.nvidia.com/cuda-gpus"
)
class HiggsHfQuantizer(HfQuantizer):
"""
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
"""
requires_calibration = False
requires_parameters_quantization = True
required_packages = ["flute-kernel", "fast_hadamard_transform"]
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config
def validate_environment(self, device_map, **kwargs):
if not torch.cuda.is_available():
raise NotImplementedError("HIGGS quantization is only supported on GPU. Please use a different quantizer.")
if not is_accelerate_available():
raise ImportError("Using `higgs` quantization requires Accelerate: `pip install accelerate`")
if not is_flute_available():
raise ImportError("Using `higgs` quantization requires FLUTE: `pip install flute-kernel>=0.3.0`")
if not is_hadamard_available():
raise ImportError(
"Using `higgs` quantization requires fast_hadamard_transform: `pip install fast_hadamard_transform`"
)
if device_map is None:
raise ValueError(
"You are attempting to load a HIGGS model without setting device_map."
" Please set device_map comprised of 'cuda' devices."
)
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
raise ValueError(
"You are attempting to load a HIGGS model with a device_map that contains a CPU or disk device."
" This is not supported. Please remove the CPU or disk device from the device_map."
)
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.float16` for FLUTE compatibility.")
torch_dtype = torch.float16
elif torch_dtype != torch.float16 and torch_dtype != torch.bfloat16:
raise ValueError(
f"Invalid `torch_dtype` {torch_dtype}. HIGGS quantization only supports `torch_dtype=torch.float16` or `torch_dtype=torch.bfloat16`."
)
return torch_dtype
def create_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
):
from ..integrations import quantize_with_higgs
"""
Quantizes weights into weight and weight_scale
"""
flute_dict = quantize_with_higgs(
param_value.to(target_device),
self.quantization_config.bits,
self.quantization_config.p,
self.quantization_config.group_size,
self.quantization_config.hadamard_size,
)
del param_value
module, tensor_name = get_module_from_name(model, param_name)
for key, value in flute_dict.items():
if key in module._parameters:
module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
elif key in module._buffers:
module._buffers[key] = torch.nn.Buffer(value)
else:
raise ValueError(f"Unexpected key {key} in module {module}")
if unexpected_keys is not None and param_name in unexpected_keys:
unexpected_keys.remove(param_name)
module.num_sms_packed = torch.nn.Parameter(
torch.tensor(get_num_sms_from_device(target_device), device=target_device, dtype=torch.int32),
requires_grad=False,
)
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
**kwargs,
):
from ..integrations import replace_with_higgs_linear
replace_with_higgs_linear(
model,
quantization_config=self.quantization_config,
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
import flute.utils
from ..integrations import HiggsLinear
flute_workspaces = {}
for name, module in model.named_modules():
if isinstance(module, HiggsLinear):
# Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
# This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
if module.weight.device not in flute_workspaces:
flute_workspaces[module.weight.device] = flute.utils.make_workspace_streamk(
device=module.weight.device
)
module.workspace = flute_workspaces[module.weight.device]
# FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
# If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
if module.num_sms_packed.item() != get_num_sms_from_device(module.weight.device):
new_device = module.weight.device
new_num_sms = get_num_sms_from_device(new_device)
module.weight.data = flute.utils.pack(
flute.utils.unpack(
weight=module.weight.data,
scales=module.scales.data,
workspace=module.workspace,
num_bits=module.num_bits,
group_size=module.group_size,
num_sms_packed=module.num_sms_packed.item(),
).T.contiguous(),
module.num_bits,
module.group_size,
)
module.num_sms_packed = torch.nn.Parameter(
torch.tensor(new_num_sms, device=new_device, dtype=torch.int32),
requires_grad=False,
)
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
from ..integrations import HiggsLinear
not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, HiggsLinear):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")
and not missing.endswith(".weight")
and not missing.endswith(".bias")
):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]
@property
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
return False
def is_serializable(self, safe_serialization=None):
return True
def check_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
from ..integrations import HiggsLinear
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, HiggsLinear) and tensor_name == "weight" and param_value.dtype != torch.int16:
# Only quantize weights of HiggsLinear modules that are not already quantized
return True
else:
return False
def _dequantize(self, model):
from ..integrations import dequantize_higgs
model = dequantize_higgs(model)
return model

View file

@ -79,12 +79,14 @@ from .utils import (
is_fbgemm_gpu_available,
is_flash_attn_2_available,
is_flax_available,
is_flute_available,
is_fsdp_available,
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
is_gguf_available,
is_grokadamw_available,
is_hadamard_available,
is_ipex_available,
is_jieba_available,
is_jinja_available,
@ -1239,6 +1241,15 @@ def require_fbgemm_gpu(test_case):
return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case)
def require_flute_hadamard(test_case):
"""
Decorator marking a test that requires higgs and hadamard
"""
return unittest.skipUnless(
is_flute_available() and is_hadamard_available(), "test requires flute and fast_hadamard_transform"
)(test_case)
def require_phonemizer(test_case):
"""
Decorator marking a test that requires phonemizer

View file

@ -140,12 +140,14 @@ from .import_utils import (
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_flax_available,
is_flute_available,
is_fsdp_available,
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
is_gguf_available,
is_grokadamw_available,
is_hadamard_available,
is_hqq_available,
is_in_notebook,
is_ipex_available,

View file

@ -128,6 +128,7 @@ except importlib.metadata.PackageNotFoundError:
_faiss_available = False
_ftfy_available = _is_package_available("ftfy")
_g2p_en_available = _is_package_available("g2p_en")
_hadamard_available = _is_package_available("fast_hadamard_transform")
_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True)
_jieba_available = _is_package_available("jieba")
_jinja_available = _is_package_available("jinja2")
@ -332,6 +333,10 @@ def is_torch_deterministic():
return True
def is_hadamard_available():
return _hadamard_available
def is_hqq_available(min_version: str = HQQ_MIN_VERSION):
return _hqq_available and version.parse(_hqq_version) >= version.parse(min_version)
@ -615,6 +620,13 @@ def is_flax_available():
return _flax_available
def is_flute_available():
try:
return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.3.0"
except importlib.metadata.PackageNotFoundError:
return False
def is_ftfy_available():
return _ftfy_available

View file

@ -42,6 +42,7 @@ class QuantizationMethod(str, Enum):
VPTQ = "vptq"
QUANTO = "quanto"
EETQ = "eetq"
HIGGS = "higgs"
HQQ = "hqq"
COMPRESSED_TENSORS = "compressed-tensors"
FBGEMM_FP8 = "fbgemm_fp8"
@ -1340,6 +1341,58 @@ class FbgemmFp8Config(QuantizationConfigMixin):
return loading_attibutes_dict
@dataclass
class HiggsConfig(QuantizationConfigMixin):
"""
HiggsConfig is a configuration class for quantization using the HIGGS method.
Args:
bits (int, *optional*, defaults to 4):
Number of bits to use for quantization. Can be 2, 3 or 4. Default is 4.
p (int, *optional*, defaults to 2):
Quantization grid dimension. 1 and 2 are supported. 2 is always better in practice. Default is 2.
modules_to_not_convert (`list`, *optional*, default to ["lm_head"]):
List of linear layers that should not be quantized.
hadamard_size (int, *optional*, defaults to 512):
Hadamard size for the HIGGS method. Default is 512. Input dimension of matrices is padded to this value. Decreasing this below 512 will reduce the quality of the quantization.
group_size (int, *optional*, defaults to 256):
Group size for the HIGGS method. Can be 64, 128 or 256. Decreasing it barely affects the performance. Default is 256. Must be a divisor of hadamard_size.
"""
def __init__(
self,
bits: int = 4,
p: int = 2,
modules_to_not_convert: Optional[List[str]] = None,
hadamard_size: int = 512,
group_size: int = 256,
**kwargs,
):
if modules_to_not_convert is None:
modules_to_not_convert = ["lm_head"]
self.quant_method = QuantizationMethod.HIGGS
self.bits = bits
self.p = p
self.modules_to_not_convert = modules_to_not_convert
self.hadamard_size = hadamard_size
self.group_size = group_size
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if self.bits not in [2, 3, 4]:
raise ValueError("bits must be 2, 3, or 4")
if self.p not in [1, 2]:
raise ValueError("p must be 1 or 2. 2 is always better in practice")
if self.group_size not in [64, 128, 256]:
raise ValueError("group_size must be 64, 128, or 256")
if self.hadamard_size % self.group_size != 0:
raise ValueError("hadamard_size must be divisible by group_size")
@dataclass
class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques.

View file

View file

@ -0,0 +1,197 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HiggsConfig, OPTForCausalLM
from transformers.testing_utils import (
require_accelerate,
require_flute_hadamard,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
from transformers.utils import is_accelerate_available, is_torch_available
if is_torch_available():
import torch
if is_accelerate_available():
from accelerate import init_empty_weights
@require_torch_gpu
class HiggsConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
"""
quantization_config = HiggsConfig()
config_to_dict = quantization_config.to_dict()
for key in config_to_dict:
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
def test_from_dict(self):
"""
Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
"""
dict = {"modules_to_not_convert": ["embed_tokens", "lm_head"], "quant_method": "higgs"}
quantization_config = HiggsConfig.from_dict(dict)
self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert)
self.assertEqual(dict["quant_method"], quantization_config.quant_method)
@slow
@require_torch_gpu
@require_flute_hadamard
@require_accelerate
# @require_read_token
class HiggsTest(unittest.TestCase):
model_name = "meta-llama/Meta-Llama-3.1-8B"
input_text = "A quick brown fox jumps over the"
max_new_tokens = 2
EXPECTED_OUTPUT = "A quick brown fox jumps over the lazy dog"
device_map = "cuda"
# called only once for all test in this class
@classmethod
def setUpClass(cls):
"""
Setup quantized model
"""
quantization_config = HiggsConfig()
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name, device_map=cls.device_map, quantization_config=quantization_config
)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_quantized_model_conversion(self):
"""
Simple test that checks if the quantized model has been converted properly
"""
from transformers.integrations import HiggsLinear, replace_with_higgs_linear
model_id = "facebook/opt-350m"
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
quantization_config = HiggsConfig()
with init_empty_weights():
model = OPTForCausalLM(config)
nb_linears = 0
for module in model.modules():
if isinstance(module, torch.nn.Linear):
nb_linears += 1
model, _ = replace_with_higgs_linear(model, quantization_config=quantization_config)
nb_higgs_linear = 0
for module in model.modules():
if isinstance(module, HiggsLinear):
nb_higgs_linear += 1
self.assertEqual(nb_linears - 1, nb_higgs_linear)
with init_empty_weights():
model = OPTForCausalLM(config)
quantization_config = HiggsConfig(modules_to_not_convert=["fc1"])
model, _ = replace_with_higgs_linear(model, quantization_config=quantization_config)
nb_higgs_linear = 0
for module in model.modules():
if isinstance(module, HiggsLinear):
nb_higgs_linear += 1
self.assertEqual(nb_linears - 24, nb_higgs_linear)
def test_quantized_model(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_save_pretrained(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantization_config = HiggsConfig()
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, device_map="auto", quantization_config=quantization_config
)
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_save_pretrained_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto")
self.assertTrue(set(model.hf_device_map.values()) == {0, 1})
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@unittest.skip("This will almost surely OOM. Enable when swithed to a smaller model")
def test_dequantize(self):
"""
Test the ability to dequantize a model
"""
self.quantized_model.dequantize()
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)