diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 84aeb22b48..d611253c80 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -599,9 +599,9 @@ def test_mixed_nnmodule_ortmodules_training(): pt_p1, pt_p2, pt_p3 = run_step(pt_model1, pt_model2, pt_model3, x1, x2) ort_p1, ort_p2, ort_p3 = run_step(ort_model1, ort_model2, ort_model3, x1, x2) - assert torch.allclose(ort_p1, pt_p1) - assert torch.allclose(ort_p2, pt_p2) - # assert torch.allclose(ort_p3, pt_p3) # TODO: this assert is failing, need to investigate!! + assert torch.allclose(ort_p1, pt_p1, atol=1e-06) + assert torch.allclose(ort_p2, pt_p2, atol=1e-06) + assert torch.allclose(ort_p3, pt_p3, atol=1e-06) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model3, pt_model3) @@ -686,7 +686,7 @@ def test_input_requires_grad_backward_creates_input_grad_as_required0(device): pt_y1 = run_step0(pt_model, pt_x1, pt_x2) ort_y1 = run_step0(ort_model, ort_x1, ort_x2) - #assert torch.allclose(pt_y1, ort_y1) # TODO: this assert is failing, need to investigate!! + assert torch.allclose(pt_y1, ort_y1, atol=1e-06) assert torch.allclose(ort_x1.grad, pt_x1.grad) assert torch.allclose(ort_x2.grad, pt_x2.grad) # backward() is from y1, so grad of fc2.weight and fc2.bias will not be calculated. @@ -701,7 +701,7 @@ def test_input_requires_grad_backward_creates_input_grad_as_required0(device): pt_y2 = run_step1(pt_model, pt_x1, pt_x2) ort_y2 = run_step1(ort_model, ort_x1, ort_x2) - #assert torch.allclose(pt_y2, ort_y2) # TODO: this assert is failing, need to investigate!! + assert torch.allclose(pt_y2, ort_y2, atol=1e-06) assert torch.allclose(ort_x1.grad, pt_x1.grad) assert torch.allclose(ort_x2.grad, pt_x2.grad) # backward() is from y2, so grad of fc1.weight and fc1.bias will not be calculated. @@ -729,8 +729,8 @@ def test_loss_combines_two_outputs_with_dependency(device): pt_y1, pt_y2 = run_step(pt_model, pt_x1, pt_x2) ort_y1, ort_y2 = run_step(ort_model, ort_x1, ort_x2) - #assert torch.allclose(pt_y1, ort_y1) # TODO: this assert is failing, need to investigate!! - #assert torch.allclose(pt_y2, ort_y2) # TODO: this assert is failing, need to investigate!! + assert torch.allclose(pt_y1, ort_y1, atol=1e-06) + assert torch.allclose(pt_y2, ort_y2, atol=1e-06) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) @pytest.mark.parametrize("x1_requires_grad, x2_requires_grad", [(True, True), (True, False), (False, False), (False, True)]) @@ -757,8 +757,8 @@ def test_input_requires_grad_backward_creates_input_grad_as_required1(x1_require pt_y1, pt_y2 = run_step(pt_model, pt_x1, pt_x2) ort_y1, ort_y2 = run_step(ort_model, ort_x1, ort_x2) - # assert torch.allclose(ort_y1, pt_y1) # TODO: this assert is failing, need to investigate!! - # assert torch.allclose(ort_y2, pt_y2) # TODO: this assert is failing, need to investigate!! + assert torch.allclose(ort_y1, pt_y1, atol=1e-06) + assert torch.allclose(ort_y2, pt_y2, atol=1e-06) assert not x1_requires_grad or ort_x1.grad is not None assert not x2_requires_grad or ort_x2.grad is not None assert not x1_requires_grad or torch.allclose(ort_x1.grad, pt_x1.grad)