From ffa628169dd732b4d95db29bd7728bb145be793c Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Thu, 30 Jan 2025 11:19:56 +0000 Subject: [PATCH] [ATen][Native][CUDA][SCALED_MM] limit f8f8bf16 rowwise scaled matmul to sm_90 (#145728) The CUTLASS-based kernel for f8f8bf16 rowwise scaled matmul is specific to Hopper devices only. It is not re-usable on newer devices without modifications. This PR adds a guard for this matmul to be sm_90 specific. Once the kernel is there, the guard may be removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145728 Approved by: https://github.com/Skylion007, https://github.com/eqy --- aten/src/ATen/native/cuda/RowwiseScaledMM.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index d2494bd43d5..9bfc74b4ed2 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -708,13 +708,13 @@ void dispatch_fp8_rowwise_kernel_on_sm( at::Tensor out) { cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); const bool sm89 = properties != nullptr && properties->major == 8 && properties->minor == 9; - const bool sm90OrLater = properties != nullptr && properties->major >= 9; - if (!(sm89 || sm90OrLater)) { + const bool sm9x = properties != nullptr && properties->major == 9; + if (!(sm89 || sm9x)) { TORCH_CHECK( false, "Rowwise scaling is not currently supported on your device"); } - if (sm90OrLater) { + if (sm9x) { dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose(XQ, WQ, x_scale, w_scale, bias, out); } else { f8f8bf16_rowwise_impl_sm89(XQ, WQ, x_scale, w_scale, bias, out);