diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 6037267ae9..48b971ea16 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -890,7 +890,7 @@ def test_multiple_ortmodules_common_backbone_training(): ort_prediction = run_step(ort_model0, ort_model2, x1) _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) - _test_helpers.assert_gradients_match_and_reset_gradient(ort_model0, pt_model0, reset_gradient=True, atol=1.5e-6) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model0, pt_model0, reset_gradient=True, atol=1e-5) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2) def test_multiple_chained_ortmodules_training():