diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index a0c68238de3..211a45028fa 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -301,7 +301,10 @@ class FakeDDP(nn.Module): DDP._active_ddp_module = None def forward(self, *inputs, **kwargs): - with self._inside_ddp_forward(): + if not DDP._active_ddp_module: + with self._inside_ddp_forward(): + return self.module.forward(*inputs, **kwargs) + else: return self.module.forward(*inputs, **kwargs) @@ -372,6 +375,43 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase): opt_model = torch.compile(dynamic=True)(model) opt_model(torch.randn(20, 512)) + @patch.object(config, "optimize_ddp", True) + def test_ddp_optimizer_inductor_strides_dont_specialize(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.fc_0 = nn.Linear(768, 768) + self.fc_1 = nn.Linear(768, 768) + + def forward(self, x): + x = self.fc_0(x) + x = self.fc_1(x) + return x + + model = Model() + model = FakeDDP(model) + + inp = torch.randn((16, 18, 768)) + inp2 = torch.randn((16, 20, 768)) + + torch._dynamo.mark_dynamic(inp, 1) + torch._dynamo.mark_dynamic(inp2, 1) + + torch._dynamo.utils.clear_compilation_metrics() + torch._dynamo.reset() + try: + DDP._active_ddp_module = model + opt_model = torch.compile(model) + self.assertEqual(0, len(torch._dynamo.utils.get_compilation_metrics())) + opt_model(inp) + compile_count_before = len(torch._dynamo.utils.get_compilation_metrics()) + opt_model(inp2) + compile_count_after = len(torch._dynamo.utils.get_compilation_metrics()) + # no recompiles + self.assertEqual(compile_count_before, compile_count_after) + finally: + DDP._active_ddp_module = None + @config.patch(optimize_ddp=True, capture_scalar_outputs=True) def test_unbacked_symbol_splitting_direct(self): class Model(nn.Module): diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 3f91ee87b3a..4ed3d8b75c4 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2150,16 +2150,18 @@ def set_tracing_context_output_strides(example_inputs, compiled_graph): if exprs is None: context.output_strides.append(None) else: - context.output_strides.append( - tuple( - ( - shape_env.evaluate_symexpr(e) - if shape_env is not None - else int(e) - ) - for e in exprs - ) - ) + fakify_first_call = False + if ctx := torch._guards.TracingContext.try_get(): + fakify_first_call = ctx.fakify_first_call + + def map_expr(e): + if shape_env is None: + return int(e) + if fakify_first_call: + return shape_env.deserialize_symexpr(e) + return shape_env.evaluate_symexpr(e) + + context.output_strides.append(tuple(map_expr(e) for e in exprs)) def should_use_remote_fx_graph_cache(): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 7ae10d0ff89..a68cd95260f 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -5231,6 +5231,16 @@ class ShapeEnv: args = {str(e): val for e, val in self.var_to_val.items()} return eval(code, SYMPY_INTERP, args) + def deserialize_symexpr(self, code: str) -> Union[SymInt, SymFloat, SymBool]: + """ + To be used by compile_fx to deserialize symexprs + """ + args = { + str(e): SymInt(SymNode(e, self, int, int(val), fx_node=None)) + for e, val in self.var_to_val.items() + } + return eval(code, SYMPY_INTERP, args) + def evaluate_guards_expression(self, code: str, args: Sequence[object]) -> bool: """ Expected to be used with produce_guards_expression(). Evaluates an expression