adding kernels

This commit is contained in:
MekkCyber 2025-02-04 14:57:18 +00:00
parent b0c3641f56
commit 3700bbc09f
2 changed files with 181 additions and 310 deletions

View file

@ -17,290 +17,179 @@ import torch
import torch.nn as nn
from typing import Optional, List, Tuple, Union
from ..utils import is_accelerate_available, logging
from torch.nn import functional as F
import triton
import triton.language as tl
from triton import Config
if is_accelerate_available():
from accelerate import init_empty_weights
logger = logging.get_logger(__name__)
ACTIVATION_SCHEMES = ["static", "dynamic"]
ACTIVATION_SCHEMES = ["dynamic"]
quant_dtype = torch.float8_e4m3fn
def fp8_quantize(weight: torch.Tensor, scale: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""Quantize weights to FP8."""
if scale is None:
# Calculate scale as max value divided by absmax
scale = 448.0 / weight.abs().max().clamp(min=1e-12)
# Scale and clamp tensor to FP8 range
qweight = (weight * scale).clamp(min=-448.0, max=448.0)
scale = scale.float().reciprocal()
# def fp8_quantize(weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# """Quantize weights to FP8."""
# # Calculate scale as max value divided by absmax
# scale = 448.0 / weight.abs().max().clamp(min=1e-12)
# # Scale and clamp tensor to FP8 range
# qweight = (weight * scale).clamp(min=-448.0, max=448.0)
# scale = scale.float().reciprocal()
# qweight = qweight.to(quant_dtype)
# return qweight, scale
@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.is_contiguous()
assert x.shape[-1] % block_size == 0
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
return y, s
@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, BLOCK_SIZE)
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
s = tl.load(s_ptr + pid_m * n + pid_n)
y = x * s
tl.store(y_ptr + offs, y, mask=mask)
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
assert x.is_contiguous() and s.is_contiguous()
assert x.dim() == 2 and s.dim() == 2
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y
fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
]
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
@triton.jit
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
a_s_ptr, b_s_ptr,
M, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
k = tl.cdiv(K, BLOCK_SIZE_K)
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
a_s_ptrs = a_s_ptr + offs_m * k
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(k):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
a_s = tl.load(a_s_ptrs)
b_s = tl.load(b_s_ptrs)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
a_s_ptrs += 1
b_s_ptrs += 1
c = accumulator.to(c_ptr.dtype.element_ty)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, c, mask=mask)
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
assert a.is_contiguous() and b.is_contiguous()
assert a_s.is_contiguous() and b_s.is_contiguous()
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
return c
def linear(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, bias: Optional[torch.Tensor] = None, block_size: Optional[Tuple[int, int]] = None, activation_scheme: str = "dynamic") -> torch.Tensor:
if weight.element_size() > 1:
return F.linear(x, weight, bias)
else:
qweight = (weight * scale.reciprocal()).clamp(min=-448.0, max=448.0)
if block_size is None:
block_size = 128
else :
block_size = 128
x, scale = act_quant(x, block_size)
y = fp8_gemm(x, scale, weight, weight_scale)
if bias is not None:
y += bias
return y
qweight = qweight.to(quant_dtype)
return qweight, scale
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-12,
dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Performs per-token-group quantization on input tensor, converting to FP8.
Args:
x (torch.Tensor): Input tensor to quantize (shape: [..., hidden_dim])
group_size (int): Size of groups for quantization
column_major_scales (bool): If True, returns scales in column-major format
eps (float): Small value to avoid division by zero
dtype (torch.dtype, optional): FP8 dtype to use. Defaults to platform-specific.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scaling factors
"""
# Input validation
assert x.ndim >= 2, "Input tensor must have at least 2 dimensions"
assert x.shape[-1] % group_size == 0, f"Last dimension ({x.shape[-1]}) must be divisible by group_size ({group_size})"
# Determine FP8 dtype and limits
if dtype is None:
dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
finfo = torch.finfo(dtype)
# Reshape input for group-wise operations
orig_shape = x.shape
num_groups = x.shape[-1] // group_size
# Reshape to [*, num_groups, group_size]
x_reshaped = x.view(-1, num_groups, group_size)
# Calculate max absolute values per group
max_abs = x_reshaped.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps)
# Calculate scales as max_dtype / max_abs
scales = finfo.max / max_abs
# Quantize values
x_scaled = (x_reshaped * scales)
x_quant = x_scaled.clamp(min=finfo.min, max=finfo.max).to(dtype)
# Reshape back to original shape
x_quant = x_quant.view(orig_shape)
# Process scales
scales = scales.squeeze(-1) # Remove the last singleton dimension
scales = scales.view(-1, num_groups)
# Return reciprocal of scales for compatibility with other operations
return x_quant, scales.float().reciprocal()
def per_token_group_dequant_fp8(
x: torch.Tensor,
scales: torch.Tensor,
group_size: int,
output_dtype: torch.dtype = torch.float16
) -> torch.Tensor:
"""
Dequantizes FP8 tensor back to floating point using group scales.
Args:
x (torch.Tensor): Quantized input tensor
scales (torch.Tensor): Scale factors (reciprocal)
group_size (int): Size of groups used in quantization
output_dtype (torch.dtype): Output dtype
Returns:
torch.Tensor: Dequantized tensor
"""
# Reshape input for group-wise operations
orig_shape = x.shape
num_groups = x.shape[-1] // group_size
x_reshaped = x.view(-1, num_groups, group_size)
# Ensure scales have correct shape
if scales.ndim == 2:
scales = scales.view(-1, num_groups, 1)
else:
scales = scales.view(*orig_shape[:-1], num_groups, 1)
# Dequantize
x_dequant = x_reshaped.to(torch.float32) * scales
# Reshape back and convert to desired dtype
return x_dequant.view(orig_shape).to(output_dtype)
@torch.compile
def w8a8_block_fp8_matmul(
input_q: torch.Tensor, # [batch, seq_len, hidden_dim]
weight_q: torch.Tensor, # [out_features, hidden_dim]
input_scale: torch.Tensor, # [batch * seq_len, num_input_groups]
weight_scale: torch.Tensor, # [num_weight_blocks_m, num_weight_blocks_n]
block_size: Tuple[int, int], # (M=128, N=128) for weights
output_dtype: torch.dtype = torch.float16
) -> torch.Tensor:
"""
Performs blocked matrix multiplication with FP8 quantized matrices.
Args:
input_q: Quantized input tensor with 1x128 block quantization
weight_q: Quantized weight tensor with 128x128 block quantization
input_scale: Scaling factors for input blocks
weight_scale: Scaling factors for weight blocks
block_size: Tuple of (M, N) for weight block dimensions
output_dtype: Desired output dtype
"""
batch_size, seq_len, hidden_dim = input_q.shape
out_features = weight_q.shape[0]
# Reshape input for batched matmul
input_reshaped = input_q.view(-1, hidden_dim) # [batch*seq_len, hidden_dim]
# Calculate number of blocks
num_weight_blocks_m = out_features // block_size[0]
num_weight_blocks_n = hidden_dim // block_size[1]
# Initialize output tensor
output = torch.zeros((batch_size * seq_len, out_features),
dtype=torch.float32,
device=input_q.device)
# Process each block
for i in range(num_weight_blocks_m):
m_start = i * block_size[0]
m_end = m_start + block_size[0]
for j in range(num_weight_blocks_n):
n_start = j * block_size[1]
n_end = n_start + block_size[1]
# Extract current blocks
input_block = input_reshaped[:, n_start:n_end]
weight_block = weight_q[m_start:m_end, n_start:n_end]
# Get corresponding scales
curr_input_scale = input_scale[:, j:j+1] # [batch*seq_len, 1]
curr_weight_scale = weight_scale[i, j] # scalar
# Dequantize and multiply
block_result = torch._scaled_mm(
input_block,
weight_block.t(),
scale_a=curr_input_scale,
scale_b=curr_weight_scale,
out_dtype=x.dtype
)
# block_result = torch.matmul(
# input_block.to(torch.float32) * curr_input_scale,
# weight_block.to(torch.float32).t() * curr_weight_scale
# )
# Accumulate result
output[:, m_start:m_end] += block_result
# Reshape output back to original dimensions
output = output.view(batch_size, seq_len, out_features)
return output.to(output_dtype)
def fp8_quantize(weight, scale: Optional[torch.Tensor] = None, qdtype=torch.float8_e4m3fn):
if scale is None:
# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
scale = scale.float().reciprocal()
else:
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
return qweight, scale
def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Convert e4m3fn weights and scales to e4m3fnuz format for ROCm compatibility."""
if weight.dtype != torch.float8_e4m3fn:
return weight, weight_scale, input_scale
# Convert -128 (NaN in e4m3fnuz) to 0
weight_as_int8 = weight.view(torch.int8)
weight_as_int8[weight_as_int8 == -128] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
# Double scales since e4m3fnuz values are half of e4m3fn
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
return weight, weight_scale, input_scale
class FP8Linear(nn.Module):
"""FP8 Linear layer implementation."""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
device=None,
dtype=None,
activation_scheme="dynamic",
weight_block_size=None
):
dtype = torch.float8_e4m3fn
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None, block_size: Optional[Tuple[int, int]] = None, device=None, activation_scheme="dynamic"):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device))
# if self.weight.element_size() == 1:
if block_size is None:
block_size = self.weight.shape
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
self.weight_scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device))
# else:
# self.register_parameter("weight_scale", None)
self.block_size = block_size
if activation_scheme == "dynamic":
self.register_parameter("input_scale", None)
else :
raise ValueError(f"Only dynamic activation scheme is supported for FP8Linear for now, you provided {activation_scheme}")
self.activation_scheme = activation_scheme
self.weight_block_size = weight_block_size
self.weight = nn.Parameter(torch.empty((out_features, in_features), dtype=quant_dtype, device=device))
self.weight_scale = nn.Parameter(torch.empty(1, dtype=torch.float32, device=device))
if activation_scheme == "static":
self.input_scale = nn.Parameter(torch.empty(1, dtype=torch.float32, device=device))
else:
self.register_parameter('input_scale', None)
if bias:
self.bias = nn.Parameter(torch.empty(out_features, dtype=dtype, device=device))
self.bias = nn.Parameter(torch.empty(self.part_out_features))
else:
self.register_parameter('bias', None)
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Handle ROCm compatibility
# Standard FP8 matmul
if self.activation_scheme == "dynamic":
qinput, self.input_scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
self.weight, self.weight_scale, self.input_scale
)
output = w8a8_block_fp8_matmul(
qinput,
weight,
input_scale,
weight_scale,
self.weight_block_size,
output_dtype=input.dtype,
)
if self.bias is not None:
output = output + self.bias
return output
print(self.weight_scale)
return linear(x, self.weight, self.weight_scale, self.bias, self.block_size, self.activation_scheme)
class FP8MoELinear(FP8Linear):
"""FP8 Linear layer for MoE implementation."""
@ -339,17 +228,14 @@ class FP8MoELinear(FP8Linear):
)
def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor:
# Handle ROCm compatibility
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
self.weight, self.weight_scale, self.input_scale
)
if self.activation_scheme == "dynamic":
input_scale = x.abs().max() / torch.finfo(quant_dtype).max
# Select expert weights and scales
selected_weights = weight[expert_indices]
selected_scales = weight_scale[expert_indices]
selected_weights = self.weight[expert_indices]
selected_scales = self.weight_scale[expert_indices]
# Perform FP8 matmul for each expert
output = torch._scaled_mm(
@ -384,32 +270,15 @@ def _replace_with_fp8_linear(
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
with init_empty_weights():
# Check if this is an MoE layer
is_moe = any(moe_key in current_key_name_str
for moe_key in ["gate", "experts"])
is_moe = False
if is_moe:
n_experts = getattr(model.config, "num_experts", 8)
model._modules[name] = FP8MoELinear(
n_experts=n_experts,
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
activation_scheme=quantization_config.activation_scheme,
weight_block_size=quantization_config.weight_block_size
)
else:
model._modules[name] = FP8Linear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
activation_scheme=quantization_config.activation_scheme,
weight_block_size=quantization_config.weight_block_size
)
model._modules[name] = FP8Linear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size
)
has_been_replaced = True
if len(list(module.children())) > 0:

View file

@ -69,16 +69,15 @@ class FP8HfQuantizer(HfQuantizer):
- Per-tensor quantization when weight_block_size is None
"""
module, tensor_name = get_module_from_name(model, param_name)
print("######################### in quantizer_fp8.py #########################")
# Get FP8 min/max values
fp8_min = torch.finfo(torch.float8_e4m3fn).min
fp8_max = torch.finfo(torch.float8_e4m3fn).max
if self.quantization_config.weight_block_size is not None:
# Block-wise quantization
block_size_m, block_size_n = self.quantization_config.weight_block_size
# Get matrix dimensions
rows, cols = param_value.shape[-2:]
# Check if dimensions are divisible by block sizes
@ -87,6 +86,7 @@ class FP8HfQuantizer(HfQuantizer):
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})"
)
# Create blocks using unfold
param_value = param_value.unfold(-2, block_size_m, block_size_m)
param_value = param_value.unfold(-2, block_size_n, block_size_n)
@ -99,24 +99,24 @@ class FP8HfQuantizer(HfQuantizer):
scale = scale.unsqueeze(-1).unsqueeze(-1)
# Quantize the weights
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max)
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
# Reshape back to matrix shape
quantized_param = quantized_param.reshape(param_value.shape[:-4] + (rows, cols))
# Reshape scale to match the number of blocks
scale = scale.reshape(scale.shape[:-2] + (-1,))
scale = scale.reshape(scale.shape[:-2] + (-1,)).reciprocal()
else:
# Per-tensor quantization
max_abs = torch.max(torch.abs(param_value))
print("###################max_abs#################", max_abs)
scale = fp8_max / max_abs
print("###################scale#################", scale)
# Quantize the weights
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max)
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
# For per-tensor quantization, we just need a single scale value
scale = torch.tensor([scale])
scale = torch.tensor([[scale]]).reciprocal()
# Store the quantized weights and scales in the module
print(f"tensor_name in create_quantized_param: {tensor_name} {target_device}")
module._buffers[tensor_name] = quantized_param.to(target_device)
@ -134,13 +134,15 @@ class FP8HfQuantizer(HfQuantizer):
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, (FP8Linear, FP8MoELinear)):
if tensor_name == "bias":
if isinstance(module, FP8Linear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.int8:
raise ValueError("Expect quantized weights but got an unquantized weight")
return False
if tensor_name == "weight":
else:
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
if tensor_name in ["weight_scale", "input_scale"]:
return False
return False
def _process_model_before_weight_loading(