This commit is contained in:
MekkCyber 2025-02-08 11:18:52 +00:00
parent 83912834cc
commit 33f73712dc
8 changed files with 496 additions and 241 deletions

View file

@ -0,0 +1,60 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# FP8
With FP8 quantization method, you can quantize your model in FP8 (W8A8):
- the weights will be quantized in 8bit (FP8) per 2D block (e.g. weight_block_size=(128, 128)) which is inspired from the deepseek implementation
- the activation will be quantized in 8bit (FP8) per group per token
It's implemented to add support for DeepSeek-V3 and DeepSeek-R1 models, you can see the paper [here](https://arxiv.org/pdf/2412.19437)
> [!TIP]
> You need a GPU with compute capability>=9 (e.g. H100)
Before you begin, make sure the following libraries are installed with their latest version:
```bash
pip install --upgrade accelerate torch
```
> [!TIP]
> You need to install a torch version compatible with the cuda version of your GPU.
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
```py
from transformers import FP8Config, AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Meta-Llama-3-8B"
quantization_config = FP8Config()
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
output = quantized_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
A quantized model can be saved via "saved_pretrained" and be reused again via the "from_pretrained".
```py
quant_path = "/path/to/save/quantized/model"
model.save_pretrained(quant_path)
model = AutoModelForCausalLM.from_pretrained(quant_path, device_map="auto")
```

View file

@ -54,7 +54,7 @@ _import_structure = {
],
"eetq": ["replace_with_eetq_linear"],
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
"fp8": ["FP8Linear", "FP8MoELinear", "replace_with_fp8_linear"],
"fp8": ["FP8Linear", "replace_with_fp8_linear"],
"fsdp": ["is_fsdp_managed_module"],
"ggml": [
"GGUF_CONFIG_MAPPING",

View file

@ -13,41 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from typing import Optional, List, Tuple, Union
from ..utils import is_accelerate_available, logging
from torch.nn import functional as F
import triton
import triton.language as tl
from triton import Config
from torch.nn import functional as F
from ..utils import is_accelerate_available, logging
if is_accelerate_available():
from accelerate import init_empty_weights
logger = logging.get_logger(__name__)
ACTIVATION_SCHEMES = ["dynamic"]
quant_dtype = torch.float8_e4m3fn
# def fp8_quantize(weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# """Quantize weights to FP8."""
# # Calculate scale as max value divided by absmax
# scale = 448.0 / weight.abs().max().clamp(min=1e-12)
# # Scale and clamp tensor to FP8 range
# qweight = (weight * scale).clamp(min=-448.0, max=448.0)
# scale = scale.float().reciprocal()
# qweight = qweight.to(quant_dtype)
# return qweight, scale
@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.
s = tl.max(tl.abs(x)) / 448.0
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
@ -56,81 +46,17 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.is_contiguous()
assert x.shape[-1] % block_size[0] == 0
assert x.shape[-1] % block_size == 0
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size[0], dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size[0])
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
def grid(meta):
return (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
return y, s
@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, BLOCK_SIZE)
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
s = tl.load(s_ptr + pid_m * n + pid_n)
y = x * s
tl.store(y_ptr + offs, y, mask=mask)
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
assert x.is_contiguous() and s.is_contiguous()
assert x.dim() == 2 and s.dim() == 2
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y
fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
]
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
@triton.jit
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
a_s_ptr, b_s_ptr,
M, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
k = tl.cdiv(K, BLOCK_SIZE_K)
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
a_s_ptrs = a_s_ptr + offs_m * k
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(k):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
a_s = tl.load(a_s_ptrs)
b_s = tl.load(b_s_ptrs)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
a_s_ptrs += 1
b_s_ptrs += 1
c = accumulator.to(c_ptr.dtype.element_ty)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, c, mask=mask)
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
@ -190,12 +116,8 @@ def _w8a8_block_fp8_matmul(
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
@ -218,7 +140,9 @@ def _w8a8_block_fp8_matmul(
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def w8a8_block_fp8_matmul(
def w8a8_block_fp8_matmul_triton(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
@ -254,7 +178,7 @@ def w8a8_block_fp8_matmul(
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
# TODO:
@ -270,8 +194,7 @@ def w8a8_block_fp8_matmul(
BLOCK_SIZE_N = block_n
def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
_w8a8_block_fp8_matmul[grid](
A,
@ -303,33 +226,118 @@ def w8a8_block_fp8_matmul(
return C
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
assert a.is_contiguous() and b.is_contiguous()
assert a_s.is_contiguous() and b_s.is_contiguous()
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
return c
# Python version of the above triton function
@torch.compile
def w8a8_block_fp8_matmul_compile(
input_q: torch.Tensor, # [batch, seq_len, hidden_dim]
weight_q: torch.Tensor, # [out_features, hidden_dim]
input_scale: torch.Tensor, # [batch * seq_len, num_input_groups]
weight_scale: torch.Tensor, # [num_weight_blocks_m, num_weight_blocks_n]
block_size: Optional[Tuple[int, int]] = None, # (M=128, N=128) for weights for example
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Performs blocked matrix multiplication with FP8 quantized matrices.
def linear(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, bias: Optional[torch.Tensor] = None, block_size: Optional[Tuple[int, int]] = None, activation_scheme: str = "dynamic") -> torch.Tensor:
Args:
input_q: Quantized input tensor with 1x128 block quantization
weight_q: Quantized weight tensor with 128x128 block quantization
input_scale: Scaling factors for input blocks
weight_scale: Scaling factors for weight blocks
block_size: Tuple of (M, N) for weight block dimensions
output_dtype: Desired output dtype
"""
batch_size, seq_len, hidden_dim = input_q.shape if input_q.ndim == 3 else (1, input_q.shape[0], input_q.shape[1])
out_features = weight_q.shape[0]
# Reshape input for batched matmul
input_reshaped = input_q.view(-1, hidden_dim) # [batch*seq_len, hidden_dim]
input_scale_reshaped = input_scale.view(input_scale.shape[0], -1) # [batch*seq_len, 1]
# Calculate number of blocks
num_weight_blocks_m = out_features // block_size[0]
num_weight_blocks_n = hidden_dim // block_size[1]
# Initialize output tensor
output = torch.zeros((batch_size * seq_len, out_features), dtype=torch.float32, device=input_q.device)
# Process each block
for i in range(num_weight_blocks_m):
m_start = i * block_size[0]
m_end = m_start + block_size[0]
for j in range(num_weight_blocks_n):
n_start = j * block_size[1]
n_end = n_start + block_size[1]
# Extract current blocks
input_block = input_reshaped[:, n_start:n_end]
weight_block = weight_q[m_start:m_end, n_start:n_end]
# Get corresponding scales
curr_input_scale = input_scale_reshaped[:, j : j + 1] # [batch*seq_len, 1]
curr_weight_scale = weight_scale[i, j] # scalar
block_result = (
torch._scaled_mm(
input_block,
weight_block.t(),
scale_a=torch.tensor(1, dtype=torch.float32, device=input_q.device),
scale_b=curr_weight_scale,
out_dtype=output_dtype,
)
* curr_input_scale
)
output[:, m_start:m_end] += block_result
output = output.view(batch_size, seq_len, out_features)
return output.to(output_dtype)
def linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
block_size: Optional[Tuple[int, int]] = None,
activation_scheme: str = "dynamic",
) -> torch.Tensor:
if weight.element_size() > 1:
return F.linear(x, weight, bias)
print("value not the one expected")
return F.linear(input, weight, bias)
else:
x, scale = act_quant(x, block_size)
y = fp8_gemm(x, scale, weight, weight_scale)
# y = w8a8_block_fp8_matmul(x, weight, scale, weight_scale, block_size)
with torch.cuda.device(input.device):
qinput, scale = act_quant(input, block_size[1])
torch.cuda.synchronize(device=input.device)
with torch.cuda.device(input.device):
output = w8a8_block_fp8_matmul_triton(
qinput,
weight,
scale,
weight_scale,
block_size,
output_dtype=input.dtype,
)
torch.cuda.synchronize(device=input.device)
if bias is not None:
y += bias
return y
output = output + bias
return output.to(dtype=input.dtype)
class FP8Linear(nn.Linear):
dtype = torch.float8_e4m3fn
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None, block_size: Optional[Tuple[int, int]] = None, device=None, activation_scheme="dynamic"):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
dtype=None,
block_size: Optional[Tuple[int, int]] = None,
device=None,
activation_scheme="dynamic",
):
super().__init__(in_features=in_features, out_features=out_features)
self.in_features = in_features
self.out_features = out_features
@ -340,84 +348,29 @@ class FP8Linear(nn.Linear):
if self.weight.element_size() == 1:
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
self.weight_scale_inv = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device))
self.weight_scale_inv = nn.Parameter(
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
)
else:
self.register_parameter("weight_scale_inv", None)
self.block_size = block_size
if activation_scheme != "dynamic":
raise ValueError(f"Only dynamic activation scheme is supported for FP8Linear for now, you provided {activation_scheme}")
raise ValueError(
f"Only dynamic activation scheme is supported for FP8Linear for now, you provided {activation_scheme}"
)
self.activation_scheme = activation_scheme
if bias:
self.bias = nn.Parameter(torch.empty(self.part_out_features))
self.bias = nn.Parameter(torch.empty(self.out_features))
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return linear(x, self.weight, self.weight_scale_inv, self.bias, self.block_size, self.activation_scheme)
class FP8MoELinear(FP8Linear):
"""FP8 Linear layer for MoE implementation."""
def __init__(
self,
n_experts: int,
in_features: int,
out_features: int,
bias: bool,
device=None,
dtype=None,
activation_scheme="dynamic",
weight_block_size=None
):
super().__init__(
in_features,
out_features,
bias,
device,
dtype,
activation_scheme,
weight_block_size
)
self.n_experts = n_experts
# Reshape weight and scale for experts
self.weight = nn.Parameter(
torch.empty((n_experts, out_features, in_features),
dtype=quant_dtype,
device=device)
)
self.weight_scale = nn.Parameter(
torch.empty((n_experts, 1),
dtype=torch.float32,
device=device)
)
def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor:
if self.activation_scheme == "dynamic":
input_scale = x.abs().max() / torch.finfo(quant_dtype).max
# Select expert weights and scales
selected_weights = self.weight[expert_indices]
selected_scales = self.weight_scale[expert_indices]
# Perform FP8 matmul for each expert
output = torch._scaled_mm(
x,
selected_weights.transpose(-1, -2),
scale_a=input_scale,
scale_b=selected_scales,
out_dtype=x.dtype
)
if self.bias is not None:
output = output + self.bias[expert_indices]
return output
def _replace_with_fp8_linear(
model,
modules_to_not_convert=None,
@ -431,12 +384,11 @@ def _replace_with_fp8_linear(
for name, module in model.named_children():
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []):
current_key_name_str = ".".join(current_key_name)
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
with init_empty_weights():
# Check if this is an MoE layer
model._modules[name] = FP8Linear(
in_features=module.in_features,
out_features=module.out_features,
@ -444,10 +396,10 @@ def _replace_with_fp8_linear(
device=module.weight.device,
dtype=module.weight.dtype,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size
block_size=quantization_config.weight_block_size,
)
has_been_replaced = True
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_fp8_linear(
module,
@ -456,11 +408,12 @@ def _replace_with_fp8_linear(
quantization_config,
has_been_replaced=has_been_replaced,
)
current_key_name.pop(-1)
return model, has_been_replaced
def replace_with_fp8_linear(
model,
modules_to_not_convert=None,
@ -468,21 +421,21 @@ def replace_with_fp8_linear(
):
"""Helper function to replace model layers with FP8 versions."""
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_fp8_linear(
model,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)
if not has_been_replaced:
logger.warning(
"You are loading your model using fp8 but no linear modules were found in your model."
" Please double check your model architecture."
)
return model
return model

View file

@ -24,6 +24,7 @@ from ..utils.quantization_config import (
CompressedTensorsConfig,
EetqConfig,
FbgemmFp8Config,
FP8Config,
GPTQConfig,
HiggsConfig,
HqqConfig,
@ -32,7 +33,6 @@ from ..utils.quantization_config import (
QuantoConfig,
TorchAoConfig,
VptqConfig,
FP8Config,
)
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer
@ -42,13 +42,14 @@ from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer
from .quantizer_eetq import EetqHfQuantizer
from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
from .quantizer_fp8 import FP8HfQuantizer
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_higgs import HiggsHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
from .quantizer_torchao import TorchAoHfQuantizer
from .quantizer_vptq import VptqHfQuantizer
from .quantizer_fp8 import FP8HfQuantizer
AUTO_QUANTIZER_MAPPING = {
"awq": AwqQuantizer,

View file

@ -1,22 +1,26 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter as Parameter
from .base import HfQuantizer
from ..utils import is_accelerate_available, logging
from .base import HfQuantizer
from .quantizers_utils import get_module_from_name
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
logger = logging.get_logger(__name__)
class FP8HfQuantizer(HfQuantizer):
"""
FP8 quantization implementation supporting both standard and MoE models.
Supports both e4m3fn and e4m3fnuz formats based on platform.
"""
requires_parameters_quantization = True
requires_calibration = False
required_packages = ["accelerate"]
@ -38,30 +42,35 @@ class FP8HfQuantizer(HfQuantizer):
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for FP8 quantization.")
compute_capability = torch.cuda.get_device_capability()
major, minor = compute_capability
if major < 9:
raise ValueError(
"FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)"
)
device_map = kwargs.get("device_map", None)
if device_map is None:
logger.warning_once(
"You have loaded an FP8 model on CPU and have a CUDA device available. "
"Make sure to set your model on a GPU device to run your model."
)
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
raise ValueError(
"FP8 models do not support CPU or disk offloading in the device map. "
"Please remove CPU/disk devices from the device map."
"You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set "
"your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. "
)
def update_torch_dtype(self, torch_dtype):
torch_dtype = torch.float32
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
logger.info("Setting torch_dtype to torch.float32 as no torch_dtype was specified in from_pretrained")
# we need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
torch_dtype = torch.float32
return torch_dtype
def create_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
):
"""
Quantizes weights to FP8 format using either:
@ -74,39 +83,39 @@ class FP8HfQuantizer(HfQuantizer):
fp8_max = torch.finfo(torch.float8_e4m3fn).max
if self.quantization_config.weight_block_size is None:
self.quantization_config.weight_block_size = (128, 128)
else :
block_size_m, block_size_n = self.quantization_config.weight_block_size
else:
block_size_m, block_size_n = self.quantization_config.weight_block_size
rows, cols = param_value.shape[-2:]
# Check if dimensions are divisible by block sizes
if rows % block_size_m != 0 or cols % block_size_n != 0:
raise ValueError(
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})"
)
param_value_orig_shape = param_value.shape
# Create blocks using unfold
param_value = param_value.reshape(-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n).permute(0, 1, 3, 2, 4)
param_value = param_value.reshape(
-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n
).permute(0, 1, 3, 2, 4)
# Calculate scaling factor for each block
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
scale = fp8_max / max_abs
scale_orig_shape = scale.shape
# Expand scale to match block dimensions for multiplication
scale = scale.unsqueeze(-1).unsqueeze(-1)
# Quantize the weights
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
# Reshape back to matrix shape
quantized_param = quantized_param.reshape(param_value_orig_shape)
# Reshape scale to match the number of blocks
scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
module._parameters[tensor_name] = quantized_param.to(target_device)
module.register_parameter("weight_scale_inv", nn.Parameter(scale.to(target_device)))
@ -118,17 +127,17 @@ class FP8HfQuantizer(HfQuantizer):
state_dict: Dict[str, Any],
**kwargs,
):
from ..integrations.fp8 import FP8Linear, FP8MoELinear
from ..integrations.fp8 import FP8Linear
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, FP8Linear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
raise ValueError("Expect quantized weights but got an unquantized weight")
return False
else:
if tensor_name == "weight_scale":
if tensor_name == "weight_scale_inv":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
return False
@ -141,28 +150,27 @@ class FP8HfQuantizer(HfQuantizer):
**kwargs,
):
from ..integrations.fp8 import replace_with_fp8_linear
self.modules_to_not_convert = ["lm_head"] + modules_to_not_convert
if self.quantization_config.modules_to_not_convert:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
model = replace_with_fp8_linear(
model,
modules_to_not_convert=self.modules_to_not_convert,
quantization_config=self.quantization_config,
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
return max_memory
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
from ..integrations import FP8Linear
@ -181,6 +189,6 @@ class FP8HfQuantizer(HfQuantizer):
def is_serializable(self, safe_serialization=None):
return True
@property
@property
def is_trainable(self) -> bool:
return False # FP8 quantization is typically used for inference only
return False

View file

@ -21,7 +21,7 @@ import os
from dataclasses import dataclass
from enum import Enum
from inspect import Parameter, signature
from typing import Any, Dict, List, Optional, Union, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
from packaging import version
@ -1550,13 +1550,14 @@ class BitNetConfig(QuantizationConfigMixin):
"""
pass
@dataclass
class FP8Config(QuantizationConfigMixin):
def __init__(
self,
modules_to_not_convert: Optional[List] = None,
activation_scheme: Optional[str] = "dynamic",
weight_block_size: Optional[Tuple[int, int]] = None,
weight_block_size: Optional[Tuple[int, int]] = (128, 128),
**kwargs,
):
self.quant_method = QuantizationMethod.FP8
@ -1569,4 +1570,10 @@ class FP8Config(QuantizationConfigMixin):
r"""
Safety checker that arguments are correct
"""
pass
self.activation_scheme = self.activation_scheme.lower()
if self.activation_scheme not in ["dynamic"]:
raise ValueError(f"Activation scheme {self.activation_scheme} not supported")
if len(self.weight_block_size) != 2:
raise ValueError("weight_block_size must be a tuple of two integers")
if self.weight_block_size[0] <= 0 or self.weight_block_size[1] <= 0:
raise ValueError("weight_block_size must be a tuple of two positive integers")

View file

@ -0,0 +1,226 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FP8Config, OPTForCausalLM
from transformers.testing_utils import (
require_accelerate,
require_torch_gpu,
require_torch_multi_gpu,
slow,
)
from transformers.utils import is_accelerate_available, is_torch_available
if is_torch_available():
import torch
if is_accelerate_available():
from accelerate import init_empty_weights
@require_torch_gpu
class FP8ConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
"""
quantization_config = FP8Config()
config_to_dict = quantization_config.to_dict()
for key in config_to_dict:
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
def test_from_dict(self):
"""
Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
"""
dict = {"modules_to_not_convert": ["lm_head.weight"], "quant_method": "fp8"}
quantization_config = FP8Config.from_dict(dict)
self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert)
self.assertEqual(dict["quant_method"], quantization_config.quant_method)
@slow
@require_accelerate
@require_torch_gpu
class FP8QuantizerTest(unittest.TestCase):
model_name = "meta-llama/Llama-3.2-1B"
input_text = "Once upon a time"
max_new_tokens = 10
EXPECTED_OUTPUT = "Once upon a time, there was a man who was very rich."
device_map = "cuda"
@classmethod
def setUpClass(cls):
"""
Setup quantized model
"""
quantization_config = FP8Config()
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name, device_map=cls.device_map, quantization_config=quantization_config
)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_quantized_model_conversion(self):
"""
Simple test that checks if the quantized model has been converted properly
"""
from transformers.integrations import FP8Linear, replace_with_fp8_linear
model_id = "facebook/opt-350m"
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
quantization_config = FP8Config()
with init_empty_weights():
model = OPTForCausalLM(config)
nb_linears = 0
for module in model.modules():
if isinstance(module, torch.nn.Linear):
nb_linears += 1
model = replace_with_fp8_linear(model, quantization_config=quantization_config)
nb_fp8_linear = 0
for module in model.modules():
if isinstance(module, FP8Linear):
nb_fp8_linear += 1
self.assertEqual(nb_linears - 1, nb_fp8_linear)
with init_empty_weights():
model = OPTForCausalLM(config)
quantization_config = FP8Config(modules_to_not_convert=["fc1"])
model = replace_with_fp8_linear(model, quantization_config=quantization_config)
nb_fp8_linear = 0
for module in model.modules():
if isinstance(module, FP8Linear):
nb_fp8_linear += 1
self.assertEqual(nb_linears - 25, nb_fp8_linear)
def test_quantized_model(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_save_pretrained(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_weight_and_weight_scale_inv(self):
"""
Simple test that checks if the weight and weight_scale_inv are working properly
"""
weight = self.quantized_model.model.layers[0].self_attn.q_proj.weight
weight_scale_inv = self.quantized_model.model.layers[0].self_attn.q_proj.weight_scale_inv
self.assertEqual(weight.dtype, torch.float8_e4m3fn)
self.assertEqual(weight_scale_inv.dtype, torch.float32)
self.assertEqual(weight.shape, (weight_scale_inv.shape[0] * 128, weight_scale_inv.shape[1] * 128))
def test_block_size(self):
"""
Simple test that checks if the block size is working properly
"""
self.assertEqual(self.quantized_model.config.quantization_config.weight_block_size, (128, 128))
quantization_config = FP8Config(weight_block_size=(32, 32))
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, device_map=self.device_map, quantization_config=quantization_config
)
self.assertEqual(quantized_model.config.quantization_config.weight_block_size, (32, 32))
@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
quantization_config = FP8Config()
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, device_map="auto", quantization_config=quantization_config
)
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_save_pretrained_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto")
self.assertTrue(set(model.hf_device_map.values()) == {0, 1})
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_gpu
class FP8LinearTest(unittest.TestCase):
device = "cuda"
def test_linear_preserves_shape(self):
"""
Test that FP8Linear preserves shape when in_features == out_features.
"""
from transformers.integrations import FP8Linear
linear = FP8Linear(256, 256, device=self.device)
x = torch.rand((1, 5, 256)).to(self.device)
x_ = linear(x)
self.assertEqual(x_.shape, x.shape)
def test_linear_with_diff_feature_size_preserves_shape(self):
"""
Test that FP8Linear generates the correct shape when in_features != out_features.
"""
from transformers.integrations import FP8Linear
linear = FP8Linear(128, 256, device=self.device)
x = torch.rand((1, 5, 128)).to(self.device)
x_ = linear(x)
self.assertEqual(x_.shape, (1, 5, 256))