Split and clean up GGUF quantization tests (#35502)

* clean up ggml test

Signed-off-by: Isotr0py <2037008807@qq.com>

* port remaining tests

Signed-off-by: Isotr0py <2037008807@qq.com>

* further cleanup

Signed-off-by: Isotr0py <2037008807@qq.com>

* format

Signed-off-by: Isotr0py <2037008807@qq.com>

* fix broken tests

Signed-off-by: Isotr0py <2037008807@qq.com>

* update comment

Signed-off-by: Isotr0py <2037008807@qq.com>

* fix

Signed-off-by: Isotr0py <2037008807@qq.com>

* reorganize tests

Signed-off-by: Isotr0py <2037008807@qq.com>

* k-quants use qwen2.5-0.5B

Signed-off-by: Isotr0py <2037008807@qq.com>

* move ggml tokenization test

Signed-off-by: Isotr0py <2037008807@qq.com>

* remove dead code

Signed-off-by: Isotr0py <2037008807@qq.com>

* add assert for serilization test

Signed-off-by: Isotr0py <2037008807@qq.com>

* use str for parameterize

Signed-off-by: Isotr0py <2037008807@qq.com>

---------

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-01-27 22:46:57 +08:00 committed by GitHub
parent 5c576f5a66
commit e57b459997
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -15,6 +15,8 @@
import tempfile
import unittest
from parameterized import parameterized
from transformers import AddedToken, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.testing_utils import (
require_gguf,
@ -23,20 +25,205 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.utils import is_torch_available
from transformers.utils import is_gguf_available, is_torch_available
if is_torch_available():
import torch
if is_gguf_available():
from gguf import GGMLQuantizationType as QuantType
@require_gguf
@require_torch_gpu
@slow
class GgufQuantizationTests(unittest.TestCase):
"""
Test cases for weights dequantization with GGUF models.
Note: The quantization names should keep aligned with `GGMLQuantizationType` in gguf-py:
https://github.com/ggerganov/llama.cpp/blob/4b0c638b9a68f577cb2066b638c9f622d91ee661/gguf-py/gguf/constants.py#L1545-L1576
So quantization like Q4_K_M or Q4_K_S dshouldn't be added to this tests.
"""
example_text = "Hello"
def run_gguf_model(self, gguf_model_id: str, gguf_filename: str, expected_text: str):
tokenizer = AutoTokenizer.from_pretrained(gguf_model_id, gguf_file=gguf_filename)
model = AutoModelForCausalLM.from_pretrained(gguf_model_id, gguf_file=gguf_filename).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), expected_text)
@parameterized.expand(
[
# standard quants
("Q4_0", "Hello, World!\n\nStep 3: Add"),
("Q5_0", "Hello, World!\n\n5. Use a library"),
("Q8_0", "Hello, World!\n\n5. Use a library"),
],
)
def test_standard_quants(self, quant_type: str, expected_text: str):
gguf_model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
filename_format = "tinyllama-1.1b-chat-v1.0.{quant_type}.gguf"
gguf_filename = filename_format.format(quant_type=quant_type)
self.run_gguf_model(gguf_model_id, gguf_filename, expected_text)
# k-quants
@parameterized.expand(
[
("Q2_K", "Hello, I'm a 22 year old female"),
("Q3_K", "Hello\n\nI am trying to create a simple program that"),
("Q4_K", "Hello\n\nI am trying to create a simple program that"),
("Q5_K", "Helloveda is a 1999 Indian"),
("Q6_K", "Hello\n\nI am trying to create a simple program that"),
],
)
def test_k_quants(self, quant_type: str, expected_text: str):
gguf_model_id = "legraphista/Qwen2.5-0.5B-Instruct-IMat-GGUF"
filename_format = "Qwen2.5-0.5B-Instruct.{quant_type}.gguf"
gguf_filename = filename_format.format(quant_type=quant_type)
self.run_gguf_model(gguf_model_id, gguf_filename, expected_text)
@parameterized.expand(
[
# i-matrix
("IQ1_S", "Hello, I'm a friend of mine, I"),
("IQ1_M", "Hello, I am interested in purching a copy of"),
("IQ2_XXS", "Hello, I'm a software engineer. I'"),
("IQ2_XS", "Hello World!\n\n```\n<|user|"),
("IQ2_S", "Hello World!\n\n```\n<|user|"),
("IQ3_XXS", "Hello, I am interested in your product. Can you"),
("IQ4_XS", "Hello, world!\n\n5. Using a loop"),
("IQ3_S", "Hello, World!\n\n5. Python:\n"),
("IQ4_NL", "Hello, world!\n\n5. Using a loop"),
],
)
def test_imatrix_quants(self, quant_type: str, expected_text: str):
gguf_model_id = "duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF"
filename_format = "TinyLlama-1.1B-Chat-v1.0-{quant_type}.gguf"
gguf_filename = filename_format.format(quant_type=quant_type)
self.run_gguf_model(gguf_model_id, gguf_filename, expected_text)
@require_gguf
@require_torch_gpu
@slow
class GgufIntegrationTests(unittest.TestCase):
"""
Test cases for basic interoperability with GGUF models:
- Tokenization
- Model dtype casting and serialization
"""
example_text = "Hello"
original_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
imatrix_model_id = "duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF"
gguf_model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
gguf_filename = "tinyllama-1.1b-chat-v1.0.{quant_type}.gguf"
def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
q8_0_gguf_model_id = self.gguf_filename.format(quant_type=QuantType.Q8_0.name)
gguf_tokenizer = AutoTokenizer.from_pretrained(self.gguf_model_id, gguf_file=q8_0_gguf_model_id)
original_tokenizer = AutoTokenizer.from_pretrained(self.original_model_id)
dataset = load_dataset("google/code_x_glue_ct_code_to_text", "go")
for item in tqdm.tqdm(dataset["validation"]):
string = item["code"]
encoded1 = gguf_tokenizer.encode(string)
encoded2 = original_tokenizer.encode(string)
self.assertEqual(encoded1, encoded2)
decoded1 = gguf_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = original_tokenizer.decode(encoded2, skip_special_tokens=True)
self.assertEqual(decoded1, decoded2)
dataset = load_dataset("facebook/xnli", "all_languages")
for i, item in enumerate(tqdm.tqdm(dataset["train"].select(range(100)))):
for string in item["premise"].values():
encoded1 = gguf_tokenizer.encode(string)
encoded2 = original_tokenizer.encode(string)
self.assertEqual(encoded1, encoded2)
decoded1 = gguf_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = original_tokenizer.decode(encoded2, skip_special_tokens=True)
self.assertEqual(decoded1, decoded2)
# With special tokens
gguf_tokenizer = AutoTokenizer.from_pretrained(self.gguf_model_id, gguf_file=q8_0_gguf_model_id)
original_tokenizer = AutoTokenizer.from_pretrained(self.original_model_id)
gguf_tokenizer.add_special_tokens(
{"additional_special_tokens": [AddedToken("<token>", rstrip=False, lstrip=False)]}
)
original_tokenizer.add_special_tokens(
{"additional_special_tokens": [AddedToken("<token>", rstrip=False, lstrip=False)]}
)
text = "Hello <token>. <token> Hello"
encoded1 = gguf_tokenizer.encode(text)
encoded2 = original_tokenizer.encode(text)
self.assertEqual(encoded1, encoded2)
decoded1 = gguf_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = original_tokenizer.decode(encoded2, skip_special_tokens=True)
self.assertEqual(decoded1, decoded2)
def test_q2_k_serialization(self):
q2_k_gguf_model_id = self.gguf_filename.format(quant_type=QuantType.Q2_K.name)
EXPECTED_TEXT = "Hello, World!\n\n[10:0"
tokenizer = AutoTokenizer.from_pretrained(self.gguf_model_id, gguf_file=q2_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.gguf_model_id, gguf_file=q2_k_gguf_model_id).to(torch_device)
orig_text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
orig_out = model.generate(**orig_text, max_new_tokens=10)
self.assertEqual(tokenizer.decode(orig_out[0], skip_special_tokens=True), EXPECTED_TEXT)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
tokenizer.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(tmpdirname)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q6_k_fp16(self):
q6_k_gguf_model_id = self.gguf_filename.format(quant_type=QuantType.Q6_K.name)
tokenizer = AutoTokenizer.from_pretrained(self.gguf_model_id, gguf_file=q6_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.gguf_model_id, gguf_file=q6_k_gguf_model_id, torch_dtype=torch.float16
).to(torch_device)
self.assertTrue(model.lm_head.weight.dtype == torch.float16)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
@require_gguf
@require_torch_gpu
@slow
class GgufModelTests(unittest.TestCase):
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"
qwen2moe_model_id = "gdax/Qwen1.5-MoE-A2.7B_gguf"
@ -68,34 +255,13 @@ class GgufIntegrationTests(unittest.TestCase):
original_gemma2_model_id = "google/gemma-2-2b-it"
gemma2_model_id = "bartowski/gemma-2-2b-it-GGUF"
# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
q5_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q5_0.gguf"
q8_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
# k-quants
q2_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q2_K.gguf"
q3_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q3_K_L.gguf"
q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
q5_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
q6_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf"
q4_k_m_stablelm_model_id = "stablelm-3b-4e1t.q4_k_m.gguf"
# imatrix
iq1_m_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ1_M.gguf"
iq1_s_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ1_S.gguf"
iq2_s_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ2_S.gguf"
iq2_xs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ2_XS.gguf"
iq2_xxs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ2_XXS.gguf"
iq3_s_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ3_S.gguf"
iq3_xxs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ3_XXS.gguf"
iq4_xs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf"
iq4_nl_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ4_NL.gguf"
q4_0_phi3_model_id = "Phi-3-mini-4k-instruct-q4.gguf"
q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
q8_qwen2moe_model_id = "Qwen1.5-MoE-A2.7B_Q8_0.gguf"
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
fp16_bloom_model_id = "bloom-560m.fp16.gguf"
q4_k_m_stablelm_model_id = "stablelm-3b-4e1t.q4_k_m.gguf"
fp16_stablelm2_model_id = "stablelm-2-1_6b.fp16.gguf"
q8_bloom_model_id = "bloom-560m.q8_0.gguf"
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"
@ -120,237 +286,6 @@ class GgufIntegrationTests(unittest.TestCase):
example_text = "Hello"
def test_q2_k(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n[10:0"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q2_k_serialization(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q2_k_gguf_model_id).to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
tokenizer.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(tmpdirname)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n[10:0"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q3_k(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q3_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q3_k_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n```\n<|user"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q5_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q5_0_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q5_0_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n5. Use a library"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q5_k(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q5_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q5_k_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q4_0_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q4_0_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q4_k_m(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q4_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q4_k_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n5. Python:\n"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q6_k(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q6_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q6_k_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q6_k_fp16(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q6_k_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.model_id, gguf_file=self.q6_k_gguf_model_id, torch_dtype=torch.float16
).to(torch_device)
self.assertTrue(model.lm_head.weight.dtype == torch.float16)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\nStep 3: Add"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_q8_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n5. Use a library"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq1_s(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_s_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_s_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, I'm a friend of mine, I"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq1_m(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_m_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_m_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, I am interested in purching a copy of"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq2_s(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_s_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_s_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello World!\n\n```\n<|user|"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq2_xs(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xs_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xs_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello World!\n\n```\n<|user|"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq2_xxs(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xxs_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xxs_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, I'm a software engineer. I'"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq3_s(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_s_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_s_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n5. Python:\n"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq3_xxs(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_xxs_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_xxs_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, I am interested in your product. Can you"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq4_xs(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_xs_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_xs_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, world!\n\n5. Using a loop"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_iq4_nl(self):
tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_nl_gguf_model_id)
model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_nl_gguf_model_id).to(
torch_device
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, world!\n\n5. Using a loop"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_f16(self):
tokenizer = AutoTokenizer.from_pretrained(self.tinyllama_model_id, gguf_file=self.f16_tinyllama_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.tinyllama_model_id, gguf_file=self.f16_tinyllama_model_id
).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n5. Node.js"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_mistral_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id)
model = AutoModelForCausalLM.from_pretrained(
@ -904,60 +839,3 @@ class GgufIntegrationTests(unittest.TestCase):
torch.testing.assert_close(original_params, converted_state_dict[layer_name])
else:
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")
def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
gguf_tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id)
original_tokenizer = AutoTokenizer.from_pretrained(self.original_model_id)
dataset = load_dataset("google/code_x_glue_ct_code_to_text", "go")
for item in tqdm.tqdm(dataset["validation"]):
string = item["code"]
encoded1 = gguf_tokenizer.encode(string)
encoded2 = original_tokenizer.encode(string)
self.assertEqual(encoded1, encoded2)
decoded1 = gguf_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = original_tokenizer.decode(encoded2, skip_special_tokens=True)
self.assertEqual(decoded1, decoded2)
dataset = load_dataset("facebook/xnli", "all_languages")
for i, item in enumerate(tqdm.tqdm(dataset["train"].select(range(100)))):
for string in item["premise"].values():
encoded1 = gguf_tokenizer.encode(string)
encoded2 = original_tokenizer.encode(string)
self.assertEqual(encoded1, encoded2)
decoded1 = gguf_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = original_tokenizer.decode(encoded2, skip_special_tokens=True)
self.assertEqual(decoded1, decoded2)
# With special tokens
gguf_tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q8_0_gguf_model_id)
original_tokenizer = AutoTokenizer.from_pretrained(self.original_model_id)
gguf_tokenizer.add_special_tokens(
{"additional_special_tokens": [AddedToken("<token>", rstrip=False, lstrip=False)]}
)
original_tokenizer.add_special_tokens(
{"additional_special_tokens": [AddedToken("<token>", rstrip=False, lstrip=False)]}
)
text = "Hello <token>. <token> Hello"
encoded1 = gguf_tokenizer.encode(text)
encoded2 = original_tokenizer.encode(text)
self.assertEqual(encoded1, encoded2)
decoded1 = gguf_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = original_tokenizer.decode(encoded2, skip_special_tokens=True)
self.assertEqual(decoded1, decoded2)