From 45dc9e42936efc5d51ad30ebb2d50d6e3cc10bef Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 4 Feb 2025 09:58:42 +0000 Subject: [PATCH] first commit --- docs/source/en/_toctree.yml | 2 + docs/source/en/main_classes/quantization.md | 4 + docs/source/en/quantization/fp8.md | 0 src/transformers/__init__.py | 2 + src/transformers/integrations/__init__.py | 1 + src/transformers/integrations/fp8.py | 452 ++++++++++++++++++ src/transformers/quantizers/auto.py | 5 +- src/transformers/quantizers/quantizer_fp8.py | 180 +++++++ src/transformers/utils/quantization_config.py | 24 +- 9 files changed, 668 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/quantization/fp8.md create mode 100644 src/transformers/integrations/fp8.py create mode 100644 src/transformers/quantizers/quantizer_fp8.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 2a2cf4512..c361ef49d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -185,6 +185,8 @@ title: BitNet - local: quantization/compressed_tensors title: compressed-tensors + - local: quantization/fp8 + title: FP8 - local: quantization/contribute title: Contribute new quantization method title: Quantization Methods diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md index 037660d06..b7a0e4cf7 100755 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -80,3 +80,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide. ## BitNetConfig [[autodoc]] BitNetConfig + +## FP8Config + +[[autodoc]] FP8Config diff --git a/docs/source/en/quantization/fp8.md b/docs/source/en/quantization/fp8.md new file mode 100644 index 000000000..e69de29bb diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ae92f21dc..2a26c95ad 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1020,6 +1020,7 @@ _import_structure = { "CompressedTensorsConfig", "EetqConfig", "FbgemmFp8Config", + "FP8Config", "GPTQConfig", "HiggsConfig", "HqqConfig", @@ -6159,6 +6160,7 @@ if TYPE_CHECKING: CompressedTensorsConfig, EetqConfig, FbgemmFp8Config, + FP8Config, GPTQConfig, HiggsConfig, HqqConfig, diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 49dbc5e3a..f76fbe9b1 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -54,6 +54,7 @@ _import_structure = { ], "eetq": ["replace_with_eetq_linear"], "fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"], + "fp8": ["FP8Linear", "FP8MoELinear", "replace_with_fp8_linear"], "fsdp": ["is_fsdp_managed_module"], "ggml": [ "GGUF_CONFIG_MAPPING", diff --git a/src/transformers/integrations/fp8.py b/src/transformers/integrations/fp8.py new file mode 100644 index 000000000..a6e15869e --- /dev/null +++ b/src/transformers/integrations/fp8.py @@ -0,0 +1,452 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from typing import Optional, List, Tuple, Union +from ..utils import is_accelerate_available, logging + +if is_accelerate_available(): + from accelerate import init_empty_weights + +logger = logging.get_logger(__name__) + +ACTIVATION_SCHEMES = ["static", "dynamic"] +quant_dtype = torch.float8_e4m3fn + +def fp8_quantize(weight: torch.Tensor, scale: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize weights to FP8.""" + if scale is None: + # Calculate scale as max value divided by absmax + scale = 448.0 / weight.abs().max().clamp(min=1e-12) + # Scale and clamp tensor to FP8 range + qweight = (weight * scale).clamp(min=-448.0, max=448.0) + scale = scale.float().reciprocal() + else: + qweight = (weight * scale.reciprocal()).clamp(min=-448.0, max=448.0) + + qweight = qweight.to(quant_dtype) + return qweight, scale + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-12, + dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Performs per-token-group quantization on input tensor, converting to FP8. + + Args: + x (torch.Tensor): Input tensor to quantize (shape: [..., hidden_dim]) + group_size (int): Size of groups for quantization + column_major_scales (bool): If True, returns scales in column-major format + eps (float): Small value to avoid division by zero + dtype (torch.dtype, optional): FP8 dtype to use. Defaults to platform-specific. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scaling factors + """ + # Input validation + assert x.ndim >= 2, "Input tensor must have at least 2 dimensions" + assert x.shape[-1] % group_size == 0, f"Last dimension ({x.shape[-1]}) must be divisible by group_size ({group_size})" + + # Determine FP8 dtype and limits + if dtype is None: + dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn + finfo = torch.finfo(dtype) + + # Reshape input for group-wise operations + orig_shape = x.shape + num_groups = x.shape[-1] // group_size + + # Reshape to [*, num_groups, group_size] + x_reshaped = x.view(-1, num_groups, group_size) + + # Calculate max absolute values per group + max_abs = x_reshaped.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps) + + # Calculate scales as max_dtype / max_abs + scales = finfo.max / max_abs + + # Quantize values + x_scaled = (x_reshaped * scales) + x_quant = x_scaled.clamp(min=finfo.min, max=finfo.max).to(dtype) + + # Reshape back to original shape + x_quant = x_quant.view(orig_shape) + + # Process scales + scales = scales.squeeze(-1) # Remove the last singleton dimension + + scales = scales.view(-1, num_groups) + + # Return reciprocal of scales for compatibility with other operations + return x_quant, scales.float().reciprocal() + +def per_token_group_dequant_fp8( + x: torch.Tensor, + scales: torch.Tensor, + group_size: int, + output_dtype: torch.dtype = torch.float16 +) -> torch.Tensor: + """ + Dequantizes FP8 tensor back to floating point using group scales. + + Args: + x (torch.Tensor): Quantized input tensor + scales (torch.Tensor): Scale factors (reciprocal) + group_size (int): Size of groups used in quantization + output_dtype (torch.dtype): Output dtype + + Returns: + torch.Tensor: Dequantized tensor + """ + # Reshape input for group-wise operations + orig_shape = x.shape + num_groups = x.shape[-1] // group_size + x_reshaped = x.view(-1, num_groups, group_size) + + # Ensure scales have correct shape + if scales.ndim == 2: + scales = scales.view(-1, num_groups, 1) + else: + scales = scales.view(*orig_shape[:-1], num_groups, 1) + + # Dequantize + x_dequant = x_reshaped.to(torch.float32) * scales + + # Reshape back and convert to desired dtype + return x_dequant.view(orig_shape).to(output_dtype) + +@torch.compile +def w8a8_block_fp8_matmul( + input_q: torch.Tensor, # [batch, seq_len, hidden_dim] + weight_q: torch.Tensor, # [out_features, hidden_dim] + input_scale: torch.Tensor, # [batch * seq_len, num_input_groups] + weight_scale: torch.Tensor, # [num_weight_blocks_m, num_weight_blocks_n] + block_size: Tuple[int, int], # (M=128, N=128) for weights + output_dtype: torch.dtype = torch.float16 +) -> torch.Tensor: + """ + Performs blocked matrix multiplication with FP8 quantized matrices. + + Args: + input_q: Quantized input tensor with 1x128 block quantization + weight_q: Quantized weight tensor with 128x128 block quantization + input_scale: Scaling factors for input blocks + weight_scale: Scaling factors for weight blocks + block_size: Tuple of (M, N) for weight block dimensions + output_dtype: Desired output dtype + """ + batch_size, seq_len, hidden_dim = input_q.shape + out_features = weight_q.shape[0] + + # Reshape input for batched matmul + input_reshaped = input_q.view(-1, hidden_dim) # [batch*seq_len, hidden_dim] + + # Calculate number of blocks + num_weight_blocks_m = out_features // block_size[0] + num_weight_blocks_n = hidden_dim // block_size[1] + + # Initialize output tensor + output = torch.zeros((batch_size * seq_len, out_features), + dtype=torch.float32, + device=input_q.device) + + # Process each block + for i in range(num_weight_blocks_m): + m_start = i * block_size[0] + m_end = m_start + block_size[0] + + for j in range(num_weight_blocks_n): + n_start = j * block_size[1] + n_end = n_start + block_size[1] + + # Extract current blocks + input_block = input_reshaped[:, n_start:n_end] + weight_block = weight_q[m_start:m_end, n_start:n_end] + + # Get corresponding scales + curr_input_scale = input_scale[:, j:j+1] # [batch*seq_len, 1] + curr_weight_scale = weight_scale[i, j] # scalar + + # Dequantize and multiply + block_result = torch._scaled_mm( + input_block, + weight_block.t(), + scale_a=curr_input_scale, + scale_b=curr_weight_scale, + out_dtype=x.dtype + ) + # block_result = torch.matmul( + # input_block.to(torch.float32) * curr_input_scale, + # weight_block.to(torch.float32).t() * curr_weight_scale + # ) + + # Accumulate result + output[:, m_start:m_end] += block_result + + # Reshape output back to original dimensions + output = output.view(batch_size, seq_len, out_features) + + return output.to(output_dtype) + + +def fp8_quantize(weight, scale: Optional[torch.Tensor] = None, qdtype=torch.float8_e4m3fn): + if scale is None: + # weight, scale = quant_weights(weight, torch.int8, False) + finfo = torch.finfo(qdtype) + # Calculate the scale as dtype max divided by absmax + scale = finfo.max / weight.abs().max().clamp(min=1e-12) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) + scale = scale.float().reciprocal() + else: + qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max) + # Return both float8 data and the inverse scale (as float), + # as both required as inputs to torch._scaled_mm + qweight = qweight.to(qdtype) + return qweight, scale + + +def normalize_e4m3fn_to_e4m3fnuz( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Convert e4m3fn weights and scales to e4m3fnuz format for ROCm compatibility.""" + if weight.dtype != torch.float8_e4m3fn: + return weight, weight_scale, input_scale + + # Convert -128 (NaN in e4m3fnuz) to 0 + weight_as_int8 = weight.view(torch.int8) + weight_as_int8[weight_as_int8 == -128] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) + + # Double scales since e4m3fnuz values are half of e4m3fn + weight_scale = weight_scale * 2.0 + if input_scale is not None: + input_scale = input_scale * 2.0 + return weight, weight_scale, input_scale + +class FP8Linear(nn.Module): + """FP8 Linear layer implementation.""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + device=None, + dtype=None, + activation_scheme="dynamic", + weight_block_size=None + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.activation_scheme = activation_scheme + self.weight_block_size = weight_block_size + + self.weight = nn.Parameter(torch.empty((out_features, in_features), dtype=quant_dtype, device=device)) + self.weight_scale = nn.Parameter(torch.empty(1, dtype=torch.float32, device=device)) + + if activation_scheme == "static": + self.input_scale = nn.Parameter(torch.empty(1, dtype=torch.float32, device=device)) + else: + self.register_parameter('input_scale', None) + + if bias: + self.bias = nn.Parameter(torch.empty(out_features, dtype=dtype, device=device)) + else: + self.register_parameter('bias', None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Handle ROCm compatibility + # Standard FP8 matmul + if self.activation_scheme == "dynamic": + qinput, self.input_scale = per_token_group_quant_fp8(input, self.weight_block_size[1]) + + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + self.weight, self.weight_scale, self.input_scale + ) + + output = w8a8_block_fp8_matmul( + qinput, + weight, + input_scale, + weight_scale, + self.weight_block_size, + output_dtype=input.dtype, + ) + + if self.bias is not None: + output = output + self.bias + + return output + +class FP8MoELinear(FP8Linear): + """FP8 Linear layer for MoE implementation.""" + + def __init__( + self, + n_experts: int, + in_features: int, + out_features: int, + bias: bool, + device=None, + dtype=None, + activation_scheme="dynamic", + weight_block_size=None + ): + super().__init__( + in_features, + out_features, + bias, + device, + dtype, + activation_scheme, + weight_block_size + ) + self.n_experts = n_experts + + # Reshape weight and scale for experts + self.weight = nn.Parameter( + torch.empty((n_experts, out_features, in_features), + dtype=quant_dtype, + device=device) + ) + self.weight_scale = nn.Parameter( + torch.empty((n_experts, 1), + dtype=torch.float32, + device=device) + ) + + def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor: + # Handle ROCm compatibility + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + self.weight, self.weight_scale, self.input_scale + ) + + if self.activation_scheme == "dynamic": + input_scale = x.abs().max() / torch.finfo(quant_dtype).max + + # Select expert weights and scales + selected_weights = weight[expert_indices] + selected_scales = weight_scale[expert_indices] + + # Perform FP8 matmul for each expert + output = torch._scaled_mm( + x, + selected_weights.transpose(-1, -2), + scale_a=input_scale, + scale_b=selected_scales, + out_dtype=x.dtype + ) + + if self.bias is not None: + output = output + self.bias[expert_indices] + + return output + +def _replace_with_fp8_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, +): + """Replace Linear layers with FP8Linear.""" + if current_key_name is None: + current_key_name = [] + + for name, module in model.named_children(): + current_key_name.append(name) + + if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []): + current_key_name_str = ".".join(current_key_name) + if not any(key in current_key_name_str for key in (modules_to_not_convert or [])): + with init_empty_weights(): + # Check if this is an MoE layer + is_moe = any(moe_key in current_key_name_str + for moe_key in ["gate", "experts"]) + is_moe = False + + if is_moe: + n_experts = getattr(model.config, "num_experts", 8) + model._modules[name] = FP8MoELinear( + n_experts=n_experts, + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + device=module.weight.device, + dtype=module.weight.dtype, + activation_scheme=quantization_config.activation_scheme, + weight_block_size=quantization_config.weight_block_size + ) + else: + model._modules[name] = FP8Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + device=module.weight.device, + dtype=module.weight.dtype, + activation_scheme=quantization_config.activation_scheme, + weight_block_size=quantization_config.weight_block_size + ) + has_been_replaced = True + + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_fp8_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + + current_key_name.pop(-1) + + return model, has_been_replaced + +def replace_with_fp8_linear( + model, + modules_to_not_convert=None, + quantization_config=None, +): + """Helper function to replace model layers with FP8 versions.""" + modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert + + if quantization_config.modules_to_not_convert is not None: + modules_to_not_convert.extend(quantization_config.modules_to_not_convert) + modules_to_not_convert = list(set(modules_to_not_convert)) + + model, has_been_replaced = _replace_with_fp8_linear( + model, + modules_to_not_convert=modules_to_not_convert, + quantization_config=quantization_config, + ) + + if not has_been_replaced: + logger.warning( + "You are loading your model using fp8 but no linear modules were found in your model." + " Please double check your model architecture." + ) + + return model \ No newline at end of file diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index d5b51d038..dbca8c208 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -31,6 +31,7 @@ from ..utils.quantization_config import ( QuantoConfig, TorchAoConfig, VptqConfig, + FP8Config, ) from .quantizer_aqlm import AqlmHfQuantizer from .quantizer_awq import AwqQuantizer @@ -46,7 +47,7 @@ from .quantizer_hqq import HqqHfQuantizer from .quantizer_quanto import QuantoHfQuantizer from .quantizer_torchao import TorchAoHfQuantizer from .quantizer_vptq import VptqHfQuantizer - +from .quantizer_fp8 import FP8HfQuantizer AUTO_QUANTIZER_MAPPING = { "awq": AwqQuantizer, @@ -63,6 +64,7 @@ AUTO_QUANTIZER_MAPPING = { "torchao": TorchAoHfQuantizer, "bitnet": BitNetHfQuantizer, "vptq": VptqHfQuantizer, + "fp8": FP8HfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -80,6 +82,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = { "torchao": TorchAoConfig, "bitnet": BitNetConfig, "vptq": VptqConfig, + "fp8": FP8Config, } diff --git a/src/transformers/quantizers/quantizer_fp8.py b/src/transformers/quantizers/quantizer_fp8.py new file mode 100644 index 000000000..137844637 --- /dev/null +++ b/src/transformers/quantizers/quantizer_fp8.py @@ -0,0 +1,180 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +import torch +from torch.nn.parameter import Parameter as Parameter +from .base import HfQuantizer +from ..utils import is_accelerate_available, logging +from .quantizers_utils import get_module_from_name + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + +logger = logging.get_logger(__name__) + +class FP8HfQuantizer(HfQuantizer): + """ + FP8 quantization implementation supporting both standard and MoE models. + Supports both e4m3fn and e4m3fnuz formats based on platform. + """ + + requires_parameters_quantization = False + requires_calibration = False + required_packages = ["accelerate"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + self.quantization_config = quantization_config + self.is_moe_model = kwargs.get("is_moe_model", False) + + def validate_environment(self, *args, **kwargs): + if not is_accelerate_available(): + raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)") + + if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): + raise ValueError( + "Converting into FP8 weights from tf/flax weights is currently not supported, " + "please make sure the weights are in PyTorch format." + ) + + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for FP8 quantization.") + + device_map = kwargs.get("device_map", None) + if device_map is None: + logger.warning_once( + "You have loaded an FP8 model on CPU and have a CUDA device available. " + "Make sure to set your model on a GPU device to run your model." + ) + elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()): + raise ValueError( + "FP8 models do not support CPU or disk offloading in the device map. " + "Please remove CPU/disk devices from the device map." + ) + + def update_torch_dtype(self, torch_dtype): + torch_dtype = torch.float32 + return torch_dtype + + def create_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: Optional[List[str]] = None, + ): + """ + Quantizes weights to FP8 format using either: + - Block-wise quantization when weight_block_size is provided + - Per-tensor quantization when weight_block_size is None + """ + module, tensor_name = get_module_from_name(model, param_name) + + # Get FP8 min/max values + fp8_min = torch.finfo(torch.float8_e4m3fn).min + fp8_max = torch.finfo(torch.float8_e4m3fn).max + + if self.quantization_config.weight_block_size is not None: + # Block-wise quantization + block_size_m, block_size_n = self.quantization_config.weight_block_size + + # Get matrix dimensions + rows, cols = param_value.shape[-2:] + + # Check if dimensions are divisible by block sizes + if rows % block_size_m != 0 or cols % block_size_n != 0: + raise ValueError( + f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})" + ) + + # Create blocks using unfold + param_value = param_value.unfold(-2, block_size_m, block_size_m) + param_value = param_value.unfold(-2, block_size_n, block_size_n) + + # Calculate scaling factor for each block + max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2)) + scale = fp8_max / max_abs + + # Expand scale to match block dimensions for multiplication + scale = scale.unsqueeze(-1).unsqueeze(-1) + + # Quantize the weights + quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max) + + # Reshape back to matrix shape + quantized_param = quantized_param.reshape(param_value.shape[:-4] + (rows, cols)) + + # Reshape scale to match the number of blocks + scale = scale.reshape(scale.shape[:-2] + (-1,)) + + else: + # Per-tensor quantization + max_abs = torch.max(torch.abs(param_value)) + scale = fp8_max / max_abs + + # Quantize the weights + quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max) + + # For per-tensor quantization, we just need a single scale value + scale = torch.tensor([scale]) + # Store the quantized weights and scales in the module + print(f"tensor_name in create_quantized_param: {tensor_name} {target_device}") + module._buffers[tensor_name] = quantized_param.to(target_device) + module._buffers["weight_scale"] = scale.to(target_device) + + def check_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ): + from ..integrations.fp8 import FP8Linear, FP8MoELinear + + module, tensor_name = get_module_from_name(model, param_name) + + if isinstance(module, (FP8Linear, FP8MoELinear)): + if tensor_name == "bias": + return False + if tensor_name == "weight": + return True + if tensor_name in ["weight_scale", "input_scale"]: + return False + return False + + def _process_model_before_weight_loading( + self, + model: "PreTrainedModel", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + from ..integrations.fp8 import replace_with_fp8_linear + + self.modules_to_not_convert = ["lm_head"] + keep_in_fp32_modules + + if self.quantization_config.modules_to_not_convert: + self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + + model = replace_with_fp8_linear( + model, + modules_to_not_convert=self.modules_to_not_convert, + quantization_config=self.quantization_config, + ) + + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + return model + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + def is_serializable(self, safe_serialization=None): + return True + + @property + def is_trainable(self) -> bool: + return False # FP8 quantization is typically used for inference only \ No newline at end of file diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index b356b3d9b..2529b6ef5 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -21,7 +21,7 @@ import os from dataclasses import dataclass from enum import Enum from inspect import Parameter, signature -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Tuple from packaging import version @@ -56,6 +56,7 @@ class QuantizationMethod(str, Enum): FBGEMM_FP8 = "fbgemm_fp8" TORCHAO = "torchao" BITNET = "bitnet" + FP8 = "fp8" class AWQLinearVersion(str, Enum): @@ -1548,3 +1549,24 @@ class BitNetConfig(QuantizationConfigMixin): Safety checker that arguments are correct """ pass + +@dataclass +class FP8Config(QuantizationConfigMixin): + def __init__( + self, + modules_to_not_convert: Optional[List] = None, + activation_scheme: Optional[str] = "dynamic", + weight_block_size: Optional[Tuple[int, int]] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.FP8 + self.modules_to_not_convert = modules_to_not_convert + self.activation_scheme = activation_scheme + self.weight_block_size = weight_block_size + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct + """ + pass