mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Relax atol for some ORTModule UTs (#6969)
This commit is contained in:
parent
534adbb065
commit
3f579facbc
1 changed files with 9 additions and 9 deletions
|
|
@ -599,9 +599,9 @@ def test_mixed_nnmodule_ortmodules_training():
|
|||
pt_p1, pt_p2, pt_p3 = run_step(pt_model1, pt_model2, pt_model3, x1, x2)
|
||||
ort_p1, ort_p2, ort_p3 = run_step(ort_model1, ort_model2, ort_model3, x1, x2)
|
||||
|
||||
assert torch.allclose(ort_p1, pt_p1)
|
||||
assert torch.allclose(ort_p2, pt_p2)
|
||||
# assert torch.allclose(ort_p3, pt_p3) # TODO: this assert is failing, need to investigate!!
|
||||
assert torch.allclose(ort_p1, pt_p1, atol=1e-06)
|
||||
assert torch.allclose(ort_p2, pt_p2, atol=1e-06)
|
||||
assert torch.allclose(ort_p3, pt_p3, atol=1e-06)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model3, pt_model3)
|
||||
|
|
@ -686,7 +686,7 @@ def test_input_requires_grad_backward_creates_input_grad_as_required0(device):
|
|||
pt_y1 = run_step0(pt_model, pt_x1, pt_x2)
|
||||
ort_y1 = run_step0(ort_model, ort_x1, ort_x2)
|
||||
|
||||
#assert torch.allclose(pt_y1, ort_y1) # TODO: this assert is failing, need to investigate!!
|
||||
assert torch.allclose(pt_y1, ort_y1, atol=1e-06)
|
||||
assert torch.allclose(ort_x1.grad, pt_x1.grad)
|
||||
assert torch.allclose(ort_x2.grad, pt_x2.grad)
|
||||
# backward() is from y1, so grad of fc2.weight and fc2.bias will not be calculated.
|
||||
|
|
@ -701,7 +701,7 @@ def test_input_requires_grad_backward_creates_input_grad_as_required0(device):
|
|||
pt_y2 = run_step1(pt_model, pt_x1, pt_x2)
|
||||
ort_y2 = run_step1(ort_model, ort_x1, ort_x2)
|
||||
|
||||
#assert torch.allclose(pt_y2, ort_y2) # TODO: this assert is failing, need to investigate!!
|
||||
assert torch.allclose(pt_y2, ort_y2, atol=1e-06)
|
||||
assert torch.allclose(ort_x1.grad, pt_x1.grad)
|
||||
assert torch.allclose(ort_x2.grad, pt_x2.grad)
|
||||
# backward() is from y2, so grad of fc1.weight and fc1.bias will not be calculated.
|
||||
|
|
@ -729,8 +729,8 @@ def test_loss_combines_two_outputs_with_dependency(device):
|
|||
pt_y1, pt_y2 = run_step(pt_model, pt_x1, pt_x2)
|
||||
ort_y1, ort_y2 = run_step(ort_model, ort_x1, ort_x2)
|
||||
|
||||
#assert torch.allclose(pt_y1, ort_y1) # TODO: this assert is failing, need to investigate!!
|
||||
#assert torch.allclose(pt_y2, ort_y2) # TODO: this assert is failing, need to investigate!!
|
||||
assert torch.allclose(pt_y1, ort_y1, atol=1e-06)
|
||||
assert torch.allclose(pt_y2, ort_y2, atol=1e-06)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
|
||||
@pytest.mark.parametrize("x1_requires_grad, x2_requires_grad", [(True, True), (True, False), (False, False), (False, True)])
|
||||
|
|
@ -757,8 +757,8 @@ def test_input_requires_grad_backward_creates_input_grad_as_required1(x1_require
|
|||
pt_y1, pt_y2 = run_step(pt_model, pt_x1, pt_x2)
|
||||
ort_y1, ort_y2 = run_step(ort_model, ort_x1, ort_x2)
|
||||
|
||||
# assert torch.allclose(ort_y1, pt_y1) # TODO: this assert is failing, need to investigate!!
|
||||
# assert torch.allclose(ort_y2, pt_y2) # TODO: this assert is failing, need to investigate!!
|
||||
assert torch.allclose(ort_y1, pt_y1, atol=1e-06)
|
||||
assert torch.allclose(ort_y2, pt_y2, atol=1e-06)
|
||||
assert not x1_requires_grad or ort_x1.grad is not None
|
||||
assert not x2_requires_grad or ort_x2.grad is not None
|
||||
assert not x1_requires_grad or torch.allclose(ort_x1.grad, pt_x1.grad)
|
||||
|
|
|
|||
Loading…
Reference in a new issue