pytorch/torch/distributed
Horace He 547bef11ee tweak heuristic for sdpa selection based off of *data* (and a decision tree) (#99644)
High level approach:
1. I generated a bunch of data comparing FlashAttention and Cutlass implementations (https://pastebin.com/pe0j3YeK)
2. I trained a decision tree using standard train/val split methodology and hyperparameter sweeps (https://pastebin.com/fjYX1HjR).
2a. I did a bunch of feature augmentation to capture interactions between features.

The heuristic I ended up with is:
```
use_flash = seq_len / (num_heads * batch_size) > 6
```

TL;DR: On my dataset, where FlashAttention and Cutlass differ by more than 10%, the existing heuristic achieves 69% accuracy.  My new heuristic achieves 94% accuracy.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99644
Approved by: https://github.com/ngimel, https://github.com/drisspg
2023-04-21 23:28:44 +00:00
..
_composable [composable] Enable replicate + trec_shard overall (#98890) 2023-04-15 01:09:00 +00:00
_shard
_sharded_tensor
_sharding_spec
_spmd tweak heuristic for sdpa selection based off of *data* (and a decision tree) (#99644) 2023-04-21 23:28:44 +00:00
_tensor [DTensor] Change Sharding algorithm to be in line with `torch.chunk()` (#98722) 2023-04-21 02:05:22 +00:00
_tools
algorithms Enable fused optimizer for DP (#98270) 2023-04-13 20:16:32 +00:00
autograd
benchmarks
checkpoint Convert logging f-strings to use % format, part three (#98704) 2023-04-11 13:17:56 +00:00
elastic Convert logging f-strings to use % format, part five (#98765) 2023-04-11 13:17:59 +00:00
examples
fsdp [FSDP][BE] Remove unused code (#99731) 2023-04-21 23:11:37 +00:00
launcher Convert logging f-strings to use % format, part four (#98705) 2023-04-11 13:17:59 +00:00
nn Convert logging f-strings to use % format (#98697) 2023-04-10 12:19:31 +00:00
optim Convert logging f-strings to use % format, part four (#98705) 2023-04-11 13:17:59 +00:00
pipeline
rpc Convert logging f-strings to use % format, part five (#98765) 2023-04-11 13:17:59 +00:00
tensor
__init__.py Revert "[c10d] Faster coalescing (#98793)" 2023-04-21 09:15:04 +00:00
_composable_state.py
_functional_collectives.py Reland python ops (#99170) 2023-04-18 15:15:46 +00:00
argparse_util.py
c10d_error_logger.py
constants.py
CONTRIBUTING.md
distributed_c10d.py Revert "[c10d] Faster coalescing (#98793)" 2023-04-21 09:15:04 +00:00
launch.py
logging_handlers.py
remote_device.py
rendezvous.py
run.py Convert logging f-strings to use % format, part four (#98705) 2023-04-11 13:17:59 +00:00
utils.py DDP forward support custom stream accelerated copy. (#98723) 2023-04-14 20:19:56 +00:00