This commit is contained in:
MekkCyber 2025-02-07 10:41:53 +00:00
parent 99f9e044d7
commit 99f9afb079
2 changed files with 31 additions and 13 deletions

View file

@ -230,7 +230,7 @@ def _w8a8_block_fp8_matmul(
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def w8a8_block_fp8_matmul(
def w8a8_block_fp8_matmul_(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
@ -391,20 +391,38 @@ def w8a8_block_fp8_matmul_compile(
return output.to(output_dtype)
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
def linear(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, bias: Optional[torch.Tensor] = None, block_size: Optional[Tuple[int, int]] = None, activation_scheme: str = "dynamic") -> torch.Tensor:
if weight.element_size() > 1:
print("value not the one expected")
return F.linear(x, weight, bias)
else:
x, scale = act_quant(x, block_size[0])
# if x.shape[1] > 1:
# print("x", x.shape, x.dtype)
# print("scale", scale.shape)
# y = fp8_gemm(x, scale, weight, weight_scale)
y = w8a8_block_fp8_matmul(x, weight, scale, weight_scale, block_size)
# y = w8a8_block_fp8_matmul_compile(x, weight, scale, weight_scale, block_size)
qinput, scale = per_token_group_quant_fp8(x, block_size[1])
output = w8a8_block_fp8_matmul(
qinput,
weight,
scale,
weight_scale,
block_size,
output_dtype=x.dtype,
)
if bias is not None:
y += bias
return y
output = output + bias
return output.to(dtype=x.dtype)
# x, scale = act_quant(x, block_size[0])
# # if x.shape[1] > 1:
# # print("x", x.shape, x.dtype)
# # print("scale", scale.shape)
# # y = fp8_gemm(x, scale, weight, weight_scale)
# y = w8a8_block_fp8_matmul(x, weight, scale, weight_scale, block_size)
# # y = w8a8_block_fp8_matmul_compile(x, weight, scale, weight_scale, block_size)
# if bias is not None:
# y += bias
# return y
class FP8Linear(nn.Linear):

View file

@ -36,7 +36,7 @@ _CONFIG_FOR_DOC = "DeepseekV3Config"
class DeepseekV3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, hidden_size, eps=1e-4):
"""
DeepseekV3RMSNorm is equivalent to T5LayerNorm
"""
@ -731,8 +731,8 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
nan_count = torch.sum(torch.isnan(layer_outputs[0])).item()
print("nan_count", nan_count, "layer", idx)
# nan_count = torch.sum(torch.isnan(layer_outputs[0])).item()
# print("nan_count", nan_count, "layer", idx)
hidden_states = layer_outputs[0]
if output_attentions: