Skip softmax BF16 test for ROCm (#21162)

### Description

Skip softmax BF16 test for ROCm, because BFloat16 is unsupported by
MIOpen, and `torch.cuda.is_available()` also returns `True` for ROCm.
This commit is contained in:
mindest 2024-06-26 11:15:50 +08:00 committed by GitHub
parent 41ad83fb00
commit e2abba18ea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -148,8 +148,8 @@ class TestOnnxOpsOrtModule(unittest.TestCase):
@unittest.skipIf(not torch.cuda.is_bf16_supported(), "Test requires CUDA and BF16 support")
def test_softmax_bf16_large(self):
if not torch.cuda.is_available():
# only test bf16 on cuda
if torch.version.cuda is None:
# Only run this test when CUDA is available, as on ROCm BF16 is not supported by MIOpen.
return
class Model(torch.nn.Module):
@ -175,7 +175,7 @@ class TestOnnxOpsOrtModule(unittest.TestCase):
data_ort.requires_grad = True
ort_res = ort_model(input=data_ort)
ort_res.backward(gradient=init_grad)
# compara result
# compare result
torch.testing.assert_close(data_torch.grad, data_ort.grad, rtol=1e-5, atol=1e-4)