[ROCm] Enable some sparse tests on ROCm (#77877)

Enabling:
test_sampled_addmm_errors_cuda_complex128
test_sampled_addmm_errors_cuda_complex64
test_sampled_addmm_errors_cuda_float32
test_sampled_addmm_errors_cuda_float64
test_sparse_add_cuda_complex128
test_sparse_add_cuda_complex64

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77877
Approved by: https://github.com/pruthvistony, https://github.com/malfet
This commit is contained in:
jpvillam 2022-06-14 21:11:33 +00:00 committed by PyTorch MergeBot
parent 20d56d2b32
commit aff7eef476

View file

@ -1538,9 +1538,6 @@ class TestSparseCSR(TestCase):
def test_sparse_add(self, device, dtype):
def run_test(m, n, index_dtype):
if TEST_WITH_ROCM and dtype.is_complex:
self.skipTest("ROCm doesn't work with complex dtype correctly.")
alpha = random.random()
nnz1 = random.randint(0, m * n)
nnz2 = random.randint(0, m * n)
@ -1744,10 +1741,9 @@ class TestSparseCSR(TestCase):
b = make_tensor((k, n), dtype=dtype, device=device)
run_test(c, a, b)
@skipCUDAIfRocm
@onlyCUDA
@skipCUDAIf(
not _check_cusparse_sddmm_available(),
not (TEST_WITH_ROCM or _check_cusparse_sddmm_available()),
"cuSparse Generic API SDDMM is not available"
)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)