diff --git a/src/transformers/integrations/fp8.py b/src/transformers/integrations/fp8.py index 1ca1c843a..ec4c3b188 100644 --- a/src/transformers/integrations/fp8.py +++ b/src/transformers/integrations/fp8.py @@ -227,7 +227,7 @@ def w8a8_block_fp8_matmul_triton( return C -# Python version of the above triton function +# Python version of the above triton function, it's much slower than the triton version @torch.compile def w8a8_block_fp8_matmul_compile( input_q: torch.Tensor, # [batch, seq_len, hidden_dim]