pytorch/torch/_inductor/kernel
drisspg 69feef5a94 Fix broken meta function for flex-attention backwards (#146563)
# Summary

Fixes https://github.com/pytorch/pytorch/issues/146377

So what was the original problem: we were codegening a really weird epilogue:

```Python
        # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
        # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
        xindex = index_k + 64*index_n + 64*off_hkv*ks2 + 128*off_zq*ks2
        tl.store(out_ptr0 + (tl.broadcast_to(index_k + 64*index_n + off_hkv*ks1, dk.shape)), dk, mask)
        x5 = (xindex % ks3)
        tmp2 = tl.load(out_ptr0 + (x5 + ks1*off_hkv), mask, eviction_policy='evict_last')
        tl.store(out_ptr1 + (tl.broadcast_to(xindex, dk.shape)), tmp2, mask)
 ```

 This epilogue was writing and then reading from overlapping regions of memory causing a race condition.

 ### Why were we generating this epilgoue

 During the lowering we created a buffer w/ a different size/stride from the expected return strides. I :think this added an implicit node (for doing the permutation of this wrongly strided output to the the expected one from the meta func. The scheduler for some reason thought it was okay to fuse this into the epilogue, tbh I dont know why.

 This fixes the broken meta func and the original repro. I will add a test but it is hard to pop, better than nothing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146563
Approved by: https://github.com/Chillee
2025-02-08 04:13:52 +00:00
..
__init__.py
bmm.py [Rocm][Inductor][CK] silence ck package not installed warning when CK backend is not used to autotune bmm (#145626) 2025-01-25 08:44:35 +00:00
conv.py PEP585 update - torch/_inductor (#145198) 2025-01-21 21:04:33 +00:00
flex_attention.py Fix broken meta function for flex-attention backwards (#146563) 2025-02-08 04:13:52 +00:00
flex_decoding.py Enable non power of 2 head_dim for FlexAttention (#133495) 2025-01-23 17:05:38 +00:00
mm.py [AOTI] Fix an unaligned memory access issue in mm_template (#146293) 2025-02-04 17:12:04 +00:00
mm_common.py PEP585 update - torch/_inductor (#145198) 2025-01-21 21:04:33 +00:00
mm_plus_mm.py
mm_scaled.py PEP585 update - torch/_inductor (#145198) 2025-01-21 21:04:33 +00:00
unpack_mixed_mm.py PEP585 update - torch/_inductor (#145198) 2025-01-21 21:04:33 +00:00