pytorch/torchgen
Alnis Murtovi 48929184e9 AutoHeuristic: mixed_mm heuristic for A100 (#131613)
This PR introduces changes to AutoHeuristic that allow one to learn a heuristic as a decision tree. I used this to learn a heuristic for mixed_mm on A100 that consistenly performs better than the default choice (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/mm.py#L402).

This is how the results look like:
Explanation of columns:
**wrong_max_spdup**: In the worst case, how much better would the best choice have been
**wrong_gman_spdup**: For inputs where the heuristic is wrong, how much better is the best choice on average (geomean)
**max_spdup_default**: Highest speedup achieved by the learned heuristic over the default choice
**gman_spdup_default**: Geomean speedup achived by the learned heuristic over the default choice
**max_slowdown_default**: If the default choice is better than the choice predicted by the learned heuristic, how much is it better in the worst case
**non_default_preds**: Number of times the learned heuristic predicted a choice that is not the default choice
**default_better**: Number of times the default choice is better than the choice made by the heuristic
```
  set     crit  max_depth  min_samples_leaf  correct  wrong  unsure  total  wrong_max_spdup  wrong_gman_spdup    max_spdup_default  gman_spdup_default  max_slowdown_default  non_default_preds  default_better
train  entropy          5              0.01     2376    740     323   3439         1.855386          1.063236            11.352318            3.438279              1.022164               3116               2
 test  entropy          5              0.01      563    183      71    817         1.622222          1.060897            10.084181            3.507741              1.017039                746               2
```

While the number of wrong predictions is high, on average the best choice is only around 6% better. What is important is that the choice predicted by the learned heuristic performs better than the default choice.

I evaluated my heuristic on gpt-fast `meta-llama/Llama-2-7b-chat-hf` with int8 weight quantization. To get the `tuned_mixed_mm` to trigger, I had to replace `F.linear()` in https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py#L355 with `torch.matmul(input, self.weight.t().to(dtype=input.dtype))` because the mixed_mm pattern does not match if there is a transpose between a cast and the matmul.
|batch size|prompt length| fallback    |  heuristic  | speedup |
|----------|-------------|------------:|------------:|--------:|
|     1    |      7      | 75.31 tok/s | 148.83 tok/s|  1.97   |
|     1    |     11      | 75.99 tok/s | 148.15 tok/s|  1.94   |
|     4    |      7      | 103.48 tok/s | 472.00 tok/s|  4.56   |
|     4    |     11      | 103.56 tok/s |  371.36 tok/s|  3.58   |
|     8    |      7      | 201.92 tok/s | 813.44 tok/s|  4.02   |
|     8    |     11      | 201.76 tok/s |  699.36 tok/s|  3.46   |

Currently, the heuristic only applies to the following inputs:
- m <= 128, k >= 1024, n >= 1024 (For these sizes, one of the triton kernels wins in most cases, but the heuristic still has to be careful to not choose a config that performs worse than the fallback)
- k % 256 == 0 (If k is not a multiple of the block size, some choices perform extremely bad. In one case one config, that usually performs very well, was 130x slower.)
- mat1 not transposed
- mat2 transposed (In some cases, it was hard for the learned heuristic to detect some cases where it

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131613
Approved by: https://github.com/eellison
2024-08-02 13:54:37 +00:00
..
_autoheuristic AutoHeuristic: mixed_mm heuristic for A100 (#131613) 2024-08-02 13:54:37 +00:00
aoti [cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343) 2024-06-30 19:22:16 +00:00
api [BE][Easy][5/19] enforce style for empty lines in import segments in tools/ and torchgen/ (#129756) 2024-07-17 06:44:35 +00:00
decompositions [BE][Easy] eliminate relative import in torchgen (#128872) 2024-06-21 14:11:46 +00:00
dest [Intel GPU] xpu-ops codegen via backend whitelist (#130082) 2024-07-31 16:31:38 +00:00
executorch [12/N] Use std::optional (#132361) 2024-08-02 13:46:46 +00:00
fuse [BE] update type annotations for basic utilities in torch/__init__.py (#129001) 2024-06-24 18:04:38 +00:00
operator_versions [BE][Easy] enable postponed annotations in torchgen (#129376) 2024-06-29 09:23:39 +00:00
selective_build [BE][Easy] enable postponed annotations in torchgen (#129376) 2024-06-29 09:23:39 +00:00
shape_functions [BE][Easy] enable postponed annotations in torchgen (#129376) 2024-06-29 09:23:39 +00:00
static_runtime [BE][Easy] enable postponed annotations in torchgen (#129376) 2024-06-29 09:23:39 +00:00
__init__.py
BUCK.oss
BUILD.bazel
build.bzl
code_template.py [BE][Easy] enable postponed annotations in torchgen (#129376) 2024-06-29 09:23:39 +00:00
context.py [BE][Easy] enable postponed annotations in torchgen (#129376) 2024-06-29 09:23:39 +00:00
gen.py Include _native.h for structured_native_functions (#131208) 2024-07-24 02:55:36 +00:00
gen_aoti_c_shim.py [cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343) 2024-06-30 19:22:16 +00:00
gen_backend_stubs.py [BE][Easy] replace import pathlib with from pathlib import Path (#129426) 2024-06-30 01:36:07 +00:00
gen_executorch.py [BE][Easy] enable postponed annotations in torchgen (#129376) 2024-06-29 09:23:39 +00:00
gen_functionalization_type.py propagate XLA's metadata after functional sync (#131076) 2024-07-31 18:20:00 +00:00
gen_lazy_tensor.py [3/N] Change #include <c10/util/Optional.h> to #include <optional> (#130300) 2024-07-09 13:32:57 +00:00
gen_vmap_plumbing.py [12/N] Use std::optional (#132361) 2024-08-02 13:46:46 +00:00
local.py [BE][Easy] enable postponed annotations in torchgen (#129376) 2024-06-29 09:23:39 +00:00
model.py [Intel GPU] xpu-ops codegen via backend whitelist (#130082) 2024-07-31 16:31:38 +00:00
native_function_generation.py [BE][Easy] enable postponed annotations in torchgen (#129376) 2024-06-29 09:23:39 +00:00
utils.py [torchgen] reference generated comment to actual location of the generator and template (#130020) 2024-07-05 21:47:14 +00:00
yaml_utils.py