mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo] Fix numpy test accuracy error induced by randomness divergence (#145293)
Previously `TestGradient.test_second_order_accurate` was failing because of a small tolerance error (0.03... which is above the 0.03 tolerance). Upon investigating, `np.random.random` caused some divergence between eager and compiled randomness because in compiled we are not using `np.random`'s random seed, rather we end up using `torch`'s. This in turn caused numerical divergence and aforementioned accuracy issue. This patch fixes the failure by patching the test case with `use_numpy_random_stream=True`, which forces a graph break on `np.random.random()` and thereby falling back to eager to ensure consistency of the numpy randomness. Fixes #116746. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145293 Approved by: https://github.com/lezcano
This commit is contained in:
parent
2efa98d69d
commit
9f150786bb
2 changed files with 2 additions and 0 deletions
|
|
@ -16,6 +16,7 @@ import pytest
|
|||
from hypothesis.extra.numpy import arrays
|
||||
from pytest import raises as assert_raises
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
|
|
@ -1013,6 +1014,7 @@ class TestGradient(TestCase):
|
|||
assert_raises(TypeError, gradient, f_2d, x, x, axis=1)
|
||||
assert_raises(TypeError, gradient, f_2d, 1, 1, axis=1)
|
||||
|
||||
@torch._dynamo.config.patch(use_numpy_random_stream=True)
|
||||
def test_second_order_accurate(self):
|
||||
# Testing that the relative numerical error is less that 3% for
|
||||
# this example problem. This corresponds to second order
|
||||
|
|
|
|||
Loading…
Reference in a new issue