diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 1ede6964be6..2367285f9d4 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -1050,6 +1050,19 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, IntArrayRef mat2_sizes = mat2.sizes(); at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); + // If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm kernels + // do not support this case). + if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) { + // `out` was created with `at::empty`. In the case where we are multiplying + // MxK by KxN and K is the zero dim, we need to initialize here to properly + // return a tensor of zeros. + if (mat1_sizes[1] == 0) { + out.zero_(); + } + + return out; + } + // We are doing row-wise scaling if (scaling_choice == ScalingType::RowWise) { TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precsion output types are supported for row-wise scaling."); diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 1d5f6bd711f..0af09695269 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -700,6 +700,33 @@ class TestFP8MatmulCuda(TestCase): torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @parametrize("which_dim_zero", [0, 1, 2]) + @parametrize("use_torch_compile", [False, True]) + def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None: + device = "cuda" + x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn + out_dtype = torch.bfloat16 + M, K, N = 32, 32, 32 + if which_dim_zero == 0: + M = 0 + elif which_dim_zero == 1: + K = 0 + elif which_dim_zero == 2: + N = 0 + + x_fp8 = torch.zeros(M, K, device=device).to(x_dtype) + y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t() + out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) + scale_a = torch.tensor(float('-inf'), device=device) + scale_b = torch.tensor(float('-inf'), device=device) + f = torch._scaled_mm + if use_torch_compile: + f = torch.compile(torch._scaled_mm) + out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype) + self.assertEqual(out_dtype, out_fp8.dtype) + self.assertEqual(out_fp32, out_fp8.to(torch.float)) + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 291e3022288..74cfac42987 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5589,13 +5589,16 @@ def meta_scaled_mm( def is_col_major(stride): return stride[0] == 1 and stride[1] > 1 + def has_zero_dim(tensor_2d): + return tensor_2d.size(0) == 0 or tensor_2d.size(1) == 0 + torch._check( - is_row_major(self.stride()), - lambda: "self must be row_major", + is_row_major(self.stride()) or has_zero_dim(self), + lambda: f"self must be row_major, got stride {self.stride()}", ) torch._check( - is_col_major(mat2.stride()), - lambda: "mat2 must be col_major", + is_col_major(mat2.stride()) or has_zero_dim(mat2), + lambda: f"mat2 must be col_major, got stride {mat2.stride()}", ) torch._check( self.size(1) % 16 == 0,