pytorch/torch/_higher_order_ops
drisspg 69feef5a94 Fix broken meta function for flex-attention backwards (#146563)
# Summary

Fixes https://github.com/pytorch/pytorch/issues/146377

So what was the original problem: we were codegening a really weird epilogue:

```Python
        # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
        # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
        xindex = index_k + 64*index_n + 64*off_hkv*ks2 + 128*off_zq*ks2
        tl.store(out_ptr0 + (tl.broadcast_to(index_k + 64*index_n + off_hkv*ks1, dk.shape)), dk, mask)
        x5 = (xindex % ks3)
        tmp2 = tl.load(out_ptr0 + (x5 + ks1*off_hkv), mask, eviction_policy='evict_last')
        tl.store(out_ptr1 + (tl.broadcast_to(xindex, dk.shape)), tmp2, mask)
 ```

 This epilogue was writing and then reading from overlapping regions of memory causing a race condition.

 ### Why were we generating this epilgoue

 During the lowering we created a buffer w/ a different size/stride from the expected return strides. I :think this added an implicit node (for doing the permutation of this wrongly strided output to the the expected one from the meta func. The scheduler for some reason thought it was okay to fuse this into the epilogue, tbh I dont know why.

 This fixes the broken meta func and the original repro. I will add a test but it is hard to pop, better than nothing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146563
Approved by: https://github.com/Chillee
2025-02-08 04:13:52 +00:00
..
__init__.py Barebones flat_apply HOP (#146060) 2025-02-01 16:17:48 +00:00
aoti_call_delegate.py Introduce aoti_call_delegate HOP (#145630) 2025-01-31 04:57:36 +00:00
associative_scan.py Require that all HOPs be imported at import torch time (#145939) 2025-01-29 22:27:52 +00:00
auto_functionalize.py [auto_functionalized] Support Tensor(a!)[]? (#145400) 2025-02-05 14:52:39 +00:00
cond.py [cond] remove warning for unsupported tuple returns (#145766) 2025-01-28 03:13:36 +00:00
effects.py PEP585 update - torch/_higher_order_ops torch/_subclasses torch/backends torch/compiler torch/cuda torch/masked torch/mtia torch/nested (#145202) 2025-01-20 22:37:26 +00:00
executorch_call_delegate.py
flat_apply.py Barebones flat_apply HOP (#146060) 2025-02-01 16:17:48 +00:00
flex_attention.py Fix broken meta function for flex-attention backwards (#146563) 2025-02-08 04:13:52 +00:00
foreach_map.py PEP585 update - torch/_higher_order_ops torch/_subclasses torch/backends torch/compiler torch/cuda torch/masked torch/mtia torch/nested (#145202) 2025-01-20 22:37:26 +00:00
hints_wrap.py [hop][be] add utils for more comprehensive input alias and mutation (#145298) 2025-01-23 18:12:28 +00:00
invoke_subgraph.py PEP585 update - torch/_higher_order_ops torch/_subclasses torch/backends torch/compiler torch/cuda torch/masked torch/mtia torch/nested (#145202) 2025-01-20 22:37:26 +00:00
map.py [BE]: Apply PERF401 autofixes from ruff (#140980) 2024-11-20 17:52:07 +00:00
out_dtype.py [BE] typing for decorators - library (#138969) 2025-01-15 17:08:55 +00:00
prim_hop_base.py [BE] typing for decorators - library (#138969) 2025-01-15 17:08:55 +00:00
run_const_graph.py [export] Unify single and multiple return for hops (#143227) 2025-01-13 03:31:14 +00:00
scan.py [scan] scan dim handling in user-facing scan() (#145179) 2025-01-30 21:09:07 +00:00
strict_mode.py
torchbind.py Remove unused Python variables in torch/[_-a]* (#133492) 2024-12-12 17:39:14 +00:00
triton_kernel_wrap.py [inductor] Make triton kernel autotune config defaults backward-compatible (#145494) 2025-01-29 00:31:39 +00:00
utils.py [while_loop] specialize when cond_fn return constants (#144515) 2025-01-30 19:02:34 +00:00
while_loop.py [hop] fix unbacked_bindings meta for while_loop (#143559) 2025-01-30 21:33:09 +00:00
wrap.py Require that all HOPs be imported at import torch time (#145939) 2025-01-29 22:27:52 +00:00