diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 9640b648591..a9b736fc811 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -3960,6 +3960,10 @@ class TestMod(torch.nn.Module): class TestAOTExport(AOTTestCase): + def setUp(self): + super().setUp() + torch._dynamo.reset() + def test_aot_export_ban_dropout_mut_pre_dispatch(self): def fn(p, x): y = torch.ops.aten.dropout.default(x, 0.1, train=False)