mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
FIXES https://github.com/pytorch/pytorch/issues/142313 So with previous HOPs, compiled autograd could just inline into their body and get their post-dispatch aten representation. You can't do that with this flex attention HOP, which just wants any proxy tracing mechanism to insert it into its graph. Okay, compiled autograd does use proxy tracing, so we can do that. This is safe because other than the reenter_make_fx call, there were no other make_fx internals usage in the HOP. And compiled autograd specializes on the AOT backward's saved symints which should cover any changes in shapes to the inputs of the HOP. However, there's still an issue: Dynamo doesn't know how to handle `FlexAttentionBackwardHOP` and will graph break, so the flex attention backward is running in eager as of this PR. The tlparse looks really scuffed after the compiled autograd capture: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpMMHBEH/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143155 Approved by: https://github.com/drisspg |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| associative_scan.py | ||
| auto_functionalize.py | ||
| cond.py | ||
| effects.py | ||
| executorch_call_delegate.py | ||
| flex_attention.py | ||
| foreach_map.py | ||
| hints_wrap.py | ||
| invoke_subgraph.py | ||
| map.py | ||
| out_dtype.py | ||
| prim_hop_base.py | ||
| run_const_graph.py | ||
| scan.py | ||
| strict_mode.py | ||
| torchbind.py | ||
| triton_kernel_wrap.py | ||
| utils.py | ||
| while_loop.py | ||
| wrap.py | ||