Check meta strides for expanded dims in effn_attn_bias (#146054)

With the `_scaled_dot_product_efficient_attention.default`, we have lowering logic to realize the bias to specific alignment constraints. Some of the dims can be expanded, and we need to keep the stride of that dim to 0 to avoid materializing a larger tensor than we need. Previously, we had checked stride of tensor, but if it is not realized, that will not work. so we should check the strides of the meta as well.

Note: getting the exact of realizing/slicing/requiring_exact_strides was a little tricky. I commented to @exclamaforte on an example unable-to-fuse message you get if you do it incorrectly.

Fix for https://github.com/pytorch/pytorch/issues/145760

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146054
Approved by: https://github.com/shunting314
This commit is contained in:
eellison 2025-02-06 17:26:10 -08:00 committed by PyTorch MergeBot
parent 71e8a2bda4
commit 002accfb8d
5 changed files with 65 additions and 8 deletions

View file

@ -151,6 +151,37 @@ class CudaReproTests(TestCase):
# dont check rng state
self.assertEqual(out[:2], fn(query, key, value, input_tensor2)[:2])
def test_effn_attn_bias_padding_misaligned(self):
seqlen_start = 1008
for offset in range(-1, 2):
seqlen = seqlen_start + offset
torch._dynamo.reset()
bsz = 32
q = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda")
k = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda")
v = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda")
mask = torch.ones([bsz, 1, seqlen, seqlen], dtype=torch.bool, device="cuda")
inputs = [q, k, v, mask]
def f(q, k, v, mask):
return F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0
)
f_compiled = torch.compile(f)
out, code = run_and_get_code(f_compiled, *inputs)
# padded bias should have an expanded dim
FileCheck().check("buf0 =").check_same(", 0, ").run(code[0])
# single fused padded kernel
FileCheck().check("def call").check_count(
"empty_strided_cuda", 1, exactly=True
).check("return").run(code[0])
self.assertEqual(out, f(*inputs))
@skipIfRocm
def test_input_channels_last(self):
m = torch.nn.Sequential(

View file

@ -5145,7 +5145,8 @@ class ExternKernel(InputsKernel):
allow_padding=False,
):
assert order is not None or exact_strides is not None
if x.get_numel() in (0, 1): # Layout doesn't matter
# Layout generally doesn't matter, but some consuming external ops might have requirements
if x.get_numel() in (0, 1) and not exact_strides:
return x
# require x to have the layout

View file

@ -2490,12 +2490,15 @@ def sdpa_constraint(fx_node, *args, **kwargs):
out_size = list(arg.get_size())
expanded_dims = []
if arg.maybe_get_stride() is not None:
# We require a dense last dimension, but the other strides
# can be expanded, which results in a smaller tensor
for i, s in enumerate(arg.get_stride()[0:-1]):
if V.graph.sizevars.statically_known_equals(s, 0):
expanded_dims.append(i)
# We require a dense last dimension, but the other strides
# can be expanded, which results in a smaller tensor
maybe_stride = arg.maybe_get_stride()
for i in range(len(arg.get_size()) - 1):
if V.graph.sizevars.statically_known_equals(meta_stride_expr[i], 0) or (
maybe_stride is not None
and V.graph.sizevars.statically_known_equals(maybe_stride[i], 0)
):
expanded_dims.append(i)
# Now, pad strides to alignment
out_strides = [-1] * len(out_size)
@ -2518,7 +2521,29 @@ def sdpa_constraint(fx_node, *args, **kwargs):
stride = ceildiv(stride, ALIGNMENT) * ALIGNMENT
out_strides[i] = stride
return ir.ExternKernel.require_exact_strides(arg, out_strides)
for dim in expanded_dims:
arg = slice_(arg, dim, 0, 1)
# TODO this is too subtle to get right in lowering, should be handled in match_exact_strides
out = ir.ExternKernel.require_exact_strides(arg, out_strides)
out = expand(TensorBox(out), out_size)
out = ir.try_match_insignificant_strides(out, out_strides)
return out
if ir.is_aligned_realized_tensor(arg, ALIGNMENT):
return ir.try_match_insignificant_strides(
ir.ExternKernel.realize_input(arg), meta_stride_expr
)
if (
isinstance(arg, IRNode)
and arg.maybe_get_stride() is not None
and ir.is_aligned_realized_tensor(arg, ALIGNMENT)
):
return ir.try_match_insignificant_strides(
ir.ExternKernel.realize_input(arg), meta_stride_expr
)
def is_aligned(x):
return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0