mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
end2end
This commit is contained in:
parent
83912834cc
commit
33f73712dc
8 changed files with 496 additions and 241 deletions
|
|
@ -0,0 +1,60 @@
|
|||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# FP8
|
||||
|
||||
With FP8 quantization method, you can quantize your model in FP8 (W8A8):
|
||||
- the weights will be quantized in 8bit (FP8) per 2D block (e.g. weight_block_size=(128, 128)) which is inspired from the deepseek implementation
|
||||
- the activation will be quantized in 8bit (FP8) per group per token
|
||||
|
||||
It's implemented to add support for DeepSeek-V3 and DeepSeek-R1 models, you can see the paper [here](https://arxiv.org/pdf/2412.19437)
|
||||
|
||||
> [!TIP]
|
||||
> You need a GPU with compute capability>=9 (e.g. H100)
|
||||
|
||||
Before you begin, make sure the following libraries are installed with their latest version:
|
||||
|
||||
```bash
|
||||
pip install --upgrade accelerate torch
|
||||
```
|
||||
> [!TIP]
|
||||
> You need to install a torch version compatible with the cuda version of your GPU.
|
||||
|
||||
|
||||
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
|
||||
|
||||
```py
|
||||
from transformers import FP8Config, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "meta-llama/Meta-Llama-3-8B"
|
||||
quantization_config = FP8Config()
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
A quantized model can be saved via "saved_pretrained" and be reused again via the "from_pretrained".
|
||||
|
||||
```py
|
||||
quant_path = "/path/to/save/quantized/model"
|
||||
model.save_pretrained(quant_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(quant_path, device_map="auto")
|
||||
```
|
||||
|
|
@ -54,7 +54,7 @@ _import_structure = {
|
|||
],
|
||||
"eetq": ["replace_with_eetq_linear"],
|
||||
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
|
||||
"fp8": ["FP8Linear", "FP8MoELinear", "replace_with_fp8_linear"],
|
||||
"fp8": ["FP8Linear", "replace_with_fp8_linear"],
|
||||
"fsdp": ["is_fsdp_managed_module"],
|
||||
"ggml": [
|
||||
"GGUF_CONFIG_MAPPING",
|
||||
|
|
|
|||
|
|
@ -13,41 +13,31 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, List, Tuple, Union
|
||||
from ..utils import is_accelerate_available, logging
|
||||
from torch.nn import functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton import Config
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..utils import is_accelerate_available, logging
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
ACTIVATION_SCHEMES = ["dynamic"]
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
# def fp8_quantize(weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# """Quantize weights to FP8."""
|
||||
|
||||
# # Calculate scale as max value divided by absmax
|
||||
# scale = 448.0 / weight.abs().max().clamp(min=1e-12)
|
||||
# # Scale and clamp tensor to FP8 range
|
||||
# qweight = (weight * scale).clamp(min=-448.0, max=448.0)
|
||||
# scale = scale.float().reciprocal()
|
||||
|
||||
# qweight = qweight.to(quant_dtype)
|
||||
# return qweight, scale
|
||||
|
||||
@triton.jit
|
||||
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(x_ptr + offs).to(tl.float32)
|
||||
s = tl.max(tl.abs(x)) / 448.
|
||||
s = tl.max(tl.abs(x)) / 448.0
|
||||
y = x / s
|
||||
y = y.to(y_ptr.dtype.element_ty)
|
||||
tl.store(y_ptr + offs, y)
|
||||
|
|
@ -56,81 +46,17 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|||
|
||||
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.is_contiguous()
|
||||
assert x.shape[-1] % block_size[0] == 0
|
||||
assert x.shape[-1] % block_size == 0
|
||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size[0], dtype=torch.float32)
|
||||
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
|
||||
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size[0])
|
||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
|
||||
|
||||
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
||||
return y, s
|
||||
|
||||
|
||||
@triton.jit
|
||||
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_n = tl.program_id(axis=1)
|
||||
n = tl.cdiv(N, BLOCK_SIZE)
|
||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
offs = offs_m[:, None] * N + offs_n[None, :]
|
||||
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
||||
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
|
||||
s = tl.load(s_ptr + pid_m * n + pid_n)
|
||||
y = x * s
|
||||
tl.store(y_ptr + offs, y, mask=mask)
|
||||
|
||||
|
||||
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||
assert x.is_contiguous() and s.is_contiguous()
|
||||
assert x.dim() == 2 and s.dim() == 2
|
||||
M, N = x.size()
|
||||
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
||||
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
|
||||
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
||||
return y
|
||||
|
||||
|
||||
fp8_gemm_configs = [
|
||||
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
|
||||
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
|
||||
]
|
||||
|
||||
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
|
||||
@triton.jit
|
||||
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
||||
a_s_ptr, b_s_ptr,
|
||||
M, N: tl.constexpr, K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_n = tl.program_id(axis=1)
|
||||
k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
|
||||
a_s_ptrs = a_s_ptr + offs_m * k
|
||||
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for i in range(k):
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
|
||||
a_s = tl.load(a_s_ptrs)
|
||||
b_s = tl.load(b_s_ptrs)
|
||||
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b_ptrs += BLOCK_SIZE_K
|
||||
a_s_ptrs += 1
|
||||
b_s_ptrs += 1
|
||||
c = accumulator.to(c_ptr.dtype.element_ty)
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=mask)
|
||||
|
||||
@triton.jit
|
||||
def _w8a8_block_fp8_matmul(
|
||||
# Pointers to inputs and output
|
||||
|
|
@ -190,12 +116,8 @@ def _w8a8_block_fp8_matmul(
|
|||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs,
|
||||
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
|
|
@ -218,7 +140,9 @@ def _w8a8_block_fp8_matmul(
|
|||
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
def w8a8_block_fp8_matmul(
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul_triton(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
|
|
@ -254,7 +178,7 @@ def w8a8_block_fp8_matmul(
|
|||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N, )
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
# TODO:
|
||||
|
|
@ -270,8 +194,7 @@ def w8a8_block_fp8_matmul(
|
|||
BLOCK_SIZE_N = block_n
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
|
||||
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
|
||||
|
||||
_w8a8_block_fp8_matmul[grid](
|
||||
A,
|
||||
|
|
@ -303,33 +226,118 @@ def w8a8_block_fp8_matmul(
|
|||
return C
|
||||
|
||||
|
||||
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
|
||||
assert a.is_contiguous() and b.is_contiguous()
|
||||
assert a_s.is_contiguous() and b_s.is_contiguous()
|
||||
K = a.size(-1)
|
||||
M = a.numel() // K
|
||||
N = b.size(0)
|
||||
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
|
||||
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
||||
return c
|
||||
# Python version of the above triton function
|
||||
@torch.compile
|
||||
def w8a8_block_fp8_matmul_compile(
|
||||
input_q: torch.Tensor, # [batch, seq_len, hidden_dim]
|
||||
weight_q: torch.Tensor, # [out_features, hidden_dim]
|
||||
input_scale: torch.Tensor, # [batch * seq_len, num_input_groups]
|
||||
weight_scale: torch.Tensor, # [num_weight_blocks_m, num_weight_blocks_n]
|
||||
block_size: Optional[Tuple[int, int]] = None, # (M=128, N=128) for weights for example
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs blocked matrix multiplication with FP8 quantized matrices.
|
||||
|
||||
def linear(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, bias: Optional[torch.Tensor] = None, block_size: Optional[Tuple[int, int]] = None, activation_scheme: str = "dynamic") -> torch.Tensor:
|
||||
Args:
|
||||
input_q: Quantized input tensor with 1x128 block quantization
|
||||
weight_q: Quantized weight tensor with 128x128 block quantization
|
||||
input_scale: Scaling factors for input blocks
|
||||
weight_scale: Scaling factors for weight blocks
|
||||
block_size: Tuple of (M, N) for weight block dimensions
|
||||
output_dtype: Desired output dtype
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim = input_q.shape if input_q.ndim == 3 else (1, input_q.shape[0], input_q.shape[1])
|
||||
out_features = weight_q.shape[0]
|
||||
|
||||
# Reshape input for batched matmul
|
||||
input_reshaped = input_q.view(-1, hidden_dim) # [batch*seq_len, hidden_dim]
|
||||
input_scale_reshaped = input_scale.view(input_scale.shape[0], -1) # [batch*seq_len, 1]
|
||||
# Calculate number of blocks
|
||||
num_weight_blocks_m = out_features // block_size[0]
|
||||
num_weight_blocks_n = hidden_dim // block_size[1]
|
||||
|
||||
# Initialize output tensor
|
||||
output = torch.zeros((batch_size * seq_len, out_features), dtype=torch.float32, device=input_q.device)
|
||||
|
||||
# Process each block
|
||||
for i in range(num_weight_blocks_m):
|
||||
m_start = i * block_size[0]
|
||||
m_end = m_start + block_size[0]
|
||||
|
||||
for j in range(num_weight_blocks_n):
|
||||
n_start = j * block_size[1]
|
||||
n_end = n_start + block_size[1]
|
||||
|
||||
# Extract current blocks
|
||||
input_block = input_reshaped[:, n_start:n_end]
|
||||
weight_block = weight_q[m_start:m_end, n_start:n_end]
|
||||
|
||||
# Get corresponding scales
|
||||
curr_input_scale = input_scale_reshaped[:, j : j + 1] # [batch*seq_len, 1]
|
||||
curr_weight_scale = weight_scale[i, j] # scalar
|
||||
|
||||
block_result = (
|
||||
torch._scaled_mm(
|
||||
input_block,
|
||||
weight_block.t(),
|
||||
scale_a=torch.tensor(1, dtype=torch.float32, device=input_q.device),
|
||||
scale_b=curr_weight_scale,
|
||||
out_dtype=output_dtype,
|
||||
)
|
||||
* curr_input_scale
|
||||
)
|
||||
|
||||
output[:, m_start:m_end] += block_result
|
||||
|
||||
output = output.view(batch_size, seq_len, out_features)
|
||||
|
||||
return output.to(output_dtype)
|
||||
|
||||
|
||||
def linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
block_size: Optional[Tuple[int, int]] = None,
|
||||
activation_scheme: str = "dynamic",
|
||||
) -> torch.Tensor:
|
||||
if weight.element_size() > 1:
|
||||
return F.linear(x, weight, bias)
|
||||
print("value not the one expected")
|
||||
return F.linear(input, weight, bias)
|
||||
else:
|
||||
x, scale = act_quant(x, block_size)
|
||||
y = fp8_gemm(x, scale, weight, weight_scale)
|
||||
# y = w8a8_block_fp8_matmul(x, weight, scale, weight_scale, block_size)
|
||||
with torch.cuda.device(input.device):
|
||||
qinput, scale = act_quant(input, block_size[1])
|
||||
torch.cuda.synchronize(device=input.device)
|
||||
with torch.cuda.device(input.device):
|
||||
output = w8a8_block_fp8_matmul_triton(
|
||||
qinput,
|
||||
weight,
|
||||
scale,
|
||||
weight_scale,
|
||||
block_size,
|
||||
output_dtype=input.dtype,
|
||||
)
|
||||
torch.cuda.synchronize(device=input.device)
|
||||
if bias is not None:
|
||||
y += bias
|
||||
return y
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype)
|
||||
|
||||
|
||||
class FP8Linear(nn.Linear):
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None, block_size: Optional[Tuple[int, int]] = None, device=None, activation_scheme="dynamic"):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = False,
|
||||
dtype=None,
|
||||
block_size: Optional[Tuple[int, int]] = None,
|
||||
device=None,
|
||||
activation_scheme="dynamic",
|
||||
):
|
||||
super().__init__(in_features=in_features, out_features=out_features)
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
|
@ -340,84 +348,29 @@ class FP8Linear(nn.Linear):
|
|||
if self.weight.element_size() == 1:
|
||||
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
|
||||
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
|
||||
self.weight_scale_inv = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device))
|
||||
self.weight_scale_inv = nn.Parameter(
|
||||
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
|
||||
)
|
||||
else:
|
||||
self.register_parameter("weight_scale_inv", None)
|
||||
|
||||
self.block_size = block_size
|
||||
|
||||
if activation_scheme != "dynamic":
|
||||
raise ValueError(f"Only dynamic activation scheme is supported for FP8Linear for now, you provided {activation_scheme}")
|
||||
raise ValueError(
|
||||
f"Only dynamic activation scheme is supported for FP8Linear for now, you provided {activation_scheme}"
|
||||
)
|
||||
self.activation_scheme = activation_scheme
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(self.part_out_features))
|
||||
self.bias = nn.Parameter(torch.empty(self.out_features))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return linear(x, self.weight, self.weight_scale_inv, self.bias, self.block_size, self.activation_scheme)
|
||||
class FP8MoELinear(FP8Linear):
|
||||
"""FP8 Linear layer for MoE implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_experts: int,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool,
|
||||
device=None,
|
||||
dtype=None,
|
||||
activation_scheme="dynamic",
|
||||
weight_block_size=None
|
||||
):
|
||||
super().__init__(
|
||||
in_features,
|
||||
out_features,
|
||||
bias,
|
||||
device,
|
||||
dtype,
|
||||
activation_scheme,
|
||||
weight_block_size
|
||||
)
|
||||
self.n_experts = n_experts
|
||||
|
||||
# Reshape weight and scale for experts
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((n_experts, out_features, in_features),
|
||||
dtype=quant_dtype,
|
||||
device=device)
|
||||
)
|
||||
self.weight_scale = nn.Parameter(
|
||||
torch.empty((n_experts, 1),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
if self.activation_scheme == "dynamic":
|
||||
input_scale = x.abs().max() / torch.finfo(quant_dtype).max
|
||||
|
||||
# Select expert weights and scales
|
||||
selected_weights = self.weight[expert_indices]
|
||||
selected_scales = self.weight_scale[expert_indices]
|
||||
|
||||
# Perform FP8 matmul for each expert
|
||||
output = torch._scaled_mm(
|
||||
x,
|
||||
selected_weights.transpose(-1, -2),
|
||||
scale_a=input_scale,
|
||||
scale_b=selected_scales,
|
||||
out_dtype=x.dtype
|
||||
)
|
||||
|
||||
if self.bias is not None:
|
||||
output = output + self.bias[expert_indices]
|
||||
|
||||
return output
|
||||
|
||||
def _replace_with_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
|
|
@ -431,12 +384,11 @@ def _replace_with_fp8_linear(
|
|||
|
||||
for name, module in model.named_children():
|
||||
current_key_name.append(name)
|
||||
|
||||
|
||||
if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []):
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
|
||||
with init_empty_weights():
|
||||
# Check if this is an MoE layer
|
||||
model._modules[name] = FP8Linear(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
|
|
@ -444,10 +396,10 @@ def _replace_with_fp8_linear(
|
|||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
activation_scheme=quantization_config.activation_scheme,
|
||||
block_size=quantization_config.weight_block_size
|
||||
block_size=quantization_config.weight_block_size,
|
||||
)
|
||||
has_been_replaced = True
|
||||
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_fp8_linear(
|
||||
module,
|
||||
|
|
@ -456,11 +408,12 @@ def _replace_with_fp8_linear(
|
|||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
|
||||
|
||||
current_key_name.pop(-1)
|
||||
|
||||
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def replace_with_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
|
|
@ -468,21 +421,21 @@ def replace_with_fp8_linear(
|
|||
):
|
||||
"""Helper function to replace model layers with FP8 versions."""
|
||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||
|
||||
|
||||
if quantization_config.modules_to_not_convert is not None:
|
||||
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
||||
modules_to_not_convert = list(set(modules_to_not_convert))
|
||||
|
||||
|
||||
model, has_been_replaced = _replace_with_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model using fp8 but no linear modules were found in your model."
|
||||
" Please double check your model architecture."
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from ..utils.quantization_config import (
|
|||
CompressedTensorsConfig,
|
||||
EetqConfig,
|
||||
FbgemmFp8Config,
|
||||
FP8Config,
|
||||
GPTQConfig,
|
||||
HiggsConfig,
|
||||
HqqConfig,
|
||||
|
|
@ -32,7 +33,6 @@ from ..utils.quantization_config import (
|
|||
QuantoConfig,
|
||||
TorchAoConfig,
|
||||
VptqConfig,
|
||||
FP8Config,
|
||||
)
|
||||
from .quantizer_aqlm import AqlmHfQuantizer
|
||||
from .quantizer_awq import AwqQuantizer
|
||||
|
|
@ -42,13 +42,14 @@ from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
|
|||
from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer
|
||||
from .quantizer_eetq import EetqHfQuantizer
|
||||
from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
|
||||
from .quantizer_fp8 import FP8HfQuantizer
|
||||
from .quantizer_gptq import GptqHfQuantizer
|
||||
from .quantizer_higgs import HiggsHfQuantizer
|
||||
from .quantizer_hqq import HqqHfQuantizer
|
||||
from .quantizer_quanto import QuantoHfQuantizer
|
||||
from .quantizer_torchao import TorchAoHfQuantizer
|
||||
from .quantizer_vptq import VptqHfQuantizer
|
||||
from .quantizer_fp8 import FP8HfQuantizer
|
||||
|
||||
|
||||
AUTO_QUANTIZER_MAPPING = {
|
||||
"awq": AwqQuantizer,
|
||||
|
|
|
|||
|
|
@ -1,22 +1,26 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import Parameter as Parameter
|
||||
from .base import HfQuantizer
|
||||
|
||||
from ..utils import is_accelerate_available, logging
|
||||
from .base import HfQuantizer
|
||||
from .quantizers_utils import get_module_from_name
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FP8HfQuantizer(HfQuantizer):
|
||||
"""
|
||||
FP8 quantization implementation supporting both standard and MoE models.
|
||||
Supports both e4m3fn and e4m3fnuz formats based on platform.
|
||||
"""
|
||||
|
||||
|
||||
requires_parameters_quantization = True
|
||||
requires_calibration = False
|
||||
required_packages = ["accelerate"]
|
||||
|
|
@ -38,30 +42,35 @@ class FP8HfQuantizer(HfQuantizer):
|
|||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("No GPU found. A GPU is needed for FP8 quantization.")
|
||||
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, minor = compute_capability
|
||||
if major < 9:
|
||||
raise ValueError(
|
||||
"FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)"
|
||||
)
|
||||
|
||||
device_map = kwargs.get("device_map", None)
|
||||
if device_map is None:
|
||||
logger.warning_once(
|
||||
"You have loaded an FP8 model on CPU and have a CUDA device available. "
|
||||
"Make sure to set your model on a GPU device to run your model."
|
||||
)
|
||||
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
|
||||
raise ValueError(
|
||||
"FP8 models do not support CPU or disk offloading in the device map. "
|
||||
"Please remove CPU/disk devices from the device map."
|
||||
"You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set "
|
||||
"your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. "
|
||||
)
|
||||
|
||||
def update_torch_dtype(self, torch_dtype):
|
||||
torch_dtype = torch.float32
|
||||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if torch_dtype is None:
|
||||
logger.info("Setting torch_dtype to torch.float32 as no torch_dtype was specified in from_pretrained")
|
||||
# we need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
|
||||
torch_dtype = torch.float32
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: Dict[str, Any],
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: Dict[str, Any],
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Quantizes weights to FP8 format using either:
|
||||
|
|
@ -74,39 +83,39 @@ class FP8HfQuantizer(HfQuantizer):
|
|||
fp8_max = torch.finfo(torch.float8_e4m3fn).max
|
||||
if self.quantization_config.weight_block_size is None:
|
||||
self.quantization_config.weight_block_size = (128, 128)
|
||||
else :
|
||||
block_size_m, block_size_n = self.quantization_config.weight_block_size
|
||||
|
||||
else:
|
||||
block_size_m, block_size_n = self.quantization_config.weight_block_size
|
||||
|
||||
rows, cols = param_value.shape[-2:]
|
||||
|
||||
# Check if dimensions are divisible by block sizes
|
||||
|
||||
if rows % block_size_m != 0 or cols % block_size_n != 0:
|
||||
raise ValueError(
|
||||
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})"
|
||||
)
|
||||
|
||||
|
||||
param_value_orig_shape = param_value.shape
|
||||
|
||||
# Create blocks using unfold
|
||||
param_value = param_value.reshape(-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n).permute(0, 1, 3, 2, 4)
|
||||
|
||||
|
||||
param_value = param_value.reshape(
|
||||
-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n
|
||||
).permute(0, 1, 3, 2, 4)
|
||||
|
||||
# Calculate scaling factor for each block
|
||||
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
|
||||
scale = fp8_max / max_abs
|
||||
scale_orig_shape = scale.shape
|
||||
# Expand scale to match block dimensions for multiplication
|
||||
scale = scale.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
|
||||
# Quantize the weights
|
||||
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
|
||||
# Reshape back to matrix shape
|
||||
quantized_param = quantized_param.reshape(param_value_orig_shape)
|
||||
|
||||
|
||||
# Reshape scale to match the number of blocks
|
||||
scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
|
||||
|
||||
|
||||
module._parameters[tensor_name] = quantized_param.to(target_device)
|
||||
module.register_parameter("weight_scale_inv", nn.Parameter(scale.to(target_device)))
|
||||
|
||||
|
|
@ -118,17 +127,17 @@ class FP8HfQuantizer(HfQuantizer):
|
|||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations.fp8 import FP8Linear, FP8MoELinear
|
||||
|
||||
from ..integrations.fp8 import FP8Linear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
|
||||
if isinstance(module, FP8Linear):
|
||||
if self.pre_quantized or tensor_name == "bias":
|
||||
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
|
||||
raise ValueError("Expect quantized weights but got an unquantized weight")
|
||||
return False
|
||||
else:
|
||||
if tensor_name == "weight_scale":
|
||||
if tensor_name == "weight_scale_inv":
|
||||
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
|
||||
return True
|
||||
return False
|
||||
|
|
@ -141,28 +150,27 @@ class FP8HfQuantizer(HfQuantizer):
|
|||
**kwargs,
|
||||
):
|
||||
from ..integrations.fp8 import replace_with_fp8_linear
|
||||
|
||||
|
||||
self.modules_to_not_convert = ["lm_head"] + modules_to_not_convert
|
||||
|
||||
|
||||
if self.quantization_config.modules_to_not_convert:
|
||||
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
|
||||
|
||||
|
||||
model = replace_with_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=self.modules_to_not_convert,
|
||||
quantization_config=self.quantization_config,
|
||||
)
|
||||
|
||||
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
|
||||
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
||||
return max_memory
|
||||
|
||||
|
||||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
||||
from ..integrations import FP8Linear
|
||||
|
||||
|
|
@ -181,6 +189,6 @@ class FP8HfQuantizer(HfQuantizer):
|
|||
def is_serializable(self, safe_serialization=None):
|
||||
return True
|
||||
|
||||
@property
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return False # FP8 quantization is typically used for inference only
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ import os
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from inspect import Parameter, signature
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
|
|
@ -1550,13 +1550,14 @@ class BitNetConfig(QuantizationConfigMixin):
|
|||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FP8Config(QuantizationConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
modules_to_not_convert: Optional[List] = None,
|
||||
activation_scheme: Optional[str] = "dynamic",
|
||||
weight_block_size: Optional[Tuple[int, int]] = None,
|
||||
weight_block_size: Optional[Tuple[int, int]] = (128, 128),
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.FP8
|
||||
|
|
@ -1569,4 +1570,10 @@ class FP8Config(QuantizationConfigMixin):
|
|||
r"""
|
||||
Safety checker that arguments are correct
|
||||
"""
|
||||
pass
|
||||
self.activation_scheme = self.activation_scheme.lower()
|
||||
if self.activation_scheme not in ["dynamic"]:
|
||||
raise ValueError(f"Activation scheme {self.activation_scheme} not supported")
|
||||
if len(self.weight_block_size) != 2:
|
||||
raise ValueError("weight_block_size must be a tuple of two integers")
|
||||
if self.weight_block_size[0] <= 0 or self.weight_block_size[1] <= 0:
|
||||
raise ValueError("weight_block_size must be a tuple of two positive integers")
|
||||
|
|
|
|||
0
tests/quantization/fp8_integration/__init__.py
Normal file
0
tests/quantization/fp8_integration/__init__.py
Normal file
226
tests/quantization/fp8_integration/test_fp8.py
Normal file
226
tests/quantization/fp8_integration/test_fp8.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FP8Config, OPTForCausalLM
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import is_accelerate_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
@require_torch_gpu
|
||||
class FP8ConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
|
||||
"""
|
||||
quantization_config = FP8Config()
|
||||
config_to_dict = quantization_config.to_dict()
|
||||
|
||||
for key in config_to_dict:
|
||||
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
|
||||
|
||||
def test_from_dict(self):
|
||||
"""
|
||||
Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
|
||||
"""
|
||||
dict = {"modules_to_not_convert": ["lm_head.weight"], "quant_method": "fp8"}
|
||||
quantization_config = FP8Config.from_dict(dict)
|
||||
|
||||
self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert)
|
||||
self.assertEqual(dict["quant_method"], quantization_config.quant_method)
|
||||
|
||||
|
||||
@slow
|
||||
@require_accelerate
|
||||
@require_torch_gpu
|
||||
class FP8QuantizerTest(unittest.TestCase):
|
||||
model_name = "meta-llama/Llama-3.2-1B"
|
||||
input_text = "Once upon a time"
|
||||
max_new_tokens = 10
|
||||
EXPECTED_OUTPUT = "Once upon a time, there was a man who was very rich."
|
||||
device_map = "cuda"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Setup quantized model
|
||||
"""
|
||||
quantization_config = FP8Config()
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model_name, device_map=cls.device_map, quantization_config=quantization_config
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def test_quantized_model_conversion(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model has been converted properly
|
||||
"""
|
||||
|
||||
from transformers.integrations import FP8Linear, replace_with_fp8_linear
|
||||
|
||||
model_id = "facebook/opt-350m"
|
||||
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
|
||||
quantization_config = FP8Config()
|
||||
|
||||
with init_empty_weights():
|
||||
model = OPTForCausalLM(config)
|
||||
|
||||
nb_linears = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
nb_linears += 1
|
||||
|
||||
model = replace_with_fp8_linear(model, quantization_config=quantization_config)
|
||||
nb_fp8_linear = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, FP8Linear):
|
||||
nb_fp8_linear += 1
|
||||
|
||||
self.assertEqual(nb_linears - 1, nb_fp8_linear)
|
||||
|
||||
with init_empty_weights():
|
||||
model = OPTForCausalLM(config)
|
||||
quantization_config = FP8Config(modules_to_not_convert=["fc1"])
|
||||
model = replace_with_fp8_linear(model, quantization_config=quantization_config)
|
||||
nb_fp8_linear = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, FP8Linear):
|
||||
nb_fp8_linear += 1
|
||||
|
||||
self.assertEqual(nb_linears - 25, nb_fp8_linear)
|
||||
|
||||
def test_quantized_model(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
|
||||
|
||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_save_pretrained(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly after being saved and loaded
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
|
||||
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
|
||||
|
||||
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_weight_and_weight_scale_inv(self):
|
||||
"""
|
||||
Simple test that checks if the weight and weight_scale_inv are working properly
|
||||
"""
|
||||
weight = self.quantized_model.model.layers[0].self_attn.q_proj.weight
|
||||
weight_scale_inv = self.quantized_model.model.layers[0].self_attn.q_proj.weight_scale_inv
|
||||
self.assertEqual(weight.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(weight_scale_inv.dtype, torch.float32)
|
||||
self.assertEqual(weight.shape, (weight_scale_inv.shape[0] * 128, weight_scale_inv.shape[1] * 128))
|
||||
|
||||
def test_block_size(self):
|
||||
"""
|
||||
Simple test that checks if the block size is working properly
|
||||
"""
|
||||
self.assertEqual(self.quantized_model.config.quantization_config.weight_block_size, (128, 128))
|
||||
quantization_config = FP8Config(weight_block_size=(32, 32))
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, device_map=self.device_map, quantization_config=quantization_config
|
||||
)
|
||||
self.assertEqual(quantized_model.config.quantization_config.weight_block_size, (32, 32))
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_quantized_model_multi_gpu(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly with multiple GPUs
|
||||
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
|
||||
quantization_config = FP8Config()
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, device_map="auto", quantization_config=quantization_config
|
||||
)
|
||||
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_save_pretrained_multi_gpu(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly after being saved and loaded
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto")
|
||||
self.assertTrue(set(model.hf_device_map.values()) == {0, 1})
|
||||
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
|
||||
|
||||
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class FP8LinearTest(unittest.TestCase):
|
||||
device = "cuda"
|
||||
def test_linear_preserves_shape(self):
|
||||
"""
|
||||
Test that FP8Linear preserves shape when in_features == out_features.
|
||||
"""
|
||||
from transformers.integrations import FP8Linear
|
||||
|
||||
linear = FP8Linear(256, 256, device=self.device)
|
||||
x = torch.rand((1, 5, 256)).to(self.device)
|
||||
|
||||
x_ = linear(x)
|
||||
self.assertEqual(x_.shape, x.shape)
|
||||
|
||||
def test_linear_with_diff_feature_size_preserves_shape(self):
|
||||
"""
|
||||
Test that FP8Linear generates the correct shape when in_features != out_features.
|
||||
"""
|
||||
from transformers.integrations import FP8Linear
|
||||
|
||||
|
||||
linear = FP8Linear(128, 256, device=self.device)
|
||||
x = torch.rand((1, 5, 128)).to(self.device)
|
||||
|
||||
x_ = linear(x)
|
||||
self.assertEqual(x_.shape, (1, 5, 256))
|
||||
Loading…
Reference in a new issue