mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Skip softmax BF16 test for ROCm (#21162)
### Description Skip softmax BF16 test for ROCm, because BFloat16 is unsupported by MIOpen, and `torch.cuda.is_available()` also returns `True` for ROCm.
This commit is contained in:
parent
41ad83fb00
commit
e2abba18ea
1 changed files with 3 additions and 3 deletions
|
|
@ -148,8 +148,8 @@ class TestOnnxOpsOrtModule(unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(not torch.cuda.is_bf16_supported(), "Test requires CUDA and BF16 support")
|
||||
def test_softmax_bf16_large(self):
|
||||
if not torch.cuda.is_available():
|
||||
# only test bf16 on cuda
|
||||
if torch.version.cuda is None:
|
||||
# Only run this test when CUDA is available, as on ROCm BF16 is not supported by MIOpen.
|
||||
return
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
|
|
@ -175,7 +175,7 @@ class TestOnnxOpsOrtModule(unittest.TestCase):
|
|||
data_ort.requires_grad = True
|
||||
ort_res = ort_model(input=data_ort)
|
||||
ort_res.backward(gradient=init_grad)
|
||||
# compara result
|
||||
# compare result
|
||||
torch.testing.assert_close(data_torch.grad, data_ort.grad, rtol=1e-5, atol=1e-4)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue