mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
add sync
This commit is contained in:
parent
8aa45e177e
commit
9f66405c2e
1 changed files with 15 additions and 11 deletions
|
|
@ -398,17 +398,21 @@ def linear(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, bi
|
|||
print("value not the one expected")
|
||||
return F.linear(x, weight, bias)
|
||||
else:
|
||||
qinput, scale = per_token_group_quant_fp8(x, block_size[1])
|
||||
|
||||
output = w8a8_block_fp8_matmul(
|
||||
qinput,
|
||||
weight,
|
||||
scale,
|
||||
weight_scale,
|
||||
block_size,
|
||||
output_dtype=x.dtype,
|
||||
)
|
||||
|
||||
with torch.cuda.device(x.device):
|
||||
qinput, scale = per_token_group_quant_fp8(x, block_size[1])
|
||||
torch.cuda.synchronize(device=x.device)
|
||||
with torch.cuda.device(x.device):
|
||||
output = w8a8_block_fp8_matmul(
|
||||
qinput,
|
||||
weight,
|
||||
scale,
|
||||
weight_scale,
|
||||
block_size,
|
||||
output_dtype=x.dtype,
|
||||
)
|
||||
torch.cuda.synchronize(device=x.device)
|
||||
# with open("output_log.txt", "a") as f:
|
||||
# f.write(f"nans after: {torch.sum(torch.isnan(output))} on {output.device} with scale {scale if torch.sum(torch.isnan(output)) > 0 else 'none'}\n")
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=x.dtype)
|
||||
|
|
|
|||
Loading…
Reference in a new issue