From 83912834cc28d732e92b065617ec374bb4adbd81 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Wed, 5 Feb 2025 09:44:31 +0000 Subject: [PATCH] fix quantization logic --- src/transformers/quantizers/quantizer_fp8.py | 93 +++++++++++--------- 1 file changed, 49 insertions(+), 44 deletions(-) diff --git a/src/transformers/quantizers/quantizer_fp8.py b/src/transformers/quantizers/quantizer_fp8.py index fceb3eeb1..58e7f4cb1 100644 --- a/src/transformers/quantizers/quantizer_fp8.py +++ b/src/transformers/quantizers/quantizer_fp8.py @@ -72,52 +72,41 @@ class FP8HfQuantizer(HfQuantizer): # Get FP8 min/max values fp8_min = torch.finfo(torch.float8_e4m3fn).min fp8_max = torch.finfo(torch.float8_e4m3fn).max - - if self.quantization_config.weight_block_size is not None: - + if self.quantization_config.weight_block_size is None: + self.quantization_config.weight_block_size = (128, 128) + else : block_size_m, block_size_n = self.quantization_config.weight_block_size + + rows, cols = param_value.shape[-2:] + + # Check if dimensions are divisible by block sizes + if rows % block_size_m != 0 or cols % block_size_n != 0: + raise ValueError( + f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})" + ) + + param_value_orig_shape = param_value.shape + + # Create blocks using unfold + param_value = param_value.reshape(-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n).permute(0, 1, 3, 2, 4) + + # Calculate scaling factor for each block + max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2)) + scale = fp8_max / max_abs + scale_orig_shape = scale.shape + # Expand scale to match block dimensions for multiplication + scale = scale.unsqueeze(-1).unsqueeze(-1) + + # Quantize the weights + quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + quantized_param = quantized_param.permute(0, 1, 3, 2, 4) + # Reshape back to matrix shape + quantized_param = quantized_param.reshape(param_value_orig_shape) + + # Reshape scale to match the number of blocks + scale = scale.reshape(scale_orig_shape).squeeze().reciprocal() - rows, cols = param_value.shape[-2:] - - # Check if dimensions are divisible by block sizes - if rows % block_size_m != 0 or cols % block_size_n != 0: - raise ValueError( - f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})" - ) - - - # Create blocks using unfold - param_value = param_value.unfold(-2, block_size_m, block_size_m) - param_value = param_value.unfold(-2, block_size_n, block_size_n) - - # Calculate scaling factor for each block - max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2)) - scale = fp8_max / max_abs - - # Expand scale to match block dimensions for multiplication - scale = scale.unsqueeze(-1).unsqueeze(-1) - - # Quantize the weights - quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - # Reshape back to matrix shape - quantized_param = quantized_param.reshape(param_value.shape[:-4] + (rows, cols)) - - # Reshape scale to match the number of blocks - scale = scale.reshape(scale.shape[:-2] + (-1,)).squeeze(-1).reciprocal() - - else: - # Per-tensor quantization - max_abs = torch.max(torch.abs(param_value)) - # print("###################max_abs#################", max_abs) - scale = fp8_max / max_abs - # print("###################scale#################", scale) - # Quantize the weights - quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - # For per-tensor quantization, we just need a single scale value - scale = torch.tensor([[scale]]).reciprocal() - # print("###################after reciprocal scale#################", scale) - # Store the quantized weights and scales in the module module._parameters[tensor_name] = quantized_param.to(target_device) module.register_parameter("weight_scale_inv", nn.Parameter(scale.to(target_device))) @@ -173,6 +162,22 @@ class FP8HfQuantizer(HfQuantizer): max_memory = {key: val * 0.90 for key, val in max_memory.items()} return max_memory + + def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: + from ..integrations import FP8Linear + + not_missing_keys = [] + for name, module in model.named_modules(): + if isinstance(module, FP8Linear): + for missing in missing_keys: + if ( + (name in missing or name in f"{prefix}.{missing}") + and not missing.endswith(".weight") + and not missing.endswith(".bias") + ): + not_missing_keys.append(missing) + return [k for k in missing_keys if k not in not_missing_keys] + def is_serializable(self, safe_serialization=None): return True