add_dtype_check

This commit is contained in:
MekkCyber 2024-12-10 15:07:17 +00:00
parent 6acb4e43a7
commit 93a233e82c

View file

@ -142,6 +142,11 @@ class Bnb4BitHfQuantizer(HfQuantizer):
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
# Add here check for loaded components' dtypes once serialization is implemented
if self.pre_quantized:
if param_value.dtype != torch.uint8:
raise ValueError(
f"Incompatible dtype `{param_value.dtype}` when loading 4-bit prequantized weight. Expected `torch.uint8`."
)
return True
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
# bias could be loaded by regular set_module_tensor_to_device() from accelerate,