mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix quantization logic
This commit is contained in:
parent
70749dfd9b
commit
83912834cc
1 changed files with 49 additions and 44 deletions
|
|
@ -72,52 +72,41 @@ class FP8HfQuantizer(HfQuantizer):
|
|||
# Get FP8 min/max values
|
||||
fp8_min = torch.finfo(torch.float8_e4m3fn).min
|
||||
fp8_max = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
if self.quantization_config.weight_block_size is not None:
|
||||
|
||||
if self.quantization_config.weight_block_size is None:
|
||||
self.quantization_config.weight_block_size = (128, 128)
|
||||
else :
|
||||
block_size_m, block_size_n = self.quantization_config.weight_block_size
|
||||
|
||||
rows, cols = param_value.shape[-2:]
|
||||
|
||||
# Check if dimensions are divisible by block sizes
|
||||
if rows % block_size_m != 0 or cols % block_size_n != 0:
|
||||
raise ValueError(
|
||||
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})"
|
||||
)
|
||||
|
||||
param_value_orig_shape = param_value.shape
|
||||
|
||||
# Create blocks using unfold
|
||||
param_value = param_value.reshape(-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n).permute(0, 1, 3, 2, 4)
|
||||
|
||||
# Calculate scaling factor for each block
|
||||
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
|
||||
scale = fp8_max / max_abs
|
||||
scale_orig_shape = scale.shape
|
||||
# Expand scale to match block dimensions for multiplication
|
||||
scale = scale.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# Quantize the weights
|
||||
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
|
||||
# Reshape back to matrix shape
|
||||
quantized_param = quantized_param.reshape(param_value_orig_shape)
|
||||
|
||||
# Reshape scale to match the number of blocks
|
||||
scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
|
||||
|
||||
rows, cols = param_value.shape[-2:]
|
||||
|
||||
# Check if dimensions are divisible by block sizes
|
||||
if rows % block_size_m != 0 or cols % block_size_n != 0:
|
||||
raise ValueError(
|
||||
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})"
|
||||
)
|
||||
|
||||
|
||||
# Create blocks using unfold
|
||||
param_value = param_value.unfold(-2, block_size_m, block_size_m)
|
||||
param_value = param_value.unfold(-2, block_size_n, block_size_n)
|
||||
|
||||
# Calculate scaling factor for each block
|
||||
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
|
||||
scale = fp8_max / max_abs
|
||||
|
||||
# Expand scale to match block dimensions for multiplication
|
||||
scale = scale.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# Quantize the weights
|
||||
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
# Reshape back to matrix shape
|
||||
quantized_param = quantized_param.reshape(param_value.shape[:-4] + (rows, cols))
|
||||
|
||||
# Reshape scale to match the number of blocks
|
||||
scale = scale.reshape(scale.shape[:-2] + (-1,)).squeeze(-1).reciprocal()
|
||||
|
||||
else:
|
||||
# Per-tensor quantization
|
||||
max_abs = torch.max(torch.abs(param_value))
|
||||
# print("###################max_abs#################", max_abs)
|
||||
scale = fp8_max / max_abs
|
||||
# print("###################scale#################", scale)
|
||||
# Quantize the weights
|
||||
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
# For per-tensor quantization, we just need a single scale value
|
||||
scale = torch.tensor([[scale]]).reciprocal()
|
||||
# print("###################after reciprocal scale#################", scale)
|
||||
# Store the quantized weights and scales in the module
|
||||
module._parameters[tensor_name] = quantized_param.to(target_device)
|
||||
module.register_parameter("weight_scale_inv", nn.Parameter(scale.to(target_device)))
|
||||
|
||||
|
|
@ -173,6 +162,22 @@ class FP8HfQuantizer(HfQuantizer):
|
|||
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
||||
return max_memory
|
||||
|
||||
|
||||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
||||
from ..integrations import FP8Linear
|
||||
|
||||
not_missing_keys = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, FP8Linear):
|
||||
for missing in missing_keys:
|
||||
if (
|
||||
(name in missing or name in f"{prefix}.{missing}")
|
||||
and not missing.endswith(".weight")
|
||||
and not missing.endswith(".bias")
|
||||
):
|
||||
not_missing_keys.append(missing)
|
||||
return [k for k in missing_keys if k not in not_missing_keys]
|
||||
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return True
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue