mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
first commit
This commit is contained in:
parent
7eecdf2a86
commit
45dc9e4293
9 changed files with 668 additions and 2 deletions
|
|
@ -185,6 +185,8 @@
|
|||
title: BitNet
|
||||
- local: quantization/compressed_tensors
|
||||
title: compressed-tensors
|
||||
- local: quantization/fp8
|
||||
title: FP8
|
||||
- local: quantization/contribute
|
||||
title: Contribute new quantization method
|
||||
title: Quantization Methods
|
||||
|
|
|
|||
|
|
@ -80,3 +80,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
|||
## BitNetConfig
|
||||
|
||||
[[autodoc]] BitNetConfig
|
||||
|
||||
## FP8Config
|
||||
|
||||
[[autodoc]] FP8Config
|
||||
|
|
|
|||
0
docs/source/en/quantization/fp8.md
Normal file
0
docs/source/en/quantization/fp8.md
Normal file
|
|
@ -1020,6 +1020,7 @@ _import_structure = {
|
|||
"CompressedTensorsConfig",
|
||||
"EetqConfig",
|
||||
"FbgemmFp8Config",
|
||||
"FP8Config",
|
||||
"GPTQConfig",
|
||||
"HiggsConfig",
|
||||
"HqqConfig",
|
||||
|
|
@ -6159,6 +6160,7 @@ if TYPE_CHECKING:
|
|||
CompressedTensorsConfig,
|
||||
EetqConfig,
|
||||
FbgemmFp8Config,
|
||||
FP8Config,
|
||||
GPTQConfig,
|
||||
HiggsConfig,
|
||||
HqqConfig,
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ _import_structure = {
|
|||
],
|
||||
"eetq": ["replace_with_eetq_linear"],
|
||||
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
|
||||
"fp8": ["FP8Linear", "FP8MoELinear", "replace_with_fp8_linear"],
|
||||
"fsdp": ["is_fsdp_managed_module"],
|
||||
"ggml": [
|
||||
"GGUF_CONFIG_MAPPING",
|
||||
|
|
|
|||
452
src/transformers/integrations/fp8.py
Normal file
452
src/transformers/integrations/fp8.py
Normal file
|
|
@ -0,0 +1,452 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, List, Tuple, Union
|
||||
from ..utils import is_accelerate_available, logging
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
def fp8_quantize(weight: torch.Tensor, scale: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize weights to FP8."""
|
||||
if scale is None:
|
||||
# Calculate scale as max value divided by absmax
|
||||
scale = 448.0 / weight.abs().max().clamp(min=1e-12)
|
||||
# Scale and clamp tensor to FP8 range
|
||||
qweight = (weight * scale).clamp(min=-448.0, max=448.0)
|
||||
scale = scale.float().reciprocal()
|
||||
else:
|
||||
qweight = (weight * scale.reciprocal()).clamp(min=-448.0, max=448.0)
|
||||
|
||||
qweight = qweight.to(quant_dtype)
|
||||
return qweight, scale
|
||||
|
||||
|
||||
def per_token_group_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-12,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Performs per-token-group quantization on input tensor, converting to FP8.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor to quantize (shape: [..., hidden_dim])
|
||||
group_size (int): Size of groups for quantization
|
||||
column_major_scales (bool): If True, returns scales in column-major format
|
||||
eps (float): Small value to avoid division by zero
|
||||
dtype (torch.dtype, optional): FP8 dtype to use. Defaults to platform-specific.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scaling factors
|
||||
"""
|
||||
# Input validation
|
||||
assert x.ndim >= 2, "Input tensor must have at least 2 dimensions"
|
||||
assert x.shape[-1] % group_size == 0, f"Last dimension ({x.shape[-1]}) must be divisible by group_size ({group_size})"
|
||||
|
||||
# Determine FP8 dtype and limits
|
||||
if dtype is None:
|
||||
dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
|
||||
finfo = torch.finfo(dtype)
|
||||
|
||||
# Reshape input for group-wise operations
|
||||
orig_shape = x.shape
|
||||
num_groups = x.shape[-1] // group_size
|
||||
|
||||
# Reshape to [*, num_groups, group_size]
|
||||
x_reshaped = x.view(-1, num_groups, group_size)
|
||||
|
||||
# Calculate max absolute values per group
|
||||
max_abs = x_reshaped.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps)
|
||||
|
||||
# Calculate scales as max_dtype / max_abs
|
||||
scales = finfo.max / max_abs
|
||||
|
||||
# Quantize values
|
||||
x_scaled = (x_reshaped * scales)
|
||||
x_quant = x_scaled.clamp(min=finfo.min, max=finfo.max).to(dtype)
|
||||
|
||||
# Reshape back to original shape
|
||||
x_quant = x_quant.view(orig_shape)
|
||||
|
||||
# Process scales
|
||||
scales = scales.squeeze(-1) # Remove the last singleton dimension
|
||||
|
||||
scales = scales.view(-1, num_groups)
|
||||
|
||||
# Return reciprocal of scales for compatibility with other operations
|
||||
return x_quant, scales.float().reciprocal()
|
||||
|
||||
def per_token_group_dequant_fp8(
|
||||
x: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
group_size: int,
|
||||
output_dtype: torch.dtype = torch.float16
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Dequantizes FP8 tensor back to floating point using group scales.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Quantized input tensor
|
||||
scales (torch.Tensor): Scale factors (reciprocal)
|
||||
group_size (int): Size of groups used in quantization
|
||||
output_dtype (torch.dtype): Output dtype
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Dequantized tensor
|
||||
"""
|
||||
# Reshape input for group-wise operations
|
||||
orig_shape = x.shape
|
||||
num_groups = x.shape[-1] // group_size
|
||||
x_reshaped = x.view(-1, num_groups, group_size)
|
||||
|
||||
# Ensure scales have correct shape
|
||||
if scales.ndim == 2:
|
||||
scales = scales.view(-1, num_groups, 1)
|
||||
else:
|
||||
scales = scales.view(*orig_shape[:-1], num_groups, 1)
|
||||
|
||||
# Dequantize
|
||||
x_dequant = x_reshaped.to(torch.float32) * scales
|
||||
|
||||
# Reshape back and convert to desired dtype
|
||||
return x_dequant.view(orig_shape).to(output_dtype)
|
||||
|
||||
@torch.compile
|
||||
def w8a8_block_fp8_matmul(
|
||||
input_q: torch.Tensor, # [batch, seq_len, hidden_dim]
|
||||
weight_q: torch.Tensor, # [out_features, hidden_dim]
|
||||
input_scale: torch.Tensor, # [batch * seq_len, num_input_groups]
|
||||
weight_scale: torch.Tensor, # [num_weight_blocks_m, num_weight_blocks_n]
|
||||
block_size: Tuple[int, int], # (M=128, N=128) for weights
|
||||
output_dtype: torch.dtype = torch.float16
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs blocked matrix multiplication with FP8 quantized matrices.
|
||||
|
||||
Args:
|
||||
input_q: Quantized input tensor with 1x128 block quantization
|
||||
weight_q: Quantized weight tensor with 128x128 block quantization
|
||||
input_scale: Scaling factors for input blocks
|
||||
weight_scale: Scaling factors for weight blocks
|
||||
block_size: Tuple of (M, N) for weight block dimensions
|
||||
output_dtype: Desired output dtype
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim = input_q.shape
|
||||
out_features = weight_q.shape[0]
|
||||
|
||||
# Reshape input for batched matmul
|
||||
input_reshaped = input_q.view(-1, hidden_dim) # [batch*seq_len, hidden_dim]
|
||||
|
||||
# Calculate number of blocks
|
||||
num_weight_blocks_m = out_features // block_size[0]
|
||||
num_weight_blocks_n = hidden_dim // block_size[1]
|
||||
|
||||
# Initialize output tensor
|
||||
output = torch.zeros((batch_size * seq_len, out_features),
|
||||
dtype=torch.float32,
|
||||
device=input_q.device)
|
||||
|
||||
# Process each block
|
||||
for i in range(num_weight_blocks_m):
|
||||
m_start = i * block_size[0]
|
||||
m_end = m_start + block_size[0]
|
||||
|
||||
for j in range(num_weight_blocks_n):
|
||||
n_start = j * block_size[1]
|
||||
n_end = n_start + block_size[1]
|
||||
|
||||
# Extract current blocks
|
||||
input_block = input_reshaped[:, n_start:n_end]
|
||||
weight_block = weight_q[m_start:m_end, n_start:n_end]
|
||||
|
||||
# Get corresponding scales
|
||||
curr_input_scale = input_scale[:, j:j+1] # [batch*seq_len, 1]
|
||||
curr_weight_scale = weight_scale[i, j] # scalar
|
||||
|
||||
# Dequantize and multiply
|
||||
block_result = torch._scaled_mm(
|
||||
input_block,
|
||||
weight_block.t(),
|
||||
scale_a=curr_input_scale,
|
||||
scale_b=curr_weight_scale,
|
||||
out_dtype=x.dtype
|
||||
)
|
||||
# block_result = torch.matmul(
|
||||
# input_block.to(torch.float32) * curr_input_scale,
|
||||
# weight_block.to(torch.float32).t() * curr_weight_scale
|
||||
# )
|
||||
|
||||
# Accumulate result
|
||||
output[:, m_start:m_end] += block_result
|
||||
|
||||
# Reshape output back to original dimensions
|
||||
output = output.view(batch_size, seq_len, out_features)
|
||||
|
||||
return output.to(output_dtype)
|
||||
|
||||
|
||||
def fp8_quantize(weight, scale: Optional[torch.Tensor] = None, qdtype=torch.float8_e4m3fn):
|
||||
if scale is None:
|
||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||
finfo = torch.finfo(qdtype)
|
||||
# Calculate the scale as dtype max divided by absmax
|
||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
|
||||
# scale and clamp the tensor to bring it to
|
||||
# the representative range of float8 data type
|
||||
# (as default cast is unsaturated)
|
||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
scale = scale.float().reciprocal()
|
||||
else:
|
||||
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
||||
# Return both float8 data and the inverse scale (as float),
|
||||
# as both required as inputs to torch._scaled_mm
|
||||
qweight = qweight.to(qdtype)
|
||||
return qweight, scale
|
||||
|
||||
|
||||
def normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Convert e4m3fn weights and scales to e4m3fnuz format for ROCm compatibility."""
|
||||
if weight.dtype != torch.float8_e4m3fn:
|
||||
return weight, weight_scale, input_scale
|
||||
|
||||
# Convert -128 (NaN in e4m3fnuz) to 0
|
||||
weight_as_int8 = weight.view(torch.int8)
|
||||
weight_as_int8[weight_as_int8 == -128] = 0
|
||||
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
||||
|
||||
# Double scales since e4m3fnuz values are half of e4m3fn
|
||||
weight_scale = weight_scale * 2.0
|
||||
if input_scale is not None:
|
||||
input_scale = input_scale * 2.0
|
||||
return weight, weight_scale, input_scale
|
||||
|
||||
class FP8Linear(nn.Module):
|
||||
"""FP8 Linear layer implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool,
|
||||
device=None,
|
||||
dtype=None,
|
||||
activation_scheme="dynamic",
|
||||
weight_block_size=None
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.activation_scheme = activation_scheme
|
||||
self.weight_block_size = weight_block_size
|
||||
|
||||
self.weight = nn.Parameter(torch.empty((out_features, in_features), dtype=quant_dtype, device=device))
|
||||
self.weight_scale = nn.Parameter(torch.empty(1, dtype=torch.float32, device=device))
|
||||
|
||||
if activation_scheme == "static":
|
||||
self.input_scale = nn.Parameter(torch.empty(1, dtype=torch.float32, device=device))
|
||||
else:
|
||||
self.register_parameter('input_scale', None)
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(out_features, dtype=dtype, device=device))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Handle ROCm compatibility
|
||||
# Standard FP8 matmul
|
||||
if self.activation_scheme == "dynamic":
|
||||
qinput, self.input_scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
|
||||
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
self.weight, self.weight_scale, self.input_scale
|
||||
)
|
||||
|
||||
output = w8a8_block_fp8_matmul(
|
||||
qinput,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
self.weight_block_size,
|
||||
output_dtype=input.dtype,
|
||||
)
|
||||
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
|
||||
return output
|
||||
|
||||
class FP8MoELinear(FP8Linear):
|
||||
"""FP8 Linear layer for MoE implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_experts: int,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool,
|
||||
device=None,
|
||||
dtype=None,
|
||||
activation_scheme="dynamic",
|
||||
weight_block_size=None
|
||||
):
|
||||
super().__init__(
|
||||
in_features,
|
||||
out_features,
|
||||
bias,
|
||||
device,
|
||||
dtype,
|
||||
activation_scheme,
|
||||
weight_block_size
|
||||
)
|
||||
self.n_experts = n_experts
|
||||
|
||||
# Reshape weight and scale for experts
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((n_experts, out_features, in_features),
|
||||
dtype=quant_dtype,
|
||||
device=device)
|
||||
)
|
||||
self.weight_scale = nn.Parameter(
|
||||
torch.empty((n_experts, 1),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor:
|
||||
# Handle ROCm compatibility
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
self.weight, self.weight_scale, self.input_scale
|
||||
)
|
||||
|
||||
if self.activation_scheme == "dynamic":
|
||||
input_scale = x.abs().max() / torch.finfo(quant_dtype).max
|
||||
|
||||
# Select expert weights and scales
|
||||
selected_weights = weight[expert_indices]
|
||||
selected_scales = weight_scale[expert_indices]
|
||||
|
||||
# Perform FP8 matmul for each expert
|
||||
output = torch._scaled_mm(
|
||||
x,
|
||||
selected_weights.transpose(-1, -2),
|
||||
scale_a=input_scale,
|
||||
scale_b=selected_scales,
|
||||
out_dtype=x.dtype
|
||||
)
|
||||
|
||||
if self.bias is not None:
|
||||
output = output + self.bias[expert_indices]
|
||||
|
||||
return output
|
||||
|
||||
def _replace_with_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
"""Replace Linear layers with FP8Linear."""
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
|
||||
for name, module in model.named_children():
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []):
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
|
||||
with init_empty_weights():
|
||||
# Check if this is an MoE layer
|
||||
is_moe = any(moe_key in current_key_name_str
|
||||
for moe_key in ["gate", "experts"])
|
||||
is_moe = False
|
||||
|
||||
if is_moe:
|
||||
n_experts = getattr(model.config, "num_experts", 8)
|
||||
model._modules[name] = FP8MoELinear(
|
||||
n_experts=n_experts,
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
activation_scheme=quantization_config.activation_scheme,
|
||||
weight_block_size=quantization_config.weight_block_size
|
||||
)
|
||||
else:
|
||||
model._modules[name] = FP8Linear(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
activation_scheme=quantization_config.activation_scheme,
|
||||
weight_block_size=quantization_config.weight_block_size
|
||||
)
|
||||
has_been_replaced = True
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_fp8_linear(
|
||||
module,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
|
||||
current_key_name.pop(-1)
|
||||
|
||||
return model, has_been_replaced
|
||||
|
||||
def replace_with_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
quantization_config=None,
|
||||
):
|
||||
"""Helper function to replace model layers with FP8 versions."""
|
||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||
|
||||
if quantization_config.modules_to_not_convert is not None:
|
||||
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
||||
modules_to_not_convert = list(set(modules_to_not_convert))
|
||||
|
||||
model, has_been_replaced = _replace_with_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model using fp8 but no linear modules were found in your model."
|
||||
" Please double check your model architecture."
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
@ -31,6 +31,7 @@ from ..utils.quantization_config import (
|
|||
QuantoConfig,
|
||||
TorchAoConfig,
|
||||
VptqConfig,
|
||||
FP8Config,
|
||||
)
|
||||
from .quantizer_aqlm import AqlmHfQuantizer
|
||||
from .quantizer_awq import AwqQuantizer
|
||||
|
|
@ -46,7 +47,7 @@ from .quantizer_hqq import HqqHfQuantizer
|
|||
from .quantizer_quanto import QuantoHfQuantizer
|
||||
from .quantizer_torchao import TorchAoHfQuantizer
|
||||
from .quantizer_vptq import VptqHfQuantizer
|
||||
|
||||
from .quantizer_fp8 import FP8HfQuantizer
|
||||
|
||||
AUTO_QUANTIZER_MAPPING = {
|
||||
"awq": AwqQuantizer,
|
||||
|
|
@ -63,6 +64,7 @@ AUTO_QUANTIZER_MAPPING = {
|
|||
"torchao": TorchAoHfQuantizer,
|
||||
"bitnet": BitNetHfQuantizer,
|
||||
"vptq": VptqHfQuantizer,
|
||||
"fp8": FP8HfQuantizer,
|
||||
}
|
||||
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
|
|
@ -80,6 +82,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
|||
"torchao": TorchAoConfig,
|
||||
"bitnet": BitNetConfig,
|
||||
"vptq": VptqConfig,
|
||||
"fp8": FP8Config,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
180
src/transformers/quantizers/quantizer_fp8.py
Normal file
180
src/transformers/quantizers/quantizer_fp8.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter as Parameter
|
||||
from .base import HfQuantizer
|
||||
from ..utils import is_accelerate_available, logging
|
||||
from .quantizers_utils import get_module_from_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
class FP8HfQuantizer(HfQuantizer):
|
||||
"""
|
||||
FP8 quantization implementation supporting both standard and MoE models.
|
||||
Supports both e4m3fn and e4m3fnuz formats based on platform.
|
||||
"""
|
||||
|
||||
requires_parameters_quantization = False
|
||||
requires_calibration = False
|
||||
required_packages = ["accelerate"]
|
||||
|
||||
def __init__(self, quantization_config, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
self.quantization_config = quantization_config
|
||||
self.is_moe_model = kwargs.get("is_moe_model", False)
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)")
|
||||
|
||||
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
|
||||
raise ValueError(
|
||||
"Converting into FP8 weights from tf/flax weights is currently not supported, "
|
||||
"please make sure the weights are in PyTorch format."
|
||||
)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("No GPU found. A GPU is needed for FP8 quantization.")
|
||||
|
||||
device_map = kwargs.get("device_map", None)
|
||||
if device_map is None:
|
||||
logger.warning_once(
|
||||
"You have loaded an FP8 model on CPU and have a CUDA device available. "
|
||||
"Make sure to set your model on a GPU device to run your model."
|
||||
)
|
||||
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
|
||||
raise ValueError(
|
||||
"FP8 models do not support CPU or disk offloading in the device map. "
|
||||
"Please remove CPU/disk devices from the device map."
|
||||
)
|
||||
|
||||
def update_torch_dtype(self, torch_dtype):
|
||||
torch_dtype = torch.float32
|
||||
return torch_dtype
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: Dict[str, Any],
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Quantizes weights to FP8 format using either:
|
||||
- Block-wise quantization when weight_block_size is provided
|
||||
- Per-tensor quantization when weight_block_size is None
|
||||
"""
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
# Get FP8 min/max values
|
||||
fp8_min = torch.finfo(torch.float8_e4m3fn).min
|
||||
fp8_max = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
if self.quantization_config.weight_block_size is not None:
|
||||
# Block-wise quantization
|
||||
block_size_m, block_size_n = self.quantization_config.weight_block_size
|
||||
|
||||
# Get matrix dimensions
|
||||
rows, cols = param_value.shape[-2:]
|
||||
|
||||
# Check if dimensions are divisible by block sizes
|
||||
if rows % block_size_m != 0 or cols % block_size_n != 0:
|
||||
raise ValueError(
|
||||
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})"
|
||||
)
|
||||
|
||||
# Create blocks using unfold
|
||||
param_value = param_value.unfold(-2, block_size_m, block_size_m)
|
||||
param_value = param_value.unfold(-2, block_size_n, block_size_n)
|
||||
|
||||
# Calculate scaling factor for each block
|
||||
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
|
||||
scale = fp8_max / max_abs
|
||||
|
||||
# Expand scale to match block dimensions for multiplication
|
||||
scale = scale.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# Quantize the weights
|
||||
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max)
|
||||
|
||||
# Reshape back to matrix shape
|
||||
quantized_param = quantized_param.reshape(param_value.shape[:-4] + (rows, cols))
|
||||
|
||||
# Reshape scale to match the number of blocks
|
||||
scale = scale.reshape(scale.shape[:-2] + (-1,))
|
||||
|
||||
else:
|
||||
# Per-tensor quantization
|
||||
max_abs = torch.max(torch.abs(param_value))
|
||||
scale = fp8_max / max_abs
|
||||
|
||||
# Quantize the weights
|
||||
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max)
|
||||
|
||||
# For per-tensor quantization, we just need a single scale value
|
||||
scale = torch.tensor([scale])
|
||||
# Store the quantized weights and scales in the module
|
||||
print(f"tensor_name in create_quantized_param: {tensor_name} {target_device}")
|
||||
module._buffers[tensor_name] = quantized_param.to(target_device)
|
||||
module._buffers["weight_scale"] = scale.to(target_device)
|
||||
|
||||
def check_quantized_param(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations.fp8 import FP8Linear, FP8MoELinear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
if isinstance(module, (FP8Linear, FP8MoELinear)):
|
||||
if tensor_name == "bias":
|
||||
return False
|
||||
if tensor_name == "weight":
|
||||
return True
|
||||
if tensor_name in ["weight_scale", "input_scale"]:
|
||||
return False
|
||||
return False
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
device_map,
|
||||
keep_in_fp32_modules: List[str] = [],
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations.fp8 import replace_with_fp8_linear
|
||||
|
||||
self.modules_to_not_convert = ["lm_head"] + keep_in_fp32_modules
|
||||
|
||||
if self.quantization_config.modules_to_not_convert:
|
||||
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
|
||||
|
||||
model = replace_with_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=self.modules_to_not_convert,
|
||||
quantization_config=self.quantization_config,
|
||||
)
|
||||
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
|
||||
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
||||
return max_memory
|
||||
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return False # FP8 quantization is typically used for inference only
|
||||
|
|
@ -21,7 +21,7 @@ import os
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from inspect import Parameter, signature
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from packaging import version
|
||||
|
||||
|
|
@ -56,6 +56,7 @@ class QuantizationMethod(str, Enum):
|
|||
FBGEMM_FP8 = "fbgemm_fp8"
|
||||
TORCHAO = "torchao"
|
||||
BITNET = "bitnet"
|
||||
FP8 = "fp8"
|
||||
|
||||
|
||||
class AWQLinearVersion(str, Enum):
|
||||
|
|
@ -1548,3 +1549,24 @@ class BitNetConfig(QuantizationConfigMixin):
|
|||
Safety checker that arguments are correct
|
||||
"""
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class FP8Config(QuantizationConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
modules_to_not_convert: Optional[List] = None,
|
||||
activation_scheme: Optional[str] = "dynamic",
|
||||
weight_block_size: Optional[Tuple[int, int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.FP8
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
self.activation_scheme = activation_scheme
|
||||
self.weight_block_size = weight_block_size
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue