From 002accfb8d29b720157ae345370aa8591ef6150f Mon Sep 17 00:00:00 2001 From: eellison Date: Thu, 6 Feb 2025 17:26:10 -0800 Subject: [PATCH] Check meta strides for expanded dims in effn_attn_bias (#146054) With the `_scaled_dot_product_efficient_attention.default`, we have lowering logic to realize the bias to specific alignment constraints. Some of the dims can be expanded, and we need to keep the stride of that dim to 0 to avoid materializing a larger tensor than we need. Previously, we had checked stride of tensor, but if it is not realized, that will not work. so we should check the strides of the meta as well. Note: getting the exact of realizing/slicing/requiring_exact_strides was a little tricky. I commented to @exclamaforte on an example unable-to-fuse message you get if you do it incorrectly. Fix for https://github.com/pytorch/pytorch/issues/145760 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146054 Approved by: https://github.com/shunting314 --- test/inductor/test_cuda_repro.py | 31 +++++++++++++++ ...onCPU.test_out__refs_bitwise_not_cpu_int64 | 0 ...CUDA.test_out__refs_bitwise_not_cuda_int64 | 0 torch/_inductor/ir.py | 3 +- torch/_inductor/lowering.py | 39 +++++++++++++++---- 5 files changed, 65 insertions(+), 8 deletions(-) delete mode 100644 test/inductor_expected_failures/TestCommonCPU.test_out__refs_bitwise_not_cpu_int64 delete mode 100644 test/inductor_expected_failures/TestCommonCUDA.test_out__refs_bitwise_not_cuda_int64 diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 2dc21876b7c..b14ef2a8a57 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -151,6 +151,37 @@ class CudaReproTests(TestCase): # dont check rng state self.assertEqual(out[:2], fn(query, key, value, input_tensor2)[:2]) + def test_effn_attn_bias_padding_misaligned(self): + seqlen_start = 1008 + + for offset in range(-1, 2): + seqlen = seqlen_start + offset + torch._dynamo.reset() + + bsz = 32 + q = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda") + k = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda") + v = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda") + mask = torch.ones([bsz, 1, seqlen, seqlen], dtype=torch.bool, device="cuda") + inputs = [q, k, v, mask] + + def f(q, k, v, mask): + return F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0 + ) + + f_compiled = torch.compile(f) + + out, code = run_and_get_code(f_compiled, *inputs) + # padded bias should have an expanded dim + FileCheck().check("buf0 =").check_same(", 0, ").run(code[0]) + # single fused padded kernel + FileCheck().check("def call").check_count( + "empty_strided_cuda", 1, exactly=True + ).check("return").run(code[0]) + + self.assertEqual(out, f(*inputs)) + @skipIfRocm def test_input_channels_last(self): m = torch.nn.Sequential( diff --git a/test/inductor_expected_failures/TestCommonCPU.test_out__refs_bitwise_not_cpu_int64 b/test/inductor_expected_failures/TestCommonCPU.test_out__refs_bitwise_not_cpu_int64 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestCommonCUDA.test_out__refs_bitwise_not_cuda_int64 b/test/inductor_expected_failures/TestCommonCUDA.test_out__refs_bitwise_not_cuda_int64 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 5415f4c572c..6dae9fe7226 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5145,7 +5145,8 @@ class ExternKernel(InputsKernel): allow_padding=False, ): assert order is not None or exact_strides is not None - if x.get_numel() in (0, 1): # Layout doesn't matter + # Layout generally doesn't matter, but some consuming external ops might have requirements + if x.get_numel() in (0, 1) and not exact_strides: return x # require x to have the layout diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index a30b9aebc30..9000613a6e9 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2490,12 +2490,15 @@ def sdpa_constraint(fx_node, *args, **kwargs): out_size = list(arg.get_size()) expanded_dims = [] - if arg.maybe_get_stride() is not None: - # We require a dense last dimension, but the other strides - # can be expanded, which results in a smaller tensor - for i, s in enumerate(arg.get_stride()[0:-1]): - if V.graph.sizevars.statically_known_equals(s, 0): - expanded_dims.append(i) + # We require a dense last dimension, but the other strides + # can be expanded, which results in a smaller tensor + maybe_stride = arg.maybe_get_stride() + for i in range(len(arg.get_size()) - 1): + if V.graph.sizevars.statically_known_equals(meta_stride_expr[i], 0) or ( + maybe_stride is not None + and V.graph.sizevars.statically_known_equals(maybe_stride[i], 0) + ): + expanded_dims.append(i) # Now, pad strides to alignment out_strides = [-1] * len(out_size) @@ -2518,7 +2521,29 @@ def sdpa_constraint(fx_node, *args, **kwargs): stride = ceildiv(stride, ALIGNMENT) * ALIGNMENT out_strides[i] = stride - return ir.ExternKernel.require_exact_strides(arg, out_strides) + + for dim in expanded_dims: + arg = slice_(arg, dim, 0, 1) + + # TODO this is too subtle to get right in lowering, should be handled in match_exact_strides + out = ir.ExternKernel.require_exact_strides(arg, out_strides) + out = expand(TensorBox(out), out_size) + out = ir.try_match_insignificant_strides(out, out_strides) + return out + + if ir.is_aligned_realized_tensor(arg, ALIGNMENT): + return ir.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride_expr + ) + + if ( + isinstance(arg, IRNode) + and arg.maybe_get_stride() is not None + and ir.is_aligned_realized_tensor(arg, ALIGNMENT) + ): + return ir.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride_expr + ) def is_aligned(x): return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0