diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 2dc21876b7c..b14ef2a8a57 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -151,6 +151,37 @@ class CudaReproTests(TestCase): # dont check rng state self.assertEqual(out[:2], fn(query, key, value, input_tensor2)[:2]) + def test_effn_attn_bias_padding_misaligned(self): + seqlen_start = 1008 + + for offset in range(-1, 2): + seqlen = seqlen_start + offset + torch._dynamo.reset() + + bsz = 32 + q = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda") + k = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda") + v = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda") + mask = torch.ones([bsz, 1, seqlen, seqlen], dtype=torch.bool, device="cuda") + inputs = [q, k, v, mask] + + def f(q, k, v, mask): + return F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0 + ) + + f_compiled = torch.compile(f) + + out, code = run_and_get_code(f_compiled, *inputs) + # padded bias should have an expanded dim + FileCheck().check("buf0 =").check_same(", 0, ").run(code[0]) + # single fused padded kernel + FileCheck().check("def call").check_count( + "empty_strided_cuda", 1, exactly=True + ).check("return").run(code[0]) + + self.assertEqual(out, f(*inputs)) + @skipIfRocm def test_input_channels_last(self): m = torch.nn.Sequential( diff --git a/test/inductor_expected_failures/TestCommonCPU.test_out__refs_bitwise_not_cpu_int64 b/test/inductor_expected_failures/TestCommonCPU.test_out__refs_bitwise_not_cpu_int64 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestCommonCUDA.test_out__refs_bitwise_not_cuda_int64 b/test/inductor_expected_failures/TestCommonCUDA.test_out__refs_bitwise_not_cuda_int64 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 5415f4c572c..6dae9fe7226 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5145,7 +5145,8 @@ class ExternKernel(InputsKernel): allow_padding=False, ): assert order is not None or exact_strides is not None - if x.get_numel() in (0, 1): # Layout doesn't matter + # Layout generally doesn't matter, but some consuming external ops might have requirements + if x.get_numel() in (0, 1) and not exact_strides: return x # require x to have the layout diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index a30b9aebc30..9000613a6e9 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2490,12 +2490,15 @@ def sdpa_constraint(fx_node, *args, **kwargs): out_size = list(arg.get_size()) expanded_dims = [] - if arg.maybe_get_stride() is not None: - # We require a dense last dimension, but the other strides - # can be expanded, which results in a smaller tensor - for i, s in enumerate(arg.get_stride()[0:-1]): - if V.graph.sizevars.statically_known_equals(s, 0): - expanded_dims.append(i) + # We require a dense last dimension, but the other strides + # can be expanded, which results in a smaller tensor + maybe_stride = arg.maybe_get_stride() + for i in range(len(arg.get_size()) - 1): + if V.graph.sizevars.statically_known_equals(meta_stride_expr[i], 0) or ( + maybe_stride is not None + and V.graph.sizevars.statically_known_equals(maybe_stride[i], 0) + ): + expanded_dims.append(i) # Now, pad strides to alignment out_strides = [-1] * len(out_size) @@ -2518,7 +2521,29 @@ def sdpa_constraint(fx_node, *args, **kwargs): stride = ceildiv(stride, ALIGNMENT) * ALIGNMENT out_strides[i] = stride - return ir.ExternKernel.require_exact_strides(arg, out_strides) + + for dim in expanded_dims: + arg = slice_(arg, dim, 0, 1) + + # TODO this is too subtle to get right in lowering, should be handled in match_exact_strides + out = ir.ExternKernel.require_exact_strides(arg, out_strides) + out = expand(TensorBox(out), out_size) + out = ir.try_match_insignificant_strides(out, out_strides) + return out + + if ir.is_aligned_realized_tensor(arg, ALIGNMENT): + return ir.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride_expr + ) + + if ( + isinstance(arg, IRNode) + and arg.maybe_get_stride() is not None + and ir.is_aligned_realized_tensor(arg, ALIGNMENT) + ): + return ir.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride_expr + ) def is_aligned(x): return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0