Quantization / HQQ: Fix HQQ tests on our runner (#30668)

Update test_hqq.py
This commit is contained in:
Younes Belkada 2024-05-06 11:33:52 +02:00 committed by GitHub
parent a45c514899
commit 9c772ac888
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -35,7 +35,7 @@ if is_hqq_available():
class HQQLLMRunner:
def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir):
def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir=None):
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=compute_dtype,
@ -118,7 +118,7 @@ class HQQTest(unittest.TestCase):
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
def test_bfp16_quantized_model_with_offloading(self):
def test_f16_quantized_model_with_offloading(self):
"""
Simple LLM model testing bfp16 with meta-data offloading
"""
@ -137,7 +137,7 @@ class HQQTest(unittest.TestCase):
)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.bfloat16, device=torch_device
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)