From aa4ee2cb9e1f9be6bbdd27654e0f768b7fe9be6c Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 24 Jun 2024 16:33:18 -0700 Subject: [PATCH] [Traceable FSDP2] Add Dynamo support for run_with_rng_state HOP (#127247) Test command: `pytest -rA test/inductor/test_compiled_autograd.py::TestCompiledAutograd::test_trace_run_with_rng_state` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127247 Approved by: https://github.com/bdhirsh ghstack dependencies: #129414 --- test/inductor/test_compiled_autograd.py | 94 +++++++++++++++++++++ torch/_dynamo/variables/higher_order_ops.py | 22 +++++ torch/_prims/rng_prims.py | 12 +++ 3 files changed, 128 insertions(+) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index fe65dca77f6..f218a203779 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -11,6 +11,7 @@ from unittest import mock import torch import torch.nn as nn +import torch.nn.functional as F from torch import _inductor as inductor from torch._dynamo import compiled_autograd, config from torch._dynamo.utils import counters @@ -1308,6 +1309,99 @@ main() self.check_output_and_recompiles(fn, 1) + def test_trace_run_with_rng_state(self): + def sdpa(xq, xk): + return F.scaled_dot_product_attention(xq, xk, xk, is_causal=True) + + def g(xq_1, xk_1, xq_2, xk_2): + # xq: (bs, n_local_heads, seqlen, head_dim) + # xk: (bs, n_local_heads, cache_len + seqlen, head_dim) + y1 = sdpa(xq_1, xk_1) + y2 = torch.utils.checkpoint.checkpoint( + sdpa, xq_2, xk_2, use_reentrant=False + ) + y = torch.mul(y1, y2) + z = torch.matmul(y, y) + return z + + def f(): + bs = 1 + n_local_heads = 1 + seqlen = 2 + head_dim = 2 + cache_len = 2 + xq_list = [ + torch.ones( + (bs, n_local_heads, seqlen, head_dim), + requires_grad=True, + device="cpu", + ) + for _ in range(2) + ] + xk_list = [ + torch.ones( + (bs, n_local_heads, cache_len + seqlen, head_dim), + requires_grad=True, + device="cpu", + ) + for _ in range(2) + ] + out = torch.compile(g, fullgraph=True)( + xq_list[0], xk_list[0], xq_list[1], xk_list[1] + ) + out.sum().backward() + return out, *[x.grad for x in xq_list + xk_list] + + """ + Walkthrough of what happens with `run_with_rng_state`: + 1. `run_with_rng_state` only shows up in the backward graph (this op is inserted by the partitioner). + 2. The Dynamo graph captured by Compiled Autograd looks like: + ``` + ===== __compiled_fn_3 ===== + torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): + def forward(self, L_inputs_ : list): + ... + run_with_rng_state = torch.ops.higher_order.run_with_rng_state( + getitem_8, + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default, + getitem_3, getitem_4, getitem_4, 0.0, True, + ) + ... + ``` + 3. We want to preserve this `run_with_rng_state` op when going through AOTAutograd. We do it by having special handling + in `run_with_rng_state` op's py_functionalize_impl. + """ + + def _run_with_rng_state_op_check(inductor_post_grad_graph): + # Checks that `run_with_rng_state` op exists in Compiled Autograd's Inductor post-grad graph. + op_set = {node.target for node in inductor_post_grad_graph.nodes} + if torch.ops.higher_order.run_and_save_rng_state not in op_set: + # This is backward graph, so check existence of `run_with_rng_state` op + self.assertTrue(torch.ops.higher_order.run_with_rng_state in op_set) + + with torch._inductor.config.patch( + post_grad_custom_post_pass=_run_with_rng_state_op_check + ): + compiler_fn = make_compiler_fn(fullgraph=True) + + def make_compiler_fn_with_op_check(): + def _compiler_fn(gm): + # Checks that `run_with_rng_state` op exists in Compiled Autograd's Dynamo graph. + self.assertTrue( + any( + node.target is torch.ops.higher_order.run_with_rng_state + for node in gm.graph.nodes + ) + ) + return compiler_fn(gm) + + return _compiler_fn + + compiler_fn_with_op_check = make_compiler_fn_with_op_check() + self.check_output_and_recompiles( + f, compiler_fn=compiler_fn_with_op_check, compile_fn=False + ) + def test_autograd_cpp_node(self): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 59f8c26ce62..f874ca21e0d 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -554,6 +554,8 @@ class TorchHigherOrderOperatorVariable(VariableTracker): return TraceWrappedHigherOrderOperatorVariable(value, source, **kwargs) elif value.__name__ == "strict_mode": return StrictModeHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "run_with_rng_state": + return RunWithRNGStateHigherOrderVariable(value, source, **kwargs) elif value.__name__ == "associative_scan": return AssociativeScanHigherOrderVariable(value, source, **kwargs) elif value.__name__ == "call_torchbind": @@ -1440,6 +1442,26 @@ class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable): ) +class RunWithRNGStateHigherOrderVariable(TorchHigherOrderOperatorVariable): + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + p_args = tuple(arg.as_proxy() for arg in args) + p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): """ Handles torch._dynamo._trace_wrapped_higher_order_op.inner_trace diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py index 1345ff0334f..f1d2fc6f426 100644 --- a/torch/_prims/rng_prims.py +++ b/torch/_prims/rng_prims.py @@ -247,6 +247,18 @@ def register_run_with_rng_state_op(): with mode: return op(*args, **kwargs) + @run_with_rng_state.py_functionalize_impl + def impl_functional(ctx, rng_state, op, *args, **kwargs): + unwrapped_rng_state = ctx.unwrap_tensors(rng_state) + unwrapped_args = ctx.unwrap_tensors(args) + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + + with ctx.redispatch_to_next(): + out = run_with_rng_state( + unwrapped_rng_state, op, *unwrapped_args, **unwrapped_kwargs + ) + return ctx.wrap_tensors(out) + return run_with_rng_state