From c5b79699e1aec4436745ba8c32b2d6e36951bff7 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Fri, 1 Nov 2024 14:54:06 -0700 Subject: [PATCH] [hop free symbols] replace ctx.save_for_backward to support symints/ints (#138737) Pull Request resolved: https://github.com/pytorch/pytorch/pull/138737 Approved by: https://github.com/drisspg, https://github.com/zou3519, https://github.com/Chillee ghstack dependencies: #138345, #138428, #138558 --- test/inductor/test_flex_attention.py | 49 ++++++++++++++++- torch/_higher_order_ops/flex_attention.py | 66 +++++++++++++---------- torch/_higher_order_ops/utils.py | 4 +- 3 files changed, 90 insertions(+), 29 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index e647558e70c..d0fc98b1ba6 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -3005,7 +3005,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): qkv = self.qkv(x).view(B, T, 3, self.n_head, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv - y = flex_attention(q, k, v, block_mask=block_mask) + y = flex_attention( + q, + k, + v, + block_mask=block_mask, + ) return y.transpose(1, 2).contiguous().view(B, T, C) model = SimpleAttention().cuda() @@ -3033,6 +3038,48 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2) + @supported_platform + def test_symbol_closure_in_score_mod(self): + class SimpleAttention(torch.nn.Module): + def __init__(self, dim=512, n_head=8): + super().__init__() + self.qkv = torch.nn.Linear(dim, 3 * dim) + self.n_head = n_head + self.head_dim = dim // n_head + + def forward(self, x, block_mask=None): + B, T, C = x.size() + qkv = self.qkv(x).view(B, T, 3, self.n_head, self.head_dim) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv + return flex_attention( + q, + k, + v, + score_mod=lambda s, b, h, q, k: s + B, + block_mask=block_mask, + ) + + model = SimpleAttention().cuda() + from torch._dynamo.testing import EagerAndRecordGraphs + + backend = EagerAndRecordGraphs() + model.compile(mode="default", dynamic=True, backend=backend) + sequence_len = 256 + + torch._dynamo.reset() + for batch_shape in [4, 16, 32]: + x = torch.randn(batch_shape, sequence_len, 512).cuda() + model(x) + self.assertEqual(len(backend.graphs), 1) + self.assertExpectedInline( + backend.graphs[0].score_mod_0.code.strip(), + """\ +def forward(self, child_4 : torch.Tensor, child_5 : torch.Tensor, child_6 : torch.Tensor, child_7 : torch.Tensor, child_8 : torch.Tensor, getitem : torch.SymInt): + add = child_4 + getitem; child_4 = getitem = None + return add""", + ) + @supported_platform def test_fw_bw_graph_correctness(self): cnt = CompileCounterWithBackend("aot_eager") diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 56794cc1b93..f38cb6a0de3 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -9,6 +9,8 @@ from torch._higher_order_ops.utils import ( _has_potential_branch_input_mutation, autograd_not_implemented, reenter_make_fx, + save_tensors_and_symints_for_backward, + saved_tensors_and_symints, UnsupportedAliasMutationException, ) from torch._ops import HigherOrderOperator @@ -84,7 +86,7 @@ class FlexAttentionHOP(HigherOrderOperator): mask_mod_other_buffers: Tuple = (), ) -> Tuple[torch.Tensor, torch.Tensor]: if not all( - isinstance(buf, torch.Tensor) + isinstance(buf, (torch.Tensor, torch.SymInt, int)) for buf in score_mod_other_buffers + mask_mod_other_buffers ): raise RuntimeError("Other buffers must be tensors.") @@ -414,7 +416,7 @@ def flex_attention_functionalize( assert isinstance(score_mod_other_buffers_unwrapped, tuple) assert isinstance(mask_mod_other_buffers_unwrapped, tuple) assert all( - isinstance(item, torch.Tensor) + isinstance(item, (torch.Tensor, torch.SymInt, int)) for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped ) @@ -502,14 +504,18 @@ def create_fw_bw_graph( with suspend_functionalization(), disable_functional_mode(): with disable_proxy_modes_tracing(): - def _from_fun(t: Tensor) -> Tensor: - return torch.empty_strided( - t.size(), - t.stride(), - device=t.device, - dtype=t.dtype, - requires_grad=t.requires_grad, - ) + def _from_fun( + t: Union[Tensor, torch.SymInt, int] + ) -> Union[Tensor, torch.SymInt, int]: + if isinstance(t, torch.Tensor): + return torch.empty_strided( + t.size(), + t.stride(), + device=t.device, + dtype=t.dtype, + requires_grad=t.requires_grad, + ) + return t # If someone runs this hop under the default compiler backend ("eager") # Then this path will be run with the actual user inputs. We convert them @@ -524,8 +530,14 @@ def create_fw_bw_graph( unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values) unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers) - assert all(isinstance(t, FakeTensor) for t in unwrapped_score_mod_indexes) - assert all(isinstance(t, FakeTensor) for t in unwrapped_other_buffers) + assert all( + isinstance(t, (FakeTensor, torch.SymInt, int)) + for t in unwrapped_score_mod_indexes + ) + assert all( + isinstance(t, (FakeTensor, torch.SymInt, int)) + for t in unwrapped_other_buffers + ) example_flat_out = pytree.tree_map( _from_fun, @@ -591,9 +603,6 @@ class FlexAttentionAutogradOp(torch.autograd.Function): ctx._fw_graph = fw_graph ctx._joint_graph = joint_graph ctx._mask_graph = block_mask[-1] - # KV_BLOCK_SIZE and Q_BLOCK_SIZE are integers, so can't use ctx.save_for_backward - ctx._Q_BLOCK_SIZE = block_mask[8] - ctx._KV_BLOCK_SIZE = block_mask[9] ctx.scale = scale ctx.kernel_options = kernel_options ctx._score_mod_other_buffers_len = len(score_mod_other_buffers) @@ -610,21 +619,24 @@ class FlexAttentionAutogradOp(torch.autograd.Function): mask_mod_other_buffers, ) - ctx.save_for_backward( - query, - key, - value, - out, - logsumexp, - *block_mask[:8], - *score_mod_other_buffers, - *mask_mod_other_buffers, + save_tensors_and_symints_for_backward( + ctx, + ( + query, + key, + value, + out, + logsumexp, + *block_mask[:10], + *score_mod_other_buffers, + *mask_mod_other_buffers, + ), ) return out, logsumexp @staticmethod def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Optional[Tensor], ...]: # type: ignore[override] - fw_args = ctx.saved_tensors + fw_args = saved_tensors_and_symints(ctx) ( query, key, @@ -639,13 +651,13 @@ class FlexAttentionAutogradOp(torch.autograd.Function): q_indices, full_q_num_blocks, full_q_indices, + Q_BLOCK_SIZE, + KV_BLOCK_SIZE, *other_buffers, ) = fw_args fw_graph = ctx._fw_graph joint_graph = ctx._joint_graph mask_graph = ctx._mask_graph - KV_BLOCK_SIZE = ctx._KV_BLOCK_SIZE - Q_BLOCK_SIZE = ctx._Q_BLOCK_SIZE scale = ctx.scale kernel_options = ctx.kernel_options score_mod_other_buffers = tuple( diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 549e1af54f9..f6a8d29d520 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -431,7 +431,9 @@ def _stack_pytree(pytrees): # iterating over the pos list and pop one item from the front of paritioned_args[pos[i]]. # We use t_idx and s_idx to keep track of the next index of the item we are going to pop for the two lists. def save_tensors_and_symints_for_backward(ctx, args): - assert all(isinstance(arg, (torch.Tensor, torch.SymInt, int)) for arg in args), args + assert all( + isinstance(arg, (torch.Tensor, torch.SymInt, int, type(None))) for arg in args + ), args partitioned_args: List[Any] = [[], []] pos = [] for i, arg in enumerate(args):