mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[ROCm] Enable some sparse tests on ROCm (#77877)
Enabling: test_sampled_addmm_errors_cuda_complex128 test_sampled_addmm_errors_cuda_complex64 test_sampled_addmm_errors_cuda_float32 test_sampled_addmm_errors_cuda_float64 test_sparse_add_cuda_complex128 test_sparse_add_cuda_complex64 Pull Request resolved: https://github.com/pytorch/pytorch/pull/77877 Approved by: https://github.com/pruthvistony, https://github.com/malfet
This commit is contained in:
parent
20d56d2b32
commit
aff7eef476
1 changed files with 1 additions and 5 deletions
|
|
@ -1538,9 +1538,6 @@ class TestSparseCSR(TestCase):
|
|||
def test_sparse_add(self, device, dtype):
|
||||
def run_test(m, n, index_dtype):
|
||||
|
||||
if TEST_WITH_ROCM and dtype.is_complex:
|
||||
self.skipTest("ROCm doesn't work with complex dtype correctly.")
|
||||
|
||||
alpha = random.random()
|
||||
nnz1 = random.randint(0, m * n)
|
||||
nnz2 = random.randint(0, m * n)
|
||||
|
|
@ -1744,10 +1741,9 @@ class TestSparseCSR(TestCase):
|
|||
b = make_tensor((k, n), dtype=dtype, device=device)
|
||||
run_test(c, a, b)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@onlyCUDA
|
||||
@skipCUDAIf(
|
||||
not _check_cusparse_sddmm_available(),
|
||||
not (TEST_WITH_ROCM or _check_cusparse_sddmm_available()),
|
||||
"cuSparse Generic API SDDMM is not available"
|
||||
)
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
|
|
|
|||
Loading…
Reference in a new issue