Fix random failure of ortmodule_api.py::test_unused_parameters (#14729)

### Fix random failure of ortmodule_api.py::test_unused_parameters

Fix FAILED
orttraining_test_ortmodule_api.py::test_unused_parameters[model1-none_pt_params1]
for orttraining-linux-gpu-ci-pipeline CI pipeline

```
=================================== FAILURES ===================================
________________ test_unused_parameters[model1-none_pt_params1] ________________

model = UnusedMiddleParameterNet(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=500, out_features=400, bias=True)
  (fc3): Linear(in_features=500, out_features=10, bias=True)
)
none_pt_params = ['fc2.weight', 'fc2.bias']

    @pytest.mark.parametrize(
        "model, none_pt_params",
        [
            (UnusedBeginParameterNet(784, 500, 400, 10), ["fc1.weight", "fc1.bias"]),
            (UnusedMiddleParameterNet(784, 500, 400, 10), ["fc2.weight", "fc2.bias"]),
            (UnusedEndParameterNet(784, 500, 400, 10), ["fc2.weight", "fc2.bias"]),
        ],
    )
    def test_unused_parameters(model, none_pt_params):
        device = "cuda"
    
        N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10
        model = model.to(device)
        ort_model = ORTModule(copy.deepcopy(model))
    
        # Make sure model runs without any exception
        for _ in range(5):
            x = torch.randn(N, D_in, device=device)
            y = copy.deepcopy(x)
    
            out_pt = model(x)
            out_ort = ort_model(y)
            loss_pt = out_pt.sum()
            loss_pt.backward()
            loss_ort = out_ort.sum()
            loss_ort.backward()
            _test_helpers.assert_values_are_close(out_ort, out_pt)
>           _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, model, none_pt_params=none_pt_params)

orttraining_test_ortmodule_api.py:4050: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
_test_helpers.py:216: in assert_gradients_match_and_reset_gradient
    assert_values_are_close(ort_param.grad, pt_param.grad, rtol=rtol, atol=atol)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

```

Initially the test runs very well. As we insert more and more tests,
when running ortmodule_api.py::test_unused_parameters, the random
generated data got changed, and now it is more easily to generate an
input data that produce a result the break existing rtol and atol.

The example data, 0.1041 only have very minor diff, e.g. abs_diff:
2.2649765014648438e-06.
> The torch.allclose judge it is not equal because: abs_diff> 0.1041 *
rtol + atol = 1.041e-1 * 1e-5 + 1e-6 =-2.041e-6.
> Additionally, according to math
[here](7b31bcda2e/orttraining/orttraining/test/python/_test_helpers.py (L230))
The maximum atol is 1.2238311910550692e-06 > current atol(1e-6), maximum
rtol is 1.2149855137977283e-05 > current rtol(1e-5).

This PR looses the atol to 1e-5, rtol to 1e-4 .
This commit is contained in:
pengwa 2023-02-20 18:09:53 +08:00 committed by GitHub
parent ad78579b66
commit fbf5d09a0c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -198,7 +198,7 @@ def _get_name(name):
# Depending on calling backward() from which outputs, it's possible that grad of some weights are not calculated.
# none_pt_params is to tell what these weights are, so we will not compare the tensors.
def assert_gradients_match_and_reset_gradient(
ort_model, pt_model, none_pt_params=[], reset_gradient=True, rtol=1e-05, atol=1e-06
ort_model, pt_model, none_pt_params=[], reset_gradient=True, rtol=1e-04, atol=1e-05
):
ort_named_params = list(ort_model.named_parameters())
pt_named_params = list(pt_model.named_parameters())
@ -220,7 +220,7 @@ def assert_gradients_match_and_reset_gradient(
pt_param.grad = None
def assert_values_are_close(input, other, rtol=1e-05, atol=1e-06):
def assert_values_are_close(input, other, rtol=1e-04, atol=1e-05):
are_close = torch.allclose(input, other, rtol=rtol, atol=atol)
if not are_close:
abs_diff = torch.abs(input - other)