pytorch/torch/_higher_order_ops
Simon Fan 72fd7abb35 [ca] fix flex attention backward HOP capture in initial graph (#143155)
FIXES https://github.com/pytorch/pytorch/issues/142313

So with previous HOPs, compiled autograd could just inline into their body and get their post-dispatch aten representation. You can't do that with this flex attention HOP, which just wants any proxy tracing mechanism to insert it into its graph. Okay, compiled autograd does use proxy tracing, so we can do that.

This is safe because other than the reenter_make_fx call, there were no other make_fx internals usage in the HOP. And compiled autograd specializes on the AOT backward's saved symints which should cover any changes in shapes to the inputs of the HOP.

However, there's still an issue: Dynamo doesn't know how to handle `FlexAttentionBackwardHOP` and will graph break, so the flex attention backward is running in eager as of this PR. The tlparse looks really scuffed after the compiled autograd capture: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpMMHBEH/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143155
Approved by: https://github.com/drisspg
2024-12-13 06:04:39 +00:00
..
__init__.py [foreach_map] Initial foreach map HOP impl for inference (#142098) 2024-12-11 21:32:11 +00:00
associative_scan.py Remove unused Python variables in torch/[_-a]* (#133492) 2024-12-12 17:39:14 +00:00
auto_functionalize.py Generate slice.Tensor view operations instead of as_strided when split is used in the original program. (#137225) 2024-10-23 17:42:16 +00:00
cond.py Remove unused Python variables in torch/[_-a]* (#133492) 2024-12-12 17:39:14 +00:00
effects.py Remove some unused type ignores (round 1) (#142325) 2024-12-09 18:23:46 +00:00
executorch_call_delegate.py
flex_attention.py [ca] fix flex attention backward HOP capture in initial graph (#143155) 2024-12-13 06:04:39 +00:00
foreach_map.py [foreach_map] Initial foreach map HOP impl for inference (#142098) 2024-12-11 21:32:11 +00:00
hints_wrap.py
invoke_subgraph.py Remove unused Python variables in torch/[_-a]* (#133492) 2024-12-12 17:39:14 +00:00
map.py [BE]: Apply PERF401 autofixes from ruff (#140980) 2024-11-20 17:52:07 +00:00
out_dtype.py
prim_hop_base.py Remove unused Python variables in torch/[_-a]* (#133492) 2024-12-12 17:39:14 +00:00
run_const_graph.py
scan.py Remove unused Python variables in torch/[_-a]* (#133492) 2024-12-12 17:39:14 +00:00
strict_mode.py
torchbind.py Remove unused Python variables in torch/[_-a]* (#133492) 2024-12-12 17:39:14 +00:00
triton_kernel_wrap.py Remove unused Python variables in torch/[_-a]* (#133492) 2024-12-12 17:39:14 +00:00
utils.py [while_loop] data-dependent op in body_fn (#142031) 2024-12-10 21:54:28 +00:00
while_loop.py Remove unused Python variables in torch/[_-a]* (#133492) 2024-12-12 17:39:14 +00:00
wrap.py