mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
# Summary Fixes https://github.com/pytorch/pytorch/issues/146377 So what was the original problem: we were codegening a really weird epilogue: ```Python # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] xindex = index_k + 64*index_n + 64*off_hkv*ks2 + 128*off_zq*ks2 tl.store(out_ptr0 + (tl.broadcast_to(index_k + 64*index_n + off_hkv*ks1, dk.shape)), dk, mask) x5 = (xindex % ks3) tmp2 = tl.load(out_ptr0 + (x5 + ks1*off_hkv), mask, eviction_policy='evict_last') tl.store(out_ptr1 + (tl.broadcast_to(xindex, dk.shape)), tmp2, mask) ``` This epilogue was writing and then reading from overlapping regions of memory causing a race condition. ### Why were we generating this epilgoue During the lowering we created a buffer w/ a different size/stride from the expected return strides. I :think this added an implicit node (for doing the permutation of this wrongly strided output to the the expected one from the meta func. The scheduler for some reason thought it was okay to fuse this into the epilogue, tbh I dont know why. This fixes the broken meta func and the original repro. I will add a test but it is hard to pop, better than nothing Pull Request resolved: https://github.com/pytorch/pytorch/pull/146563 Approved by: https://github.com/Chillee |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| bmm.py | ||
| conv.py | ||
| flex_attention.py | ||
| flex_decoding.py | ||
| mm.py | ||
| mm_common.py | ||
| mm_plus_mm.py | ||
| mm_scaled.py | ||
| unpack_mixed_mm.py | ||