diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index d2494bd43d5..9bfc74b4ed2 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -708,13 +708,13 @@ void dispatch_fp8_rowwise_kernel_on_sm( at::Tensor out) { cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); const bool sm89 = properties != nullptr && properties->major == 8 && properties->minor == 9; - const bool sm90OrLater = properties != nullptr && properties->major >= 9; - if (!(sm89 || sm90OrLater)) { + const bool sm9x = properties != nullptr && properties->major == 9; + if (!(sm89 || sm9x)) { TORCH_CHECK( false, "Rowwise scaling is not currently supported on your device"); } - if (sm90OrLater) { + if (sm9x) { dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose(XQ, WQ, x_scale, w_scale, bias, out); } else { f8f8bf16_rowwise_impl_sm89(XQ, WQ, x_scale, w_scale, bias, out);