fix quantization logic

This commit is contained in:
MekkCyber 2025-02-05 09:44:31 +00:00
parent 70749dfd9b
commit 83912834cc

View file

@ -72,52 +72,41 @@ class FP8HfQuantizer(HfQuantizer):
# Get FP8 min/max values
fp8_min = torch.finfo(torch.float8_e4m3fn).min
fp8_max = torch.finfo(torch.float8_e4m3fn).max
if self.quantization_config.weight_block_size is not None:
if self.quantization_config.weight_block_size is None:
self.quantization_config.weight_block_size = (128, 128)
else :
block_size_m, block_size_n = self.quantization_config.weight_block_size
rows, cols = param_value.shape[-2:]
# Check if dimensions are divisible by block sizes
if rows % block_size_m != 0 or cols % block_size_n != 0:
raise ValueError(
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})"
)
param_value_orig_shape = param_value.shape
# Create blocks using unfold
param_value = param_value.reshape(-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n).permute(0, 1, 3, 2, 4)
# Calculate scaling factor for each block
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
scale = fp8_max / max_abs
scale_orig_shape = scale.shape
# Expand scale to match block dimensions for multiplication
scale = scale.unsqueeze(-1).unsqueeze(-1)
# Quantize the weights
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
# Reshape back to matrix shape
quantized_param = quantized_param.reshape(param_value_orig_shape)
# Reshape scale to match the number of blocks
scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
rows, cols = param_value.shape[-2:]
# Check if dimensions are divisible by block sizes
if rows % block_size_m != 0 or cols % block_size_n != 0:
raise ValueError(
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})"
)
# Create blocks using unfold
param_value = param_value.unfold(-2, block_size_m, block_size_m)
param_value = param_value.unfold(-2, block_size_n, block_size_n)
# Calculate scaling factor for each block
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
scale = fp8_max / max_abs
# Expand scale to match block dimensions for multiplication
scale = scale.unsqueeze(-1).unsqueeze(-1)
# Quantize the weights
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
# Reshape back to matrix shape
quantized_param = quantized_param.reshape(param_value.shape[:-4] + (rows, cols))
# Reshape scale to match the number of blocks
scale = scale.reshape(scale.shape[:-2] + (-1,)).squeeze(-1).reciprocal()
else:
# Per-tensor quantization
max_abs = torch.max(torch.abs(param_value))
# print("###################max_abs#################", max_abs)
scale = fp8_max / max_abs
# print("###################scale#################", scale)
# Quantize the weights
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
# For per-tensor quantization, we just need a single scale value
scale = torch.tensor([[scale]]).reciprocal()
# print("###################after reciprocal scale#################", scale)
# Store the quantized weights and scales in the module
module._parameters[tensor_name] = quantized_param.to(target_device)
module.register_parameter("weight_scale_inv", nn.Parameter(scale.to(target_device)))
@ -173,6 +162,22 @@ class FP8HfQuantizer(HfQuantizer):
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
return max_memory
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
from ..integrations import FP8Linear
not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, FP8Linear):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")
and not missing.endswith(".weight")
and not missing.endswith(".bias")
):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]
def is_serializable(self, safe_serialization=None):
return True