mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Quantization / HQQ: Fix HQQ tests on our runner (#30668)
Update test_hqq.py
This commit is contained in:
parent
a45c514899
commit
9c772ac888
1 changed files with 3 additions and 3 deletions
|
|
@ -35,7 +35,7 @@ if is_hqq_available():
|
|||
|
||||
|
||||
class HQQLLMRunner:
|
||||
def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir):
|
||||
def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir=None):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=compute_dtype,
|
||||
|
|
@ -118,7 +118,7 @@ class HQQTest(unittest.TestCase):
|
|||
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
||||
check_forward(self, hqq_runner.model)
|
||||
|
||||
def test_bfp16_quantized_model_with_offloading(self):
|
||||
def test_f16_quantized_model_with_offloading(self):
|
||||
"""
|
||||
Simple LLM model testing bfp16 with meta-data offloading
|
||||
"""
|
||||
|
|
@ -137,7 +137,7 @@ class HQQTest(unittest.TestCase):
|
|||
)
|
||||
|
||||
hqq_runner = HQQLLMRunner(
|
||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.bfloat16, device=torch_device
|
||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||
)
|
||||
|
||||
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
||||
|
|
|
|||
Loading…
Reference in a new issue