[hop free symbols] replace ctx.save_for_backward to support symints/ints (#138737)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138737
Approved by: https://github.com/drisspg, https://github.com/zou3519, https://github.com/Chillee
ghstack dependencies: #138345, #138428, #138558
This commit is contained in:
Yidi Wu 2024-11-01 14:54:06 -07:00 committed by PyTorch MergeBot
parent ac20d0f893
commit c5b79699e1
3 changed files with 90 additions and 29 deletions

View file

@ -3005,7 +3005,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
qkv = self.qkv(x).view(B, T, 3, self.n_head, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv
y = flex_attention(q, k, v, block_mask=block_mask)
y = flex_attention(
q,
k,
v,
block_mask=block_mask,
)
return y.transpose(1, 2).contiguous().view(B, T, C)
model = SimpleAttention().cuda()
@ -3033,6 +3038,48 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2)
@supported_platform
def test_symbol_closure_in_score_mod(self):
class SimpleAttention(torch.nn.Module):
def __init__(self, dim=512, n_head=8):
super().__init__()
self.qkv = torch.nn.Linear(dim, 3 * dim)
self.n_head = n_head
self.head_dim = dim // n_head
def forward(self, x, block_mask=None):
B, T, C = x.size()
qkv = self.qkv(x).view(B, T, 3, self.n_head, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv
return flex_attention(
q,
k,
v,
score_mod=lambda s, b, h, q, k: s + B,
block_mask=block_mask,
)
model = SimpleAttention().cuda()
from torch._dynamo.testing import EagerAndRecordGraphs
backend = EagerAndRecordGraphs()
model.compile(mode="default", dynamic=True, backend=backend)
sequence_len = 256
torch._dynamo.reset()
for batch_shape in [4, 16, 32]:
x = torch.randn(batch_shape, sequence_len, 512).cuda()
model(x)
self.assertEqual(len(backend.graphs), 1)
self.assertExpectedInline(
backend.graphs[0].score_mod_0.code.strip(),
"""\
def forward(self, child_4 : torch.Tensor, child_5 : torch.Tensor, child_6 : torch.Tensor, child_7 : torch.Tensor, child_8 : torch.Tensor, getitem : torch.SymInt):
add = child_4 + getitem; child_4 = getitem = None
return add""",
)
@supported_platform
def test_fw_bw_graph_correctness(self):
cnt = CompileCounterWithBackend("aot_eager")

View file

@ -9,6 +9,8 @@ from torch._higher_order_ops.utils import (
_has_potential_branch_input_mutation,
autograd_not_implemented,
reenter_make_fx,
save_tensors_and_symints_for_backward,
saved_tensors_and_symints,
UnsupportedAliasMutationException,
)
from torch._ops import HigherOrderOperator
@ -84,7 +86,7 @@ class FlexAttentionHOP(HigherOrderOperator):
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
if not all(
isinstance(buf, torch.Tensor)
isinstance(buf, (torch.Tensor, torch.SymInt, int))
for buf in score_mod_other_buffers + mask_mod_other_buffers
):
raise RuntimeError("Other buffers must be tensors.")
@ -414,7 +416,7 @@ def flex_attention_functionalize(
assert isinstance(score_mod_other_buffers_unwrapped, tuple)
assert isinstance(mask_mod_other_buffers_unwrapped, tuple)
assert all(
isinstance(item, torch.Tensor)
isinstance(item, (torch.Tensor, torch.SymInt, int))
for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped
)
@ -502,14 +504,18 @@ def create_fw_bw_graph(
with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
def _from_fun(t: Tensor) -> Tensor:
return torch.empty_strided(
t.size(),
t.stride(),
device=t.device,
dtype=t.dtype,
requires_grad=t.requires_grad,
)
def _from_fun(
t: Union[Tensor, torch.SymInt, int]
) -> Union[Tensor, torch.SymInt, int]:
if isinstance(t, torch.Tensor):
return torch.empty_strided(
t.size(),
t.stride(),
device=t.device,
dtype=t.dtype,
requires_grad=t.requires_grad,
)
return t
# If someone runs this hop under the default compiler backend ("eager")
# Then this path will be run with the actual user inputs. We convert them
@ -524,8 +530,14 @@ def create_fw_bw_graph(
unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values)
unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers)
assert all(isinstance(t, FakeTensor) for t in unwrapped_score_mod_indexes)
assert all(isinstance(t, FakeTensor) for t in unwrapped_other_buffers)
assert all(
isinstance(t, (FakeTensor, torch.SymInt, int))
for t in unwrapped_score_mod_indexes
)
assert all(
isinstance(t, (FakeTensor, torch.SymInt, int))
for t in unwrapped_other_buffers
)
example_flat_out = pytree.tree_map(
_from_fun,
@ -591,9 +603,6 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
ctx._fw_graph = fw_graph
ctx._joint_graph = joint_graph
ctx._mask_graph = block_mask[-1]
# KV_BLOCK_SIZE and Q_BLOCK_SIZE are integers, so can't use ctx.save_for_backward
ctx._Q_BLOCK_SIZE = block_mask[8]
ctx._KV_BLOCK_SIZE = block_mask[9]
ctx.scale = scale
ctx.kernel_options = kernel_options
ctx._score_mod_other_buffers_len = len(score_mod_other_buffers)
@ -610,21 +619,24 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
mask_mod_other_buffers,
)
ctx.save_for_backward(
query,
key,
value,
out,
logsumexp,
*block_mask[:8],
*score_mod_other_buffers,
*mask_mod_other_buffers,
save_tensors_and_symints_for_backward(
ctx,
(
query,
key,
value,
out,
logsumexp,
*block_mask[:10],
*score_mod_other_buffers,
*mask_mod_other_buffers,
),
)
return out, logsumexp
@staticmethod
def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Optional[Tensor], ...]: # type: ignore[override]
fw_args = ctx.saved_tensors
fw_args = saved_tensors_and_symints(ctx)
(
query,
key,
@ -639,13 +651,13 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
q_indices,
full_q_num_blocks,
full_q_indices,
Q_BLOCK_SIZE,
KV_BLOCK_SIZE,
*other_buffers,
) = fw_args
fw_graph = ctx._fw_graph
joint_graph = ctx._joint_graph
mask_graph = ctx._mask_graph
KV_BLOCK_SIZE = ctx._KV_BLOCK_SIZE
Q_BLOCK_SIZE = ctx._Q_BLOCK_SIZE
scale = ctx.scale
kernel_options = ctx.kernel_options
score_mod_other_buffers = tuple(

View file

@ -431,7 +431,9 @@ def _stack_pytree(pytrees):
# iterating over the pos list and pop one item from the front of paritioned_args[pos[i]].
# We use t_idx and s_idx to keep track of the next index of the item we are going to pop for the two lists.
def save_tensors_and_symints_for_backward(ctx, args):
assert all(isinstance(arg, (torch.Tensor, torch.SymInt, int)) for arg in args), args
assert all(
isinstance(arg, (torch.Tensor, torch.SymInt, int, type(None))) for arg in args
), args
partitioned_args: List[Any] = [[], []]
pos = []
for i, arg in enumerate(args):