mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
High level approach: 1. I generated a bunch of data comparing FlashAttention and Cutlass implementations (https://pastebin.com/pe0j3YeK) 2. I trained a decision tree using standard train/val split methodology and hyperparameter sweeps (https://pastebin.com/fjYX1HjR). 2a. I did a bunch of feature augmentation to capture interactions between features. The heuristic I ended up with is: ``` use_flash = seq_len / (num_heads * batch_size) > 6 ``` TL;DR: On my dataset, where FlashAttention and Cutlass differ by more than 10%, the existing heuristic achieves 69% accuracy. My new heuristic achieves 94% accuracy. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99644 Approved by: https://github.com/ngimel, https://github.com/drisspg |
||
|---|---|---|
| .. | ||
| check_forward_backward_compatibility.py | ||
| dump_all_function_schemas.py | ||