This commit is contained in:
MekkCyber 2025-02-07 21:04:53 +00:00
parent 8aa45e177e
commit 9f66405c2e

View file

@ -398,17 +398,21 @@ def linear(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, bi
print("value not the one expected")
return F.linear(x, weight, bias)
else:
qinput, scale = per_token_group_quant_fp8(x, block_size[1])
output = w8a8_block_fp8_matmul(
qinput,
weight,
scale,
weight_scale,
block_size,
output_dtype=x.dtype,
)
with torch.cuda.device(x.device):
qinput, scale = per_token_group_quant_fp8(x, block_size[1])
torch.cuda.synchronize(device=x.device)
with torch.cuda.device(x.device):
output = w8a8_block_fp8_matmul(
qinput,
weight,
scale,
weight_scale,
block_size,
output_dtype=x.dtype,
)
torch.cuda.synchronize(device=x.device)
# with open("output_log.txt", "a") as f:
# f.write(f"nans after: {torch.sum(torch.isnan(output))} on {output.device} with scale {scale if torch.sum(torch.isnan(output)) > 0 else 'none'}\n")
if bias is not None:
output = output + bias
return output.to(dtype=x.dtype)