onnxruntime/include/onnxruntime/core
pengwa 8a98874e7e
Flash attention recompute (#20603)
### Flash attn recompute

1. Allow PythonOp(FlashAttn) can be recomputed correctly.
45879ff5c2
2. Use JSON to pass the selected-to-recompute subgraphs.
3c374da678

#### Better Memory Efficiency 

Customer model can run both PyTorch SPDA and Flash Attn, this PR make it
possible to let the Flash Attn path work with ORTModule layerwise
recompute. The peak drop from 45.xGB to 32.xGB if we only compare the
layers (not including other pieces, BTW there are few more optimization
targeting other pieces as well later).

#### Better Perf

Using Flash ATTN bring additionally 16% end to end time reduction, with
highly aligned loss curve.


![image](https://github.com/microsoft/onnxruntime/assets/10530022/bb63894a-f281-49bc-a8e6-ff818439be38)

#### Use JSON File to pass Recompute Plans

To overcome the limitation of max length of the strings defined in
session options.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-05-21 13:38:19 +08:00
..
common Fix build errors from date/date.h C++20 compatibility (#20139) 2024-04-02 22:10:25 -07:00
eager
framework Expose Reserve() in OrtAllocator to allow custom allocators to work when session.use_device_allocator_for_initializers is specified. (#19904) 2024-03-28 12:28:37 -07:00
graph Introduce memory efficient topological sort (#20258) 2024-04-23 08:00:23 +08:00
optimizer
platform Bump linter versions (#18341) 2023-11-08 13:04:40 -08:00
providers [java][DML EP] Modifying dml_provider_factory.h so it can compile as a C header file (#20157) 2024-04-01 21:58:50 -07:00
session Flash attention recompute (#20603) 2024-05-21 13:38:19 +08:00