pytorch/torch/_higher_order_ops
drisspg 0f9eea1329 [FlexAttention] Fix multiple calls to flex bug (#140761)
# Summary
Fixes long-standing bug we've had in the backward pass for flex attention. See https://github.com/pytorch/pytorch/issues/135161 for details

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140761
Approved by: https://github.com/Chillee, https://github.com/zou3519
2024-11-16 04:57:04 +00:00
..
__init__.py Add base class for single-subgraph inductor HOPs (#139898) 2024-11-11 16:12:35 +00:00
associative_scan.py Improvements for associative_scan - slicing of xs (#138858) 2024-11-05 23:38:21 +00:00
auto_functionalize.py Generate slice.Tensor view operations instead of as_strided when split is used in the original program. (#137225) 2024-10-23 17:42:16 +00:00
cond.py [hop free symbols] lift free symbols in example_value when create_graph_input (#138363) 2024-11-07 04:44:32 +00:00
effects.py
executorch_call_delegate.py
flex_attention.py [FlexAttention] Fix multiple calls to flex bug (#140761) 2024-11-16 04:57:04 +00:00
hints_wrap.py
invoke_subgraph.py [invoke_subgraph] Support symint/int as inputs (#140058) 2024-11-11 22:26:43 +00:00
map.py [hop free symbols][refactor] make map's save_for_backward to handle int (#138558) 2024-11-04 22:48:07 +00:00
out_dtype.py
prim_hop_base.py Add base class for single-subgraph inductor HOPs (#139898) 2024-11-11 16:12:35 +00:00
run_const_graph.py
scan.py [hop free symbols] lift free symbols in example_value when create_graph_input (#138363) 2024-11-07 04:44:32 +00:00
strict_mode.py
torchbind.py
triton_kernel_wrap.py [RFC] Implement caching for user defined triton kernels (#140326) 2024-11-16 02:37:16 +00:00
utils.py [hop free symbols] lift free symbols in example_value when create_graph_input (#138363) 2024-11-07 04:44:32 +00:00
while_loop.py [hop free symbols] lift free symbols in example_value when create_graph_input (#138363) 2024-11-07 04:44:32 +00:00
wrap.py