diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index 95c3b58521..73f2940467 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -198,7 +198,7 @@ def _get_name(name): # Depending on calling backward() from which outputs, it's possible that grad of some weights are not calculated. # none_pt_params is to tell what these weights are, so we will not compare the tensors. def assert_gradients_match_and_reset_gradient( - ort_model, pt_model, none_pt_params=[], reset_gradient=True, rtol=1e-05, atol=1e-06 + ort_model, pt_model, none_pt_params=[], reset_gradient=True, rtol=1e-04, atol=1e-05 ): ort_named_params = list(ort_model.named_parameters()) pt_named_params = list(pt_model.named_parameters()) @@ -220,7 +220,7 @@ def assert_gradients_match_and_reset_gradient( pt_param.grad = None -def assert_values_are_close(input, other, rtol=1e-05, atol=1e-06): +def assert_values_are_close(input, other, rtol=1e-04, atol=1e-05): are_close = torch.allclose(input, other, rtol=rtol, atol=atol) if not are_close: abs_diff = torch.abs(input - other)