mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[hop free symbols] replace ctx.save_for_backward to support symints/ints (#138737)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138737 Approved by: https://github.com/drisspg, https://github.com/zou3519, https://github.com/Chillee ghstack dependencies: #138345, #138428, #138558
This commit is contained in:
parent
ac20d0f893
commit
c5b79699e1
3 changed files with 90 additions and 29 deletions
|
|
@ -3005,7 +3005,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
qkv = self.qkv(x).view(B, T, 3, self.n_head, self.head_dim)
|
||||
qkv = qkv.permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv
|
||||
y = flex_attention(q, k, v, block_mask=block_mask)
|
||||
y = flex_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_mask=block_mask,
|
||||
)
|
||||
return y.transpose(1, 2).contiguous().view(B, T, C)
|
||||
|
||||
model = SimpleAttention().cuda()
|
||||
|
|
@ -3033,6 +3038,48 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
|
||||
self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2)
|
||||
|
||||
@supported_platform
|
||||
def test_symbol_closure_in_score_mod(self):
|
||||
class SimpleAttention(torch.nn.Module):
|
||||
def __init__(self, dim=512, n_head=8):
|
||||
super().__init__()
|
||||
self.qkv = torch.nn.Linear(dim, 3 * dim)
|
||||
self.n_head = n_head
|
||||
self.head_dim = dim // n_head
|
||||
|
||||
def forward(self, x, block_mask=None):
|
||||
B, T, C = x.size()
|
||||
qkv = self.qkv(x).view(B, T, 3, self.n_head, self.head_dim)
|
||||
qkv = qkv.permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv
|
||||
return flex_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
score_mod=lambda s, b, h, q, k: s + B,
|
||||
block_mask=block_mask,
|
||||
)
|
||||
|
||||
model = SimpleAttention().cuda()
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs
|
||||
|
||||
backend = EagerAndRecordGraphs()
|
||||
model.compile(mode="default", dynamic=True, backend=backend)
|
||||
sequence_len = 256
|
||||
|
||||
torch._dynamo.reset()
|
||||
for batch_shape in [4, 16, 32]:
|
||||
x = torch.randn(batch_shape, sequence_len, 512).cuda()
|
||||
model(x)
|
||||
self.assertEqual(len(backend.graphs), 1)
|
||||
self.assertExpectedInline(
|
||||
backend.graphs[0].score_mod_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, child_4 : torch.Tensor, child_5 : torch.Tensor, child_6 : torch.Tensor, child_7 : torch.Tensor, child_8 : torch.Tensor, getitem : torch.SymInt):
|
||||
add = child_4 + getitem; child_4 = getitem = None
|
||||
return add""",
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
def test_fw_bw_graph_correctness(self):
|
||||
cnt = CompileCounterWithBackend("aot_eager")
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from torch._higher_order_ops.utils import (
|
|||
_has_potential_branch_input_mutation,
|
||||
autograd_not_implemented,
|
||||
reenter_make_fx,
|
||||
save_tensors_and_symints_for_backward,
|
||||
saved_tensors_and_symints,
|
||||
UnsupportedAliasMutationException,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
|
|
@ -84,7 +86,7 @@ class FlexAttentionHOP(HigherOrderOperator):
|
|||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not all(
|
||||
isinstance(buf, torch.Tensor)
|
||||
isinstance(buf, (torch.Tensor, torch.SymInt, int))
|
||||
for buf in score_mod_other_buffers + mask_mod_other_buffers
|
||||
):
|
||||
raise RuntimeError("Other buffers must be tensors.")
|
||||
|
|
@ -414,7 +416,7 @@ def flex_attention_functionalize(
|
|||
assert isinstance(score_mod_other_buffers_unwrapped, tuple)
|
||||
assert isinstance(mask_mod_other_buffers_unwrapped, tuple)
|
||||
assert all(
|
||||
isinstance(item, torch.Tensor)
|
||||
isinstance(item, (torch.Tensor, torch.SymInt, int))
|
||||
for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped
|
||||
)
|
||||
|
||||
|
|
@ -502,14 +504,18 @@ def create_fw_bw_graph(
|
|||
with suspend_functionalization(), disable_functional_mode():
|
||||
with disable_proxy_modes_tracing():
|
||||
|
||||
def _from_fun(t: Tensor) -> Tensor:
|
||||
return torch.empty_strided(
|
||||
t.size(),
|
||||
t.stride(),
|
||||
device=t.device,
|
||||
dtype=t.dtype,
|
||||
requires_grad=t.requires_grad,
|
||||
)
|
||||
def _from_fun(
|
||||
t: Union[Tensor, torch.SymInt, int]
|
||||
) -> Union[Tensor, torch.SymInt, int]:
|
||||
if isinstance(t, torch.Tensor):
|
||||
return torch.empty_strided(
|
||||
t.size(),
|
||||
t.stride(),
|
||||
device=t.device,
|
||||
dtype=t.dtype,
|
||||
requires_grad=t.requires_grad,
|
||||
)
|
||||
return t
|
||||
|
||||
# If someone runs this hop under the default compiler backend ("eager")
|
||||
# Then this path will be run with the actual user inputs. We convert them
|
||||
|
|
@ -524,8 +530,14 @@ def create_fw_bw_graph(
|
|||
unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values)
|
||||
unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers)
|
||||
|
||||
assert all(isinstance(t, FakeTensor) for t in unwrapped_score_mod_indexes)
|
||||
assert all(isinstance(t, FakeTensor) for t in unwrapped_other_buffers)
|
||||
assert all(
|
||||
isinstance(t, (FakeTensor, torch.SymInt, int))
|
||||
for t in unwrapped_score_mod_indexes
|
||||
)
|
||||
assert all(
|
||||
isinstance(t, (FakeTensor, torch.SymInt, int))
|
||||
for t in unwrapped_other_buffers
|
||||
)
|
||||
|
||||
example_flat_out = pytree.tree_map(
|
||||
_from_fun,
|
||||
|
|
@ -591,9 +603,6 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
|
|||
ctx._fw_graph = fw_graph
|
||||
ctx._joint_graph = joint_graph
|
||||
ctx._mask_graph = block_mask[-1]
|
||||
# KV_BLOCK_SIZE and Q_BLOCK_SIZE are integers, so can't use ctx.save_for_backward
|
||||
ctx._Q_BLOCK_SIZE = block_mask[8]
|
||||
ctx._KV_BLOCK_SIZE = block_mask[9]
|
||||
ctx.scale = scale
|
||||
ctx.kernel_options = kernel_options
|
||||
ctx._score_mod_other_buffers_len = len(score_mod_other_buffers)
|
||||
|
|
@ -610,21 +619,24 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
|
|||
mask_mod_other_buffers,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
*block_mask[:8],
|
||||
*score_mod_other_buffers,
|
||||
*mask_mod_other_buffers,
|
||||
save_tensors_and_symints_for_backward(
|
||||
ctx,
|
||||
(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
*block_mask[:10],
|
||||
*score_mod_other_buffers,
|
||||
*mask_mod_other_buffers,
|
||||
),
|
||||
)
|
||||
return out, logsumexp
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Optional[Tensor], ...]: # type: ignore[override]
|
||||
fw_args = ctx.saved_tensors
|
||||
fw_args = saved_tensors_and_symints(ctx)
|
||||
(
|
||||
query,
|
||||
key,
|
||||
|
|
@ -639,13 +651,13 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
|
|||
q_indices,
|
||||
full_q_num_blocks,
|
||||
full_q_indices,
|
||||
Q_BLOCK_SIZE,
|
||||
KV_BLOCK_SIZE,
|
||||
*other_buffers,
|
||||
) = fw_args
|
||||
fw_graph = ctx._fw_graph
|
||||
joint_graph = ctx._joint_graph
|
||||
mask_graph = ctx._mask_graph
|
||||
KV_BLOCK_SIZE = ctx._KV_BLOCK_SIZE
|
||||
Q_BLOCK_SIZE = ctx._Q_BLOCK_SIZE
|
||||
scale = ctx.scale
|
||||
kernel_options = ctx.kernel_options
|
||||
score_mod_other_buffers = tuple(
|
||||
|
|
|
|||
|
|
@ -431,7 +431,9 @@ def _stack_pytree(pytrees):
|
|||
# iterating over the pos list and pop one item from the front of paritioned_args[pos[i]].
|
||||
# We use t_idx and s_idx to keep track of the next index of the item we are going to pop for the two lists.
|
||||
def save_tensors_and_symints_for_backward(ctx, args):
|
||||
assert all(isinstance(arg, (torch.Tensor, torch.SymInt, int)) for arg in args), args
|
||||
assert all(
|
||||
isinstance(arg, (torch.Tensor, torch.SymInt, int, type(None))) for arg in args
|
||||
), args
|
||||
partitioned_args: List[Any] = [[], []]
|
||||
pos = []
|
||||
for i, arg in enumerate(args):
|
||||
|
|
|
|||
Loading…
Reference in a new issue