mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Traceable FSDP2] Add Dynamo support for run_with_rng_state HOP (#127247)
Test command: `pytest -rA test/inductor/test_compiled_autograd.py::TestCompiledAutograd::test_trace_run_with_rng_state` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127247 Approved by: https://github.com/bdhirsh ghstack dependencies: #129414
This commit is contained in:
parent
b24787b757
commit
aa4ee2cb9e
3 changed files with 128 additions and 0 deletions
|
|
@ -11,6 +11,7 @@ from unittest import mock
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import _inductor as inductor
|
||||
from torch._dynamo import compiled_autograd, config
|
||||
from torch._dynamo.utils import counters
|
||||
|
|
@ -1308,6 +1309,99 @@ main()
|
|||
|
||||
self.check_output_and_recompiles(fn, 1)
|
||||
|
||||
def test_trace_run_with_rng_state(self):
|
||||
def sdpa(xq, xk):
|
||||
return F.scaled_dot_product_attention(xq, xk, xk, is_causal=True)
|
||||
|
||||
def g(xq_1, xk_1, xq_2, xk_2):
|
||||
# xq: (bs, n_local_heads, seqlen, head_dim)
|
||||
# xk: (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||
y1 = sdpa(xq_1, xk_1)
|
||||
y2 = torch.utils.checkpoint.checkpoint(
|
||||
sdpa, xq_2, xk_2, use_reentrant=False
|
||||
)
|
||||
y = torch.mul(y1, y2)
|
||||
z = torch.matmul(y, y)
|
||||
return z
|
||||
|
||||
def f():
|
||||
bs = 1
|
||||
n_local_heads = 1
|
||||
seqlen = 2
|
||||
head_dim = 2
|
||||
cache_len = 2
|
||||
xq_list = [
|
||||
torch.ones(
|
||||
(bs, n_local_heads, seqlen, head_dim),
|
||||
requires_grad=True,
|
||||
device="cpu",
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
xk_list = [
|
||||
torch.ones(
|
||||
(bs, n_local_heads, cache_len + seqlen, head_dim),
|
||||
requires_grad=True,
|
||||
device="cpu",
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
out = torch.compile(g, fullgraph=True)(
|
||||
xq_list[0], xk_list[0], xq_list[1], xk_list[1]
|
||||
)
|
||||
out.sum().backward()
|
||||
return out, *[x.grad for x in xq_list + xk_list]
|
||||
|
||||
"""
|
||||
Walkthrough of what happens with `run_with_rng_state`:
|
||||
1. `run_with_rng_state` only shows up in the backward graph (this op is inserted by the partitioner).
|
||||
2. The Dynamo graph captured by Compiled Autograd looks like:
|
||||
```
|
||||
===== __compiled_fn_3 =====
|
||||
torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_inputs_ : list):
|
||||
...
|
||||
run_with_rng_state = torch.ops.higher_order.run_with_rng_state(
|
||||
getitem_8,
|
||||
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
|
||||
getitem_3, getitem_4, getitem_4, 0.0, True,
|
||||
)
|
||||
...
|
||||
```
|
||||
3. We want to preserve this `run_with_rng_state` op when going through AOTAutograd. We do it by having special handling
|
||||
in `run_with_rng_state` op's py_functionalize_impl.
|
||||
"""
|
||||
|
||||
def _run_with_rng_state_op_check(inductor_post_grad_graph):
|
||||
# Checks that `run_with_rng_state` op exists in Compiled Autograd's Inductor post-grad graph.
|
||||
op_set = {node.target for node in inductor_post_grad_graph.nodes}
|
||||
if torch.ops.higher_order.run_and_save_rng_state not in op_set:
|
||||
# This is backward graph, so check existence of `run_with_rng_state` op
|
||||
self.assertTrue(torch.ops.higher_order.run_with_rng_state in op_set)
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
post_grad_custom_post_pass=_run_with_rng_state_op_check
|
||||
):
|
||||
compiler_fn = make_compiler_fn(fullgraph=True)
|
||||
|
||||
def make_compiler_fn_with_op_check():
|
||||
def _compiler_fn(gm):
|
||||
# Checks that `run_with_rng_state` op exists in Compiled Autograd's Dynamo graph.
|
||||
self.assertTrue(
|
||||
any(
|
||||
node.target is torch.ops.higher_order.run_with_rng_state
|
||||
for node in gm.graph.nodes
|
||||
)
|
||||
)
|
||||
return compiler_fn(gm)
|
||||
|
||||
return _compiler_fn
|
||||
|
||||
compiler_fn_with_op_check = make_compiler_fn_with_op_check()
|
||||
self.check_output_and_recompiles(
|
||||
f, compiler_fn=compiler_fn_with_op_check, compile_fn=False
|
||||
)
|
||||
|
||||
def test_autograd_cpp_node(self):
|
||||
cpp_source = """
|
||||
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
|
||||
|
|
|
|||
|
|
@ -554,6 +554,8 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
|
|||
return TraceWrappedHigherOrderOperatorVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "strict_mode":
|
||||
return StrictModeHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "run_with_rng_state":
|
||||
return RunWithRNGStateHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "associative_scan":
|
||||
return AssociativeScanHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "call_torchbind":
|
||||
|
|
@ -1440,6 +1442,26 @@ class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
)
|
||||
|
||||
|
||||
class RunWithRNGStateHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
def call_function(
|
||||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||||
) -> "VariableTracker":
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
p_args = tuple(arg.as_proxy() for arg in args)
|
||||
p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
|
||||
return wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
args=p_args,
|
||||
kwargs=p_kwargs,
|
||||
),
|
||||
example_value=None,
|
||||
)
|
||||
|
||||
|
||||
class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable):
|
||||
"""
|
||||
Handles torch._dynamo._trace_wrapped_higher_order_op.inner_trace
|
||||
|
|
|
|||
|
|
@ -247,6 +247,18 @@ def register_run_with_rng_state_op():
|
|||
with mode:
|
||||
return op(*args, **kwargs)
|
||||
|
||||
@run_with_rng_state.py_functionalize_impl
|
||||
def impl_functional(ctx, rng_state, op, *args, **kwargs):
|
||||
unwrapped_rng_state = ctx.unwrap_tensors(rng_state)
|
||||
unwrapped_args = ctx.unwrap_tensors(args)
|
||||
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
||||
|
||||
with ctx.redispatch_to_next():
|
||||
out = run_with_rng_state(
|
||||
unwrapped_rng_state, op, *unwrapped_args, **unwrapped_kwargs
|
||||
)
|
||||
return ctx.wrap_tensors(out)
|
||||
|
||||
return run_with_rng_state
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue