Relax atol for some ORTModule UTs (#6969)

This commit is contained in:
Vincent Wang 2021-03-11 00:59:56 +08:00 committed by GitHub
parent 534adbb065
commit 3f579facbc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -599,9 +599,9 @@ def test_mixed_nnmodule_ortmodules_training():
pt_p1, pt_p2, pt_p3 = run_step(pt_model1, pt_model2, pt_model3, x1, x2)
ort_p1, ort_p2, ort_p3 = run_step(ort_model1, ort_model2, ort_model3, x1, x2)
assert torch.allclose(ort_p1, pt_p1)
assert torch.allclose(ort_p2, pt_p2)
# assert torch.allclose(ort_p3, pt_p3) # TODO: this assert is failing, need to investigate!!
assert torch.allclose(ort_p1, pt_p1, atol=1e-06)
assert torch.allclose(ort_p2, pt_p2, atol=1e-06)
assert torch.allclose(ort_p3, pt_p3, atol=1e-06)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model3, pt_model3)
@ -686,7 +686,7 @@ def test_input_requires_grad_backward_creates_input_grad_as_required0(device):
pt_y1 = run_step0(pt_model, pt_x1, pt_x2)
ort_y1 = run_step0(ort_model, ort_x1, ort_x2)
#assert torch.allclose(pt_y1, ort_y1) # TODO: this assert is failing, need to investigate!!
assert torch.allclose(pt_y1, ort_y1, atol=1e-06)
assert torch.allclose(ort_x1.grad, pt_x1.grad)
assert torch.allclose(ort_x2.grad, pt_x2.grad)
# backward() is from y1, so grad of fc2.weight and fc2.bias will not be calculated.
@ -701,7 +701,7 @@ def test_input_requires_grad_backward_creates_input_grad_as_required0(device):
pt_y2 = run_step1(pt_model, pt_x1, pt_x2)
ort_y2 = run_step1(ort_model, ort_x1, ort_x2)
#assert torch.allclose(pt_y2, ort_y2) # TODO: this assert is failing, need to investigate!!
assert torch.allclose(pt_y2, ort_y2, atol=1e-06)
assert torch.allclose(ort_x1.grad, pt_x1.grad)
assert torch.allclose(ort_x2.grad, pt_x2.grad)
# backward() is from y2, so grad of fc1.weight and fc1.bias will not be calculated.
@ -729,8 +729,8 @@ def test_loss_combines_two_outputs_with_dependency(device):
pt_y1, pt_y2 = run_step(pt_model, pt_x1, pt_x2)
ort_y1, ort_y2 = run_step(ort_model, ort_x1, ort_x2)
#assert torch.allclose(pt_y1, ort_y1) # TODO: this assert is failing, need to investigate!!
#assert torch.allclose(pt_y2, ort_y2) # TODO: this assert is failing, need to investigate!!
assert torch.allclose(pt_y1, ort_y1, atol=1e-06)
assert torch.allclose(pt_y2, ort_y2, atol=1e-06)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
@pytest.mark.parametrize("x1_requires_grad, x2_requires_grad", [(True, True), (True, False), (False, False), (False, True)])
@ -757,8 +757,8 @@ def test_input_requires_grad_backward_creates_input_grad_as_required1(x1_require
pt_y1, pt_y2 = run_step(pt_model, pt_x1, pt_x2)
ort_y1, ort_y2 = run_step(ort_model, ort_x1, ort_x2)
# assert torch.allclose(ort_y1, pt_y1) # TODO: this assert is failing, need to investigate!!
# assert torch.allclose(ort_y2, pt_y2) # TODO: this assert is failing, need to investigate!!
assert torch.allclose(ort_y1, pt_y1, atol=1e-06)
assert torch.allclose(ort_y2, pt_y2, atol=1e-06)
assert not x1_requires_grad or ort_x1.grad is not None
assert not x2_requires_grad or ort_x2.grad is not None
assert not x1_requires_grad or torch.allclose(ort_x1.grad, pt_x1.grad)