mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
avoid specializing strides with DDPOptimizer + inductor (#140751)
Fixes https://github.com/pytorch/pytorch/issues/140229 Fixes https://github.com/pytorch/pytorch/issues/139474 The issue was that: (1) DDPOptimizer has some logic to partition the dynamo graph into buckets, and run AOTAutograd/inductor on each bucket (2) doing so requires knowing the **exact** strides of the outputs of each subgraph, so we can have example inputs (with correct strides) to each of the later subgraphs to compile with (3) there is some existing logic to do this today: we have a `fakify_first_call` flag in AOTAutograd that lets you run it with fake tensor inputs (to handle the calling convention changes that AOTAutograd performs at runtime). During this process, we query inductor for the output strides that it compiled with (4) these outputs strides are stored in the FX graph cache as raw strings of sympy expressions. We have a function, `evaluate_symexpr`, which given the sympy string, and the ShapeEnv's `var_to_val` mapping, will evaluate the sympy string to generate concrete strides (5) evaluating this expression will specialize on the exact values of any variables in our shape env, however. In DDPOptimizer, we want to know what inductor's stride outputs are symbolically. This requires converting the (string) sympy expression into actual `SymInts` that we can return. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140751 Approved by: https://github.com/eellison
This commit is contained in:
parent
b08bc07cd7
commit
471017cbc9
3 changed files with 63 additions and 11 deletions
|
|
@ -301,7 +301,10 @@ class FakeDDP(nn.Module):
|
|||
DDP._active_ddp_module = None
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
with self._inside_ddp_forward():
|
||||
if not DDP._active_ddp_module:
|
||||
with self._inside_ddp_forward():
|
||||
return self.module.forward(*inputs, **kwargs)
|
||||
else:
|
||||
return self.module.forward(*inputs, **kwargs)
|
||||
|
||||
|
||||
|
|
@ -372,6 +375,43 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
|
|||
opt_model = torch.compile(dynamic=True)(model)
|
||||
opt_model(torch.randn(20, 512))
|
||||
|
||||
@patch.object(config, "optimize_ddp", True)
|
||||
def test_ddp_optimizer_inductor_strides_dont_specialize(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc_0 = nn.Linear(768, 768)
|
||||
self.fc_1 = nn.Linear(768, 768)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc_0(x)
|
||||
x = self.fc_1(x)
|
||||
return x
|
||||
|
||||
model = Model()
|
||||
model = FakeDDP(model)
|
||||
|
||||
inp = torch.randn((16, 18, 768))
|
||||
inp2 = torch.randn((16, 20, 768))
|
||||
|
||||
torch._dynamo.mark_dynamic(inp, 1)
|
||||
torch._dynamo.mark_dynamic(inp2, 1)
|
||||
|
||||
torch._dynamo.utils.clear_compilation_metrics()
|
||||
torch._dynamo.reset()
|
||||
try:
|
||||
DDP._active_ddp_module = model
|
||||
opt_model = torch.compile(model)
|
||||
self.assertEqual(0, len(torch._dynamo.utils.get_compilation_metrics()))
|
||||
opt_model(inp)
|
||||
compile_count_before = len(torch._dynamo.utils.get_compilation_metrics())
|
||||
opt_model(inp2)
|
||||
compile_count_after = len(torch._dynamo.utils.get_compilation_metrics())
|
||||
# no recompiles
|
||||
self.assertEqual(compile_count_before, compile_count_after)
|
||||
finally:
|
||||
DDP._active_ddp_module = None
|
||||
|
||||
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
|
||||
def test_unbacked_symbol_splitting_direct(self):
|
||||
class Model(nn.Module):
|
||||
|
|
|
|||
|
|
@ -2150,16 +2150,18 @@ def set_tracing_context_output_strides(example_inputs, compiled_graph):
|
|||
if exprs is None:
|
||||
context.output_strides.append(None)
|
||||
else:
|
||||
context.output_strides.append(
|
||||
tuple(
|
||||
(
|
||||
shape_env.evaluate_symexpr(e)
|
||||
if shape_env is not None
|
||||
else int(e)
|
||||
)
|
||||
for e in exprs
|
||||
)
|
||||
)
|
||||
fakify_first_call = False
|
||||
if ctx := torch._guards.TracingContext.try_get():
|
||||
fakify_first_call = ctx.fakify_first_call
|
||||
|
||||
def map_expr(e):
|
||||
if shape_env is None:
|
||||
return int(e)
|
||||
if fakify_first_call:
|
||||
return shape_env.deserialize_symexpr(e)
|
||||
return shape_env.evaluate_symexpr(e)
|
||||
|
||||
context.output_strides.append(tuple(map_expr(e) for e in exprs))
|
||||
|
||||
|
||||
def should_use_remote_fx_graph_cache():
|
||||
|
|
|
|||
|
|
@ -5231,6 +5231,16 @@ class ShapeEnv:
|
|||
args = {str(e): val for e, val in self.var_to_val.items()}
|
||||
return eval(code, SYMPY_INTERP, args)
|
||||
|
||||
def deserialize_symexpr(self, code: str) -> Union[SymInt, SymFloat, SymBool]:
|
||||
"""
|
||||
To be used by compile_fx to deserialize symexprs
|
||||
"""
|
||||
args = {
|
||||
str(e): SymInt(SymNode(e, self, int, int(val), fx_node=None))
|
||||
for e, val in self.var_to_val.items()
|
||||
}
|
||||
return eval(code, SYMPY_INTERP, args)
|
||||
|
||||
def evaluate_guards_expression(self, code: str, args: Sequence[object]) -> bool:
|
||||
"""
|
||||
Expected to be used with produce_guards_expression(). Evaluates an expression
|
||||
|
|
|
|||
Loading…
Reference in a new issue