From fbf5d09a0cecf40522ccd8b74b2f08564acb7159 Mon Sep 17 00:00:00 2001 From: pengwa Date: Mon, 20 Feb 2023 18:09:53 +0800 Subject: [PATCH] Fix random failure of ortmodule_api.py::test_unused_parameters (#14729) ### Fix random failure of ortmodule_api.py::test_unused_parameters Fix FAILED orttraining_test_ortmodule_api.py::test_unused_parameters[model1-none_pt_params1] for orttraining-linux-gpu-ci-pipeline CI pipeline ``` =================================== FAILURES =================================== ________________ test_unused_parameters[model1-none_pt_params1] ________________ model = UnusedMiddleParameterNet( (fc1): Linear(in_features=784, out_features=500, bias=True) (relu): ReLU() (fc2): Linear(in_features=500, out_features=400, bias=True) (fc3): Linear(in_features=500, out_features=10, bias=True) ) none_pt_params = ['fc2.weight', 'fc2.bias'] @pytest.mark.parametrize( "model, none_pt_params", [ (UnusedBeginParameterNet(784, 500, 400, 10), ["fc1.weight", "fc1.bias"]), (UnusedMiddleParameterNet(784, 500, 400, 10), ["fc2.weight", "fc2.bias"]), (UnusedEndParameterNet(784, 500, 400, 10), ["fc2.weight", "fc2.bias"]), ], ) def test_unused_parameters(model, none_pt_params): device = "cuda" N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10 model = model.to(device) ort_model = ORTModule(copy.deepcopy(model)) # Make sure model runs without any exception for _ in range(5): x = torch.randn(N, D_in, device=device) y = copy.deepcopy(x) out_pt = model(x) out_ort = ort_model(y) loss_pt = out_pt.sum() loss_pt.backward() loss_ort = out_ort.sum() loss_ort.backward() _test_helpers.assert_values_are_close(out_ort, out_pt) > _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, model, none_pt_params=none_pt_params) orttraining_test_ortmodule_api.py:4050: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _test_helpers.py:216: in assert_gradients_match_and_reset_gradient assert_values_are_close(ort_param.grad, pt_param.grad, rtol=rtol, atol=atol) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ ``` Initially the test runs very well. As we insert more and more tests, when running ortmodule_api.py::test_unused_parameters, the random generated data got changed, and now it is more easily to generate an input data that produce a result the break existing rtol and atol. The example data, 0.1041 only have very minor diff, e.g. abs_diff: 2.2649765014648438e-06. > The torch.allclose judge it is not equal because: abs_diff> 0.1041 * rtol + atol = 1.041e-1 * 1e-5 + 1e-6 =-2.041e-6. > Additionally, according to math [here](https://github.com/microsoft/onnxruntime/blob/7b31bcda2e9cd45c709e3fe31a544297db37ea3c/orttraining/orttraining/test/python/_test_helpers.py#L230) The maximum atol is 1.2238311910550692e-06 > current atol(1e-6), maximum rtol is 1.2149855137977283e-05 > current rtol(1e-5). This PR looses the atol to 1e-5, rtol to 1e-4 . --- orttraining/orttraining/test/python/_test_helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index 95c3b58521..73f2940467 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -198,7 +198,7 @@ def _get_name(name): # Depending on calling backward() from which outputs, it's possible that grad of some weights are not calculated. # none_pt_params is to tell what these weights are, so we will not compare the tensors. def assert_gradients_match_and_reset_gradient( - ort_model, pt_model, none_pt_params=[], reset_gradient=True, rtol=1e-05, atol=1e-06 + ort_model, pt_model, none_pt_params=[], reset_gradient=True, rtol=1e-04, atol=1e-05 ): ort_named_params = list(ort_model.named_parameters()) pt_named_params = list(pt_model.named_parameters()) @@ -220,7 +220,7 @@ def assert_gradients_match_and_reset_gradient( pt_param.grad = None -def assert_values_are_close(input, other, rtol=1e-05, atol=1e-06): +def assert_values_are_close(input, other, rtol=1e-04, atol=1e-05): are_close = torch.allclose(input, other, rtol=rtol, atol=atol) if not are_close: abs_diff = torch.abs(input - other)