mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[CUTLASS][FP8] Skip scaled_mm rowwise test on sm89 (#133612)
Rowwise implementation currently uses sm90-specific features incl. TMA CC @drisspg Pull Request resolved: https://github.com/pytorch/pytorch/pull/133612 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
413416cf33
commit
7ad3108ef2
1 changed files with 2 additions and 0 deletions
|
|
@ -16,6 +16,7 @@ from torch.quantization._quantized_conversions import (
|
|||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import (
|
||||
SM53OrLater,
|
||||
SM90OrLater,
|
||||
_get_torch_cuda_version,
|
||||
PLATFORM_SUPPORTS_FP8
|
||||
)
|
||||
|
|
@ -664,6 +665,7 @@ class TestFP8MatmulCuda(TestCase):
|
|||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not SM90OrLater, "rowwise implementation is currently sm90 specific")
|
||||
@skipIfRocm()
|
||||
@parametrize("base_dtype", [torch.bfloat16])
|
||||
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
|
||||
|
|
|
|||
Loading…
Reference in a new issue