mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
High level approach: 1. I generated a bunch of data comparing FlashAttention and Cutlass implementations (https://pastebin.com/pe0j3YeK) 2. I trained a decision tree using standard train/val split methodology and hyperparameter sweeps (https://pastebin.com/fjYX1HjR). 2a. I did a bunch of feature augmentation to capture interactions between features. The heuristic I ended up with is: ``` use_flash = seq_len / (num_heads * batch_size) > 6 ``` TL;DR: On my dataset, where FlashAttention and Cutlass differ by more than 10%, the existing heuristic achieves 69% accuracy. My new heuristic achieves 94% accuracy. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99644 Approved by: https://github.com/ngimel, https://github.com/drisspg |
||
|---|---|---|
| .. | ||
| _composable | ||
| _shard | ||
| _sharded_tensor | ||
| _sharding_spec | ||
| _spmd | ||
| _tensor | ||
| _tools | ||
| algorithms | ||
| autograd | ||
| benchmarks | ||
| checkpoint | ||
| elastic | ||
| examples | ||
| fsdp | ||
| launcher | ||
| nn | ||
| optim | ||
| pipeline | ||
| rpc | ||
| tensor | ||
| __init__.py | ||
| _composable_state.py | ||
| _functional_collectives.py | ||
| argparse_util.py | ||
| c10d_error_logger.py | ||
| constants.py | ||
| CONTRIBUTING.md | ||
| distributed_c10d.py | ||
| launch.py | ||
| logging_handlers.py | ||
| remote_device.py | ||
| rendezvous.py | ||
| run.py | ||
| utils.py | ||