mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
add_dtype_check
This commit is contained in:
parent
6acb4e43a7
commit
93a233e82c
1 changed files with 5 additions and 0 deletions
|
|
@ -142,6 +142,11 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
|
||||
# Add here check for loaded components' dtypes once serialization is implemented
|
||||
if self.pre_quantized:
|
||||
if param_value.dtype != torch.uint8:
|
||||
raise ValueError(
|
||||
f"Incompatible dtype `{param_value.dtype}` when loading 4-bit prequantized weight. Expected `torch.uint8`."
|
||||
)
|
||||
return True
|
||||
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
|
||||
# bias could be loaded by regular set_module_tensor_to_device() from accelerate,
|
||||
|
|
|
|||
Loading…
Reference in a new issue