From bdc2c2a23752e8cb49eaa07c3ce66d6af583b1b5 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Tue, 21 Jan 2025 16:55:39 -0800 Subject: [PATCH] [be] fix flaky test aot_export_ cond caused by free symbol lifting and automatic dynamic shape (#145330) Fixes https://github.com/pytorch/pytorch/issues/139998#issuecomment-2605908426. It seems to be an issue caused by the interaction between dynamoed hop X automatic dynamic shape X auto_lift_free symbols. The immediate error is that the asserteExpectedInline of the graph can sometimes be different e.g. see https://hud.pytorch.org/flakytest?name=test_aot_export_with_torch_cond&suite=TestAOTExport&limit=100, where sometimes the shapes are lifted as input to the cond and sometimes they're not. The root cause of the flakyness is that the two invocations of torch.cond triggers two torch.compile on the same code object ([code](https://github.com/pytorch/pytorch/blob/main/torch/_higher_order_ops/cond.py#L192)), and triggers automatic dynamic shape because in test_aot_export_with_torch_cond, x has shape (3, 4) while the pre_dispatch one has shape (2, 2). Because of we auto lift free symbols for dynamic shaped input, this causes cond sometimes have the shape as arguments and sometimes not. This PR adds a simple fix by adding a _dynamo.reset before each torch.cond tests. This fixes the error by not triggering automatic dynamic shape. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145330 Approved by: https://github.com/zou3519 --- test/functorch/test_aotdispatch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 9640b648591..a9b736fc811 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -3960,6 +3960,10 @@ class TestMod(torch.nn.Module): class TestAOTExport(AOTTestCase): + def setUp(self): + super().setUp() + torch._dynamo.reset() + def test_aot_export_ban_dropout_mut_pre_dispatch(self): def fn(p, x): y = torch.ops.aten.dropout.default(x, 0.1, train=False)