From aff7eef4764e5478c92e48be8d52f3e75e0ba705 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Tue, 14 Jun 2022 21:11:33 +0000 Subject: [PATCH] [ROCm] Enable some sparse tests on ROCm (#77877) Enabling: test_sampled_addmm_errors_cuda_complex128 test_sampled_addmm_errors_cuda_complex64 test_sampled_addmm_errors_cuda_float32 test_sampled_addmm_errors_cuda_float64 test_sparse_add_cuda_complex128 test_sparse_add_cuda_complex64 Pull Request resolved: https://github.com/pytorch/pytorch/pull/77877 Approved by: https://github.com/pruthvistony, https://github.com/malfet --- test/test_sparse_csr.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index ff8b432f3cb..367d7df47cb 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -1538,9 +1538,6 @@ class TestSparseCSR(TestCase): def test_sparse_add(self, device, dtype): def run_test(m, n, index_dtype): - if TEST_WITH_ROCM and dtype.is_complex: - self.skipTest("ROCm doesn't work with complex dtype correctly.") - alpha = random.random() nnz1 = random.randint(0, m * n) nnz2 = random.randint(0, m * n) @@ -1744,10 +1741,9 @@ class TestSparseCSR(TestCase): b = make_tensor((k, n), dtype=dtype, device=device) run_test(c, a, b) - @skipCUDAIfRocm @onlyCUDA @skipCUDAIf( - not _check_cusparse_sddmm_available(), + not (TEST_WITH_ROCM or _check_cusparse_sddmm_available()), "cuSparse Generic API SDDMM is not available" ) @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)