onnxruntime/include
pengwa 8a98874e7e
Flash attention recompute (#20603)
### Flash attn recompute

1. Allow PythonOp(FlashAttn) can be recomputed correctly.
45879ff5c2
2. Use JSON to pass the selected-to-recompute subgraphs.
3c374da678

#### Better Memory Efficiency 

Customer model can run both PyTorch SPDA and Flash Attn, this PR make it
possible to let the Flash Attn path work with ORTModule layerwise
recompute. The peak drop from 45.xGB to 32.xGB if we only compare the
layers (not including other pieces, BTW there are few more optimization
targeting other pieces as well later).

#### Better Perf

Using Flash ATTN bring additionally 16% end to end time reduction, with
highly aligned loss curve.


![image](https://github.com/microsoft/onnxruntime/assets/10530022/bb63894a-f281-49bc-a8e6-ff818439be38)

#### Use JSON File to pass Recompute Plans

To overcome the limitation of max length of the strings defined in
session options.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-05-21 13:38:19 +08:00
..
onnxruntime/core Flash attention recompute (#20603) 2024-05-21 13:38:19 +08:00