[Traceable FSDP2] Add Dynamo support for run_with_rng_state HOP (#127247)

Test command:
`pytest -rA test/inductor/test_compiled_autograd.py::TestCompiledAutograd::test_trace_run_with_rng_state`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127247
Approved by: https://github.com/bdhirsh
ghstack dependencies: #129414
This commit is contained in:
Will Feng 2024-06-24 16:33:18 -07:00 committed by PyTorch MergeBot
parent b24787b757
commit aa4ee2cb9e
3 changed files with 128 additions and 0 deletions

View file

@ -11,6 +11,7 @@ from unittest import mock
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import _inductor as inductor
from torch._dynamo import compiled_autograd, config
from torch._dynamo.utils import counters
@ -1308,6 +1309,99 @@ main()
self.check_output_and_recompiles(fn, 1)
def test_trace_run_with_rng_state(self):
def sdpa(xq, xk):
return F.scaled_dot_product_attention(xq, xk, xk, is_causal=True)
def g(xq_1, xk_1, xq_2, xk_2):
# xq: (bs, n_local_heads, seqlen, head_dim)
# xk: (bs, n_local_heads, cache_len + seqlen, head_dim)
y1 = sdpa(xq_1, xk_1)
y2 = torch.utils.checkpoint.checkpoint(
sdpa, xq_2, xk_2, use_reentrant=False
)
y = torch.mul(y1, y2)
z = torch.matmul(y, y)
return z
def f():
bs = 1
n_local_heads = 1
seqlen = 2
head_dim = 2
cache_len = 2
xq_list = [
torch.ones(
(bs, n_local_heads, seqlen, head_dim),
requires_grad=True,
device="cpu",
)
for _ in range(2)
]
xk_list = [
torch.ones(
(bs, n_local_heads, cache_len + seqlen, head_dim),
requires_grad=True,
device="cpu",
)
for _ in range(2)
]
out = torch.compile(g, fullgraph=True)(
xq_list[0], xk_list[0], xq_list[1], xk_list[1]
)
out.sum().backward()
return out, *[x.grad for x in xq_list + xk_list]
"""
Walkthrough of what happens with `run_with_rng_state`:
1. `run_with_rng_state` only shows up in the backward graph (this op is inserted by the partitioner).
2. The Dynamo graph captured by Compiled Autograd looks like:
```
===== __compiled_fn_3 =====
torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list):
...
run_with_rng_state = torch.ops.higher_order.run_with_rng_state(
getitem_8,
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
getitem_3, getitem_4, getitem_4, 0.0, True,
)
...
```
3. We want to preserve this `run_with_rng_state` op when going through AOTAutograd. We do it by having special handling
in `run_with_rng_state` op's py_functionalize_impl.
"""
def _run_with_rng_state_op_check(inductor_post_grad_graph):
# Checks that `run_with_rng_state` op exists in Compiled Autograd's Inductor post-grad graph.
op_set = {node.target for node in inductor_post_grad_graph.nodes}
if torch.ops.higher_order.run_and_save_rng_state not in op_set:
# This is backward graph, so check existence of `run_with_rng_state` op
self.assertTrue(torch.ops.higher_order.run_with_rng_state in op_set)
with torch._inductor.config.patch(
post_grad_custom_post_pass=_run_with_rng_state_op_check
):
compiler_fn = make_compiler_fn(fullgraph=True)
def make_compiler_fn_with_op_check():
def _compiler_fn(gm):
# Checks that `run_with_rng_state` op exists in Compiled Autograd's Dynamo graph.
self.assertTrue(
any(
node.target is torch.ops.higher_order.run_with_rng_state
for node in gm.graph.nodes
)
)
return compiler_fn(gm)
return _compiler_fn
compiler_fn_with_op_check = make_compiler_fn_with_op_check()
self.check_output_and_recompiles(
f, compiler_fn=compiler_fn_with_op_check, compile_fn=False
)
def test_autograd_cpp_node(self):
cpp_source = """
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {

View file

@ -554,6 +554,8 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
return TraceWrappedHigherOrderOperatorVariable(value, source, **kwargs)
elif value.__name__ == "strict_mode":
return StrictModeHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "run_with_rng_state":
return RunWithRNGStateHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "associative_scan":
return AssociativeScanHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "call_torchbind":
@ -1440,6 +1442,26 @@ class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
class RunWithRNGStateHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .builder import wrap_fx_proxy
p_args = tuple(arg.as_proxy() for arg in args)
p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=p_args,
kwargs=p_kwargs,
),
example_value=None,
)
class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable):
"""
Handles torch._dynamo._trace_wrapped_higher_order_op.inner_trace

View file

@ -247,6 +247,18 @@ def register_run_with_rng_state_op():
with mode:
return op(*args, **kwargs)
@run_with_rng_state.py_functionalize_impl
def impl_functional(ctx, rng_state, op, *args, **kwargs):
unwrapped_rng_state = ctx.unwrap_tensors(rng_state)
unwrapped_args = ctx.unwrap_tensors(args)
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
with ctx.redispatch_to_next():
out = run_with_rng_state(
unwrapped_rng_state, op, *unwrapped_args, **unwrapped_kwargs
)
return ctx.wrap_tensors(out)
return run_with_rng_state