avoid specializing strides with DDPOptimizer + inductor (#140751)

Fixes https://github.com/pytorch/pytorch/issues/140229

Fixes https://github.com/pytorch/pytorch/issues/139474

The issue was that:

(1) DDPOptimizer has some logic to partition the dynamo graph into buckets, and run AOTAutograd/inductor on each bucket

(2) doing so requires knowing the **exact** strides of the outputs of each subgraph, so we can have example inputs (with correct strides) to each of the later subgraphs to compile with

(3) there is some existing logic to do this today: we have a `fakify_first_call` flag in AOTAutograd that lets you run it with fake tensor inputs (to handle the calling convention changes that AOTAutograd performs at runtime). During this process, we query inductor for the output strides that it compiled with

(4) these outputs strides are stored in the FX graph cache as raw strings of sympy expressions. We have a function, `evaluate_symexpr`, which given the sympy string, and the ShapeEnv's `var_to_val` mapping, will evaluate the sympy string to generate concrete strides

(5) evaluating this expression will specialize on the exact values of any variables in our shape env, however. In DDPOptimizer, we want to know what inductor's stride outputs are symbolically. This requires converting the (string) sympy expression into actual `SymInts` that we can return.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140751
Approved by: https://github.com/eellison
This commit is contained in:
Brian Hirsh 2024-12-04 14:53:37 -08:00 committed by PyTorch MergeBot
parent b08bc07cd7
commit 471017cbc9
3 changed files with 63 additions and 11 deletions

View file

@ -301,7 +301,10 @@ class FakeDDP(nn.Module):
DDP._active_ddp_module = None
def forward(self, *inputs, **kwargs):
with self._inside_ddp_forward():
if not DDP._active_ddp_module:
with self._inside_ddp_forward():
return self.module.forward(*inputs, **kwargs)
else:
return self.module.forward(*inputs, **kwargs)
@ -372,6 +375,43 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512))
@patch.object(config, "optimize_ddp", True)
def test_ddp_optimizer_inductor_strides_dont_specialize(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc_0 = nn.Linear(768, 768)
self.fc_1 = nn.Linear(768, 768)
def forward(self, x):
x = self.fc_0(x)
x = self.fc_1(x)
return x
model = Model()
model = FakeDDP(model)
inp = torch.randn((16, 18, 768))
inp2 = torch.randn((16, 20, 768))
torch._dynamo.mark_dynamic(inp, 1)
torch._dynamo.mark_dynamic(inp2, 1)
torch._dynamo.utils.clear_compilation_metrics()
torch._dynamo.reset()
try:
DDP._active_ddp_module = model
opt_model = torch.compile(model)
self.assertEqual(0, len(torch._dynamo.utils.get_compilation_metrics()))
opt_model(inp)
compile_count_before = len(torch._dynamo.utils.get_compilation_metrics())
opt_model(inp2)
compile_count_after = len(torch._dynamo.utils.get_compilation_metrics())
# no recompiles
self.assertEqual(compile_count_before, compile_count_after)
finally:
DDP._active_ddp_module = None
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
def test_unbacked_symbol_splitting_direct(self):
class Model(nn.Module):

View file

@ -2150,16 +2150,18 @@ def set_tracing_context_output_strides(example_inputs, compiled_graph):
if exprs is None:
context.output_strides.append(None)
else:
context.output_strides.append(
tuple(
(
shape_env.evaluate_symexpr(e)
if shape_env is not None
else int(e)
)
for e in exprs
)
)
fakify_first_call = False
if ctx := torch._guards.TracingContext.try_get():
fakify_first_call = ctx.fakify_first_call
def map_expr(e):
if shape_env is None:
return int(e)
if fakify_first_call:
return shape_env.deserialize_symexpr(e)
return shape_env.evaluate_symexpr(e)
context.output_strides.append(tuple(map_expr(e) for e in exprs))
def should_use_remote_fx_graph_cache():

View file

@ -5231,6 +5231,16 @@ class ShapeEnv:
args = {str(e): val for e, val in self.var_to_val.items()}
return eval(code, SYMPY_INTERP, args)
def deserialize_symexpr(self, code: str) -> Union[SymInt, SymFloat, SymBool]:
"""
To be used by compile_fx to deserialize symexprs
"""
args = {
str(e): SymInt(SymNode(e, self, int, int(val), fx_node=None))
for e, val in self.var_to_val.items()
}
return eval(code, SYMPY_INTERP, args)
def evaluate_guards_expression(self, code: str, args: Sequence[object]) -> bool:
"""
Expected to be used with produce_guards_expression(). Evaluates an expression