mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Check meta strides for expanded dims in effn_attn_bias (#146054)
With the `_scaled_dot_product_efficient_attention.default`, we have lowering logic to realize the bias to specific alignment constraints. Some of the dims can be expanded, and we need to keep the stride of that dim to 0 to avoid materializing a larger tensor than we need. Previously, we had checked stride of tensor, but if it is not realized, that will not work. so we should check the strides of the meta as well. Note: getting the exact of realizing/slicing/requiring_exact_strides was a little tricky. I commented to @exclamaforte on an example unable-to-fuse message you get if you do it incorrectly. Fix for https://github.com/pytorch/pytorch/issues/145760 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146054 Approved by: https://github.com/shunting314
This commit is contained in:
parent
71e8a2bda4
commit
002accfb8d
5 changed files with 65 additions and 8 deletions
|
|
@ -151,6 +151,37 @@ class CudaReproTests(TestCase):
|
|||
# dont check rng state
|
||||
self.assertEqual(out[:2], fn(query, key, value, input_tensor2)[:2])
|
||||
|
||||
def test_effn_attn_bias_padding_misaligned(self):
|
||||
seqlen_start = 1008
|
||||
|
||||
for offset in range(-1, 2):
|
||||
seqlen = seqlen_start + offset
|
||||
torch._dynamo.reset()
|
||||
|
||||
bsz = 32
|
||||
q = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda")
|
||||
k = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda")
|
||||
v = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda")
|
||||
mask = torch.ones([bsz, 1, seqlen, seqlen], dtype=torch.bool, device="cuda")
|
||||
inputs = [q, k, v, mask]
|
||||
|
||||
def f(q, k, v, mask):
|
||||
return F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask, dropout_p=0.0
|
||||
)
|
||||
|
||||
f_compiled = torch.compile(f)
|
||||
|
||||
out, code = run_and_get_code(f_compiled, *inputs)
|
||||
# padded bias should have an expanded dim
|
||||
FileCheck().check("buf0 =").check_same(", 0, ").run(code[0])
|
||||
# single fused padded kernel
|
||||
FileCheck().check("def call").check_count(
|
||||
"empty_strided_cuda", 1, exactly=True
|
||||
).check("return").run(code[0])
|
||||
|
||||
self.assertEqual(out, f(*inputs))
|
||||
|
||||
@skipIfRocm
|
||||
def test_input_channels_last(self):
|
||||
m = torch.nn.Sequential(
|
||||
|
|
|
|||
|
|
@ -5145,7 +5145,8 @@ class ExternKernel(InputsKernel):
|
|||
allow_padding=False,
|
||||
):
|
||||
assert order is not None or exact_strides is not None
|
||||
if x.get_numel() in (0, 1): # Layout doesn't matter
|
||||
# Layout generally doesn't matter, but some consuming external ops might have requirements
|
||||
if x.get_numel() in (0, 1) and not exact_strides:
|
||||
return x
|
||||
|
||||
# require x to have the layout
|
||||
|
|
|
|||
|
|
@ -2490,12 +2490,15 @@ def sdpa_constraint(fx_node, *args, **kwargs):
|
|||
out_size = list(arg.get_size())
|
||||
|
||||
expanded_dims = []
|
||||
if arg.maybe_get_stride() is not None:
|
||||
# We require a dense last dimension, but the other strides
|
||||
# can be expanded, which results in a smaller tensor
|
||||
for i, s in enumerate(arg.get_stride()[0:-1]):
|
||||
if V.graph.sizevars.statically_known_equals(s, 0):
|
||||
expanded_dims.append(i)
|
||||
# We require a dense last dimension, but the other strides
|
||||
# can be expanded, which results in a smaller tensor
|
||||
maybe_stride = arg.maybe_get_stride()
|
||||
for i in range(len(arg.get_size()) - 1):
|
||||
if V.graph.sizevars.statically_known_equals(meta_stride_expr[i], 0) or (
|
||||
maybe_stride is not None
|
||||
and V.graph.sizevars.statically_known_equals(maybe_stride[i], 0)
|
||||
):
|
||||
expanded_dims.append(i)
|
||||
|
||||
# Now, pad strides to alignment
|
||||
out_strides = [-1] * len(out_size)
|
||||
|
|
@ -2518,7 +2521,29 @@ def sdpa_constraint(fx_node, *args, **kwargs):
|
|||
stride = ceildiv(stride, ALIGNMENT) * ALIGNMENT
|
||||
|
||||
out_strides[i] = stride
|
||||
return ir.ExternKernel.require_exact_strides(arg, out_strides)
|
||||
|
||||
for dim in expanded_dims:
|
||||
arg = slice_(arg, dim, 0, 1)
|
||||
|
||||
# TODO this is too subtle to get right in lowering, should be handled in match_exact_strides
|
||||
out = ir.ExternKernel.require_exact_strides(arg, out_strides)
|
||||
out = expand(TensorBox(out), out_size)
|
||||
out = ir.try_match_insignificant_strides(out, out_strides)
|
||||
return out
|
||||
|
||||
if ir.is_aligned_realized_tensor(arg, ALIGNMENT):
|
||||
return ir.try_match_insignificant_strides(
|
||||
ir.ExternKernel.realize_input(arg), meta_stride_expr
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(arg, IRNode)
|
||||
and arg.maybe_get_stride() is not None
|
||||
and ir.is_aligned_realized_tensor(arg, ALIGNMENT)
|
||||
):
|
||||
return ir.try_match_insignificant_strides(
|
||||
ir.ExternKernel.realize_input(arg), meta_stride_expr
|
||||
)
|
||||
|
||||
def is_aligned(x):
|
||||
return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0
|
||||
|
|
|
|||
Loading…
Reference in a new issue