mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
commit
This commit is contained in:
parent
99f9e044d7
commit
99f9afb079
2 changed files with 31 additions and 13 deletions
|
|
@ -230,7 +230,7 @@ def _w8a8_block_fp8_matmul(
|
|||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
def w8a8_block_fp8_matmul(
|
||||
def w8a8_block_fp8_matmul_(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
|
|
@ -391,20 +391,38 @@ def w8a8_block_fp8_matmul_compile(
|
|||
|
||||
return output.to(output_dtype)
|
||||
|
||||
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
|
||||
|
||||
def linear(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, bias: Optional[torch.Tensor] = None, block_size: Optional[Tuple[int, int]] = None, activation_scheme: str = "dynamic") -> torch.Tensor:
|
||||
if weight.element_size() > 1:
|
||||
print("value not the one expected")
|
||||
return F.linear(x, weight, bias)
|
||||
else:
|
||||
x, scale = act_quant(x, block_size[0])
|
||||
# if x.shape[1] > 1:
|
||||
# print("x", x.shape, x.dtype)
|
||||
# print("scale", scale.shape)
|
||||
# y = fp8_gemm(x, scale, weight, weight_scale)
|
||||
y = w8a8_block_fp8_matmul(x, weight, scale, weight_scale, block_size)
|
||||
# y = w8a8_block_fp8_matmul_compile(x, weight, scale, weight_scale, block_size)
|
||||
qinput, scale = per_token_group_quant_fp8(x, block_size[1])
|
||||
|
||||
output = w8a8_block_fp8_matmul(
|
||||
qinput,
|
||||
weight,
|
||||
scale,
|
||||
weight_scale,
|
||||
block_size,
|
||||
output_dtype=x.dtype,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
y += bias
|
||||
return y
|
||||
output = output + bias
|
||||
return output.to(dtype=x.dtype)
|
||||
|
||||
# x, scale = act_quant(x, block_size[0])
|
||||
# # if x.shape[1] > 1:
|
||||
# # print("x", x.shape, x.dtype)
|
||||
# # print("scale", scale.shape)
|
||||
# # y = fp8_gemm(x, scale, weight, weight_scale)
|
||||
# y = w8a8_block_fp8_matmul(x, weight, scale, weight_scale, block_size)
|
||||
# # y = w8a8_block_fp8_matmul_compile(x, weight, scale, weight_scale, block_size)
|
||||
# if bias is not None:
|
||||
# y += bias
|
||||
# return y
|
||||
|
||||
|
||||
class FP8Linear(nn.Linear):
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ _CONFIG_FOR_DOC = "DeepseekV3Config"
|
|||
|
||||
|
||||
class DeepseekV3RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
def __init__(self, hidden_size, eps=1e-4):
|
||||
"""
|
||||
DeepseekV3RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
|
|
@ -731,8 +731,8 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
|||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
nan_count = torch.sum(torch.isnan(layer_outputs[0])).item()
|
||||
print("nan_count", nan_count, "layer", idx)
|
||||
# nan_count = torch.sum(torch.isnan(layer_outputs[0])).item()
|
||||
# print("nan_count", nan_count, "layer", idx)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
|
|
|
|||
Loading…
Reference in a new issue