From 93a233e82c30632e0c5694fbba163e7ad56f80cd Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 10 Dec 2024 15:07:17 +0000 Subject: [PATCH] add_dtype_check --- src/transformers/quantizers/quantizer_bnb_4bit.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 98d57e225..30aeed8b7 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -142,6 +142,11 @@ class Bnb4BitHfQuantizer(HfQuantizer): module, tensor_name = get_module_from_name(model, param_name) if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit): # Add here check for loaded components' dtypes once serialization is implemented + if self.pre_quantized: + if param_value.dtype != torch.uint8: + raise ValueError( + f"Incompatible dtype `{param_value.dtype}` when loading 4-bit prequantized weight. Expected `torch.uint8`." + ) return True elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias": # bias could be loaded by regular set_module_tensor_to_device() from accelerate,