diff --git a/src/transformers/integrations/fp8.py b/src/transformers/integrations/fp8.py index 0acc9930a..8a89e59e8 100644 --- a/src/transformers/integrations/fp8.py +++ b/src/transformers/integrations/fp8.py @@ -230,7 +230,7 @@ def _w8a8_block_fp8_matmul( c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) -def w8a8_block_fp8_matmul( +def w8a8_block_fp8_matmul_( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -391,20 +391,38 @@ def w8a8_block_fp8_matmul_compile( return output.to(output_dtype) +from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8 + def linear(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, bias: Optional[torch.Tensor] = None, block_size: Optional[Tuple[int, int]] = None, activation_scheme: str = "dynamic") -> torch.Tensor: if weight.element_size() > 1: + print("value not the one expected") return F.linear(x, weight, bias) else: - x, scale = act_quant(x, block_size[0]) - # if x.shape[1] > 1: - # print("x", x.shape, x.dtype) - # print("scale", scale.shape) - # y = fp8_gemm(x, scale, weight, weight_scale) - y = w8a8_block_fp8_matmul(x, weight, scale, weight_scale, block_size) - # y = w8a8_block_fp8_matmul_compile(x, weight, scale, weight_scale, block_size) + qinput, scale = per_token_group_quant_fp8(x, block_size[1]) + + output = w8a8_block_fp8_matmul( + qinput, + weight, + scale, + weight_scale, + block_size, + output_dtype=x.dtype, + ) + if bias is not None: - y += bias - return y + output = output + bias + return output.to(dtype=x.dtype) + + # x, scale = act_quant(x, block_size[0]) + # # if x.shape[1] > 1: + # # print("x", x.shape, x.dtype) + # # print("scale", scale.shape) + # # y = fp8_gemm(x, scale, weight, weight_scale) + # y = w8a8_block_fp8_matmul(x, weight, scale, weight_scale, block_size) + # # y = w8a8_block_fp8_matmul_compile(x, weight, scale, weight_scale, block_size) + # if bias is not None: + # y += bias + # return y class FP8Linear(nn.Linear): diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 1dedc0d9f..c16e8d071 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -36,7 +36,7 @@ _CONFIG_FOR_DOC = "DeepseekV3Config" class DeepseekV3RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, eps=1e-4): """ DeepseekV3RMSNorm is equivalent to T5LayerNorm """ @@ -731,8 +731,8 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): position_embeddings=position_embeddings, **flash_attn_kwargs, ) - nan_count = torch.sum(torch.isnan(layer_outputs[0])).item() - print("nan_count", nan_count, "layer", idx) + # nan_count = torch.sum(torch.isnan(layer_outputs[0])).item() + # print("nan_count", nan_count, "layer", idx) hidden_states = layer_outputs[0] if output_attentions: