fix create_quantized_param

This commit is contained in:
MekkCyber 2025-02-04 17:09:58 +00:00
parent 3700bbc09f
commit 70749dfd9b
2 changed files with 199 additions and 33 deletions

View file

@ -56,11 +56,11 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.is_contiguous()
assert x.shape[-1] % block_size == 0
assert x.shape[-1] % block_size[0] == 0
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size[0], dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size[0])
return y, s
@ -131,6 +131,177 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, c, mask=mask)
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and
store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise
quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)
# TODO:
# BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized.
# BLOCK_SIZE_K must be divisible by block_k
# BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements
BLOCK_SIZE_M = 128
if M < BLOCK_SIZE_M:
BLOCK_SIZE_M = triton.next_power_of_2(M)
BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
BLOCK_SIZE_K = block_k
assert block_k % BLOCK_SIZE_K == 0
BLOCK_SIZE_N = block_n
def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
_w8a8_block_fp8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=8,
)
return C
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
assert a.is_contiguous() and b.is_contiguous()
@ -147,49 +318,45 @@ def linear(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, bi
if weight.element_size() > 1:
return F.linear(x, weight, bias)
else:
if block_size is None:
block_size = 128
else :
block_size = 128
x, scale = act_quant(x, block_size)
y = fp8_gemm(x, scale, weight, weight_scale)
# y = w8a8_block_fp8_matmul(x, weight, scale, weight_scale, block_size)
if bias is not None:
y += bias
return y
class FP8Linear(nn.Module):
class FP8Linear(nn.Linear):
dtype = torch.float8_e4m3fn
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None, block_size: Optional[Tuple[int, int]] = None, device=None, activation_scheme="dynamic"):
super().__init__()
super().__init__(in_features=in_features, out_features=out_features)
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device))
# if self.weight.element_size() == 1:
if block_size is None:
block_size = self.weight.shape
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
self.weight_scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device))
# else:
# self.register_parameter("weight_scale", None)
if self.weight.element_size() == 1:
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
self.weight_scale_inv = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device))
else:
self.register_parameter("weight_scale_inv", None)
self.block_size = block_size
if activation_scheme == "dynamic":
self.register_parameter("input_scale", None)
else :
if activation_scheme != "dynamic":
raise ValueError(f"Only dynamic activation scheme is supported for FP8Linear for now, you provided {activation_scheme}")
self.activation_scheme = activation_scheme
if bias:
self.bias = nn.Parameter(torch.empty(self.part_out_features))
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
print(self.weight_scale)
return linear(x, self.weight, self.weight_scale, self.bias, self.block_size, self.activation_scheme)
return linear(x, self.weight, self.weight_scale_inv, self.bias, self.block_size, self.activation_scheme)
class FP8MoELinear(FP8Linear):
"""FP8 Linear layer for MoE implementation."""

View file

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter as Parameter
from .base import HfQuantizer
from ..utils import is_accelerate_available, logging
@ -16,14 +17,13 @@ class FP8HfQuantizer(HfQuantizer):
Supports both e4m3fn and e4m3fnuz formats based on platform.
"""
requires_parameters_quantization = False
requires_parameters_quantization = True
requires_calibration = False
required_packages = ["accelerate"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config
self.is_moe_model = kwargs.get("is_moe_model", False)
def validate_environment(self, *args, **kwargs):
if not is_accelerate_available():
@ -69,11 +69,10 @@ class FP8HfQuantizer(HfQuantizer):
- Per-tensor quantization when weight_block_size is None
"""
module, tensor_name = get_module_from_name(model, param_name)
print("######################### in quantizer_fp8.py #########################")
# Get FP8 min/max values
fp8_min = torch.finfo(torch.float8_e4m3fn).min
fp8_max = torch.finfo(torch.float8_e4m3fn).max
if self.quantization_config.weight_block_size is not None:
block_size_m, block_size_n = self.quantization_config.weight_block_size
@ -105,22 +104,22 @@ class FP8HfQuantizer(HfQuantizer):
quantized_param = quantized_param.reshape(param_value.shape[:-4] + (rows, cols))
# Reshape scale to match the number of blocks
scale = scale.reshape(scale.shape[:-2] + (-1,)).reciprocal()
scale = scale.reshape(scale.shape[:-2] + (-1,)).squeeze(-1).reciprocal()
else:
# Per-tensor quantization
max_abs = torch.max(torch.abs(param_value))
print("###################max_abs#################", max_abs)
# print("###################max_abs#################", max_abs)
scale = fp8_max / max_abs
print("###################scale#################", scale)
# print("###################scale#################", scale)
# Quantize the weights
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
# For per-tensor quantization, we just need a single scale value
scale = torch.tensor([[scale]]).reciprocal()
# print("###################after reciprocal scale#################", scale)
# Store the quantized weights and scales in the module
print(f"tensor_name in create_quantized_param: {tensor_name} {target_device}")
module._buffers[tensor_name] = quantized_param.to(target_device)
module._buffers["weight_scale"] = scale.to(target_device)
module._parameters[tensor_name] = quantized_param.to(target_device)
module.register_parameter("weight_scale_inv", nn.Parameter(scale.to(target_device)))
def check_quantized_param(
self,
@ -136,7 +135,7 @@ class FP8HfQuantizer(HfQuantizer):
if isinstance(module, FP8Linear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.int8:
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
raise ValueError("Expect quantized weights but got an unquantized weight")
return False
else:
@ -149,12 +148,12 @@ class FP8HfQuantizer(HfQuantizer):
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = [],
modules_to_not_convert: List[str] = [],
**kwargs,
):
from ..integrations.fp8 import replace_with_fp8_linear
self.modules_to_not_convert = ["lm_head"] + keep_in_fp32_modules
self.modules_to_not_convert = ["lm_head"] + modules_to_not_convert
if self.quantization_config.modules_to_not_convert:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)