pytorch/test/forward_backward_compatibility
Horace He 547bef11ee tweak heuristic for sdpa selection based off of *data* (and a decision tree) (#99644)
High level approach:
1. I generated a bunch of data comparing FlashAttention and Cutlass implementations (https://pastebin.com/pe0j3YeK)
2. I trained a decision tree using standard train/val split methodology and hyperparameter sweeps (https://pastebin.com/fjYX1HjR).
2a. I did a bunch of feature augmentation to capture interactions between features.

The heuristic I ended up with is:
```
use_flash = seq_len / (num_heads * batch_size) > 6
```

TL;DR: On my dataset, where FlashAttention and Cutlass differ by more than 10%, the existing heuristic achieves 69% accuracy.  My new heuristic achieves 94% accuracy.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99644
Approved by: https://github.com/ngimel, https://github.com/drisspg
2023-04-21 23:28:44 +00:00
..
check_forward_backward_compatibility.py tweak heuristic for sdpa selection based off of *data* (and a decision tree) (#99644) 2023-04-21 23:28:44 +00:00
dump_all_function_schemas.py