2022-06-17 04:21:14 +00:00
|
|
|
[build-system]
|
|
|
|
|
requires = [
|
|
|
|
|
"setuptools",
|
|
|
|
|
"wheel",
|
|
|
|
|
"astunparse",
|
|
|
|
|
"numpy",
|
|
|
|
|
"ninja",
|
|
|
|
|
"pyyaml",
|
|
|
|
|
"cmake",
|
2023-02-09 19:17:46 +00:00
|
|
|
"typing-extensions",
|
2022-06-17 04:21:14 +00:00
|
|
|
"requests",
|
|
|
|
|
]
|
|
|
|
|
# Use legacy backend to import local packages in setup.py
|
|
|
|
|
build-backend = "setuptools.build_meta:__legacy__"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
[tool.black]
|
2024-04-13 12:54:14 +00:00
|
|
|
line-length = 88
|
2024-08-04 02:41:35 +00:00
|
|
|
target-version = ["py38"]
|
2023-04-24 15:37:13 +00:00
|
|
|
|
|
|
|
|
|
2024-04-13 12:54:14 +00:00
|
|
|
[tool.isort]
|
2024-05-25 16:21:08 +00:00
|
|
|
src_paths = ["caffe2", "torch", "torchgen", "functorch", "test"]
|
2024-04-13 12:54:14 +00:00
|
|
|
extra_standard_library = ["typing_extensions"]
|
|
|
|
|
skip_gitignore = true
|
|
|
|
|
skip_glob = ["third_party/*"]
|
|
|
|
|
atomic = true
|
|
|
|
|
profile = "black"
|
|
|
|
|
indent = 4
|
|
|
|
|
line_length = 88
|
|
|
|
|
lines_after_imports = 2
|
|
|
|
|
multi_line_output = 3
|
|
|
|
|
include_trailing_comma = true
|
2024-06-29 13:06:58 +00:00
|
|
|
combine_as_imports = true
|
2024-04-13 12:54:14 +00:00
|
|
|
|
|
|
|
|
|
2024-05-24 19:57:36 +00:00
|
|
|
[tool.usort.known]
|
2024-05-27 09:47:26 +00:00
|
|
|
first_party = ["caffe2", "torch", "torchgen", "functorch", "test"]
|
2024-05-24 19:57:36 +00:00
|
|
|
standard_library = ["typing_extensions"]
|
|
|
|
|
|
2024-04-13 12:54:14 +00:00
|
|
|
|
2023-04-24 15:37:13 +00:00
|
|
|
[tool.ruff]
|
|
|
|
|
target-version = "py38"
|
2024-02-24 07:13:53 +00:00
|
|
|
line-length = 120
|
2024-08-04 02:41:35 +00:00
|
|
|
src = ["caffe2", "torch", "torchgen", "functorch", "test"]
|
2023-04-24 15:37:13 +00:00
|
|
|
|
2024-02-24 07:13:53 +00:00
|
|
|
[tool.ruff.lint]
|
2023-04-24 15:37:13 +00:00
|
|
|
# NOTE: Synchoronize the ignores with .flake8
|
|
|
|
|
ignore = [
|
|
|
|
|
# these ignores are from flake8-bugbear; please fix!
|
|
|
|
|
"B007", "B008", "B017",
|
|
|
|
|
"B018", # Useless expression
|
2023-11-07 21:38:13 +00:00
|
|
|
"B023",
|
2023-04-24 15:37:13 +00:00
|
|
|
"B028", # No explicit `stacklevel` keyword argument found
|
|
|
|
|
"E402",
|
|
|
|
|
"C408", # C408 ignored because we like the dict keyword argument syntax
|
|
|
|
|
"E501", # E501 is not flexible enough, we're using B950 instead
|
|
|
|
|
"E721",
|
|
|
|
|
"E731", # Assign lambda expression
|
|
|
|
|
"E741",
|
|
|
|
|
"EXE001",
|
|
|
|
|
"F405",
|
|
|
|
|
"F841",
|
|
|
|
|
# these ignores are from flake8-logging-format; please fix!
|
2023-11-27 17:38:08 +00:00
|
|
|
"G101",
|
2023-11-27 18:56:10 +00:00
|
|
|
# these ignores are from ruff NPY; please fix!
|
|
|
|
|
"NPY002",
|
|
|
|
|
# these ignores are from ruff PERF; please fix!
|
2023-12-20 18:01:20 +00:00
|
|
|
"PERF203",
|
|
|
|
|
"PERF401",
|
|
|
|
|
"PERF403",
|
2023-11-03 17:25:35 +00:00
|
|
|
# these ignores are from PYI; please fix!
|
2023-10-09 16:37:23 +00:00
|
|
|
"PYI024",
|
|
|
|
|
"PYI036",
|
|
|
|
|
"PYI041",
|
|
|
|
|
"PYI056",
|
2023-04-24 15:37:13 +00:00
|
|
|
"SIM102", "SIM103", "SIM112", # flake8-simplify code styles
|
2024-02-24 07:13:53 +00:00
|
|
|
"SIM113", # please fix
|
2023-04-24 15:37:13 +00:00
|
|
|
"SIM105", # these ignores are from flake8-simplify. please fix or ignore with commented reason
|
2024-07-11 12:40:53 +00:00
|
|
|
"SIM108", # SIM108 ignored because we prefer if-else-block instead of ternary expression
|
2023-04-24 15:37:13 +00:00
|
|
|
"SIM110",
|
|
|
|
|
"SIM114", # Combine `if` branches using logical `or` operator
|
|
|
|
|
"SIM115",
|
|
|
|
|
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements
|
|
|
|
|
"SIM117",
|
|
|
|
|
"SIM118",
|
2023-07-21 15:23:47 +00:00
|
|
|
"UP006", # keep-runtime-typing
|
|
|
|
|
"UP007", # keep-runtime-typing
|
2023-04-24 15:37:13 +00:00
|
|
|
]
|
|
|
|
|
select = [
|
|
|
|
|
"B",
|
2024-05-17 16:31:01 +00:00
|
|
|
"B904", # Re-raised error without specifying the cause via the from keyword
|
2023-04-24 15:37:13 +00:00
|
|
|
"C4",
|
|
|
|
|
"G",
|
|
|
|
|
"E",
|
2023-11-28 01:27:51 +00:00
|
|
|
"EXE",
|
2023-04-24 15:37:13 +00:00
|
|
|
"F",
|
|
|
|
|
"SIM1",
|
2024-07-17 02:06:02 +00:00
|
|
|
"SIM911",
|
2023-04-24 15:37:13 +00:00
|
|
|
"W",
|
2023-05-11 23:57:25 +00:00
|
|
|
# Not included in flake8
|
2024-07-02 14:47:08 +00:00
|
|
|
"FURB",
|
2024-02-27 04:37:17 +00:00
|
|
|
"LOG",
|
2023-11-27 18:56:10 +00:00
|
|
|
"NPY",
|
2023-07-11 20:45:18 +00:00
|
|
|
"PERF",
|
2023-07-22 23:03:32 +00:00
|
|
|
"PGH004",
|
2023-11-05 22:11:50 +00:00
|
|
|
"PIE794",
|
2023-11-16 22:34:34 +00:00
|
|
|
"PIE800",
|
2023-11-17 21:22:58 +00:00
|
|
|
"PIE804",
|
2023-07-28 22:35:53 +00:00
|
|
|
"PIE807",
|
|
|
|
|
"PIE810",
|
2023-12-02 20:35:08 +00:00
|
|
|
"PLC0131", # type bivariance
|
|
|
|
|
"PLC0132", # type param mismatch
|
|
|
|
|
"PLC0205", # string as __slots__
|
2024-07-17 02:06:02 +00:00
|
|
|
"PLC3002", # unnecessary-direct-lambda-call
|
2023-05-11 23:57:25 +00:00
|
|
|
"PLE",
|
2023-11-28 20:49:03 +00:00
|
|
|
"PLR0133", # constant comparison
|
|
|
|
|
"PLR0206", # property with params
|
2023-09-18 02:07:18 +00:00
|
|
|
"PLR1722", # use sys exit
|
2024-07-02 14:47:08 +00:00
|
|
|
"PLR1736", # unnecessary list index
|
2023-11-29 20:53:22 +00:00
|
|
|
"PLW0129", # assert on string literal
|
2024-07-02 14:47:08 +00:00
|
|
|
"PLW0133", # useless exception statement
|
2023-11-29 20:53:22 +00:00
|
|
|
"PLW0406", # import self
|
|
|
|
|
"PLW0711", # binary op exception
|
|
|
|
|
"PLW1509", # preexec_fn not safe with threads
|
2024-07-02 14:47:08 +00:00
|
|
|
"PLW2101", # useless lock statement
|
2023-09-18 02:07:18 +00:00
|
|
|
"PLW3301", # nested min max
|
2023-10-05 21:40:43 +00:00
|
|
|
"PT006", # TODO: enable more PT rules
|
|
|
|
|
"PT022",
|
|
|
|
|
"PT023",
|
|
|
|
|
"PT024",
|
|
|
|
|
"PT025",
|
|
|
|
|
"PT026",
|
2023-10-09 16:37:23 +00:00
|
|
|
"PYI",
|
2024-06-02 23:25:26 +00:00
|
|
|
"Q003", # avoidable escaped quote
|
|
|
|
|
"Q004", # unnecessary escaped quote
|
2024-04-17 19:29:30 +00:00
|
|
|
"RSE",
|
2023-12-20 21:16:45 +00:00
|
|
|
"RUF008", # mutable dataclass default
|
2023-12-11 15:51:01 +00:00
|
|
|
"RUF015", # access first ele in constant time
|
2023-12-20 21:16:45 +00:00
|
|
|
"RUF016", # type error non-integer index
|
2023-08-22 23:16:35 +00:00
|
|
|
"RUF017",
|
2024-04-28 21:41:34 +00:00
|
|
|
"RUF018", # no assignment in assert
|
2024-07-17 02:06:02 +00:00
|
|
|
"RUF019", # unnecessary-key-check
|
2024-07-02 14:47:08 +00:00
|
|
|
"RUF024", # from keys mutable
|
|
|
|
|
"RUF026", # default factory kwarg
|
2024-06-06 16:55:56 +00:00
|
|
|
"TCH",
|
2024-04-21 22:26:40 +00:00
|
|
|
"TRY002", # ban vanilla raise (todo fix NOQAs)
|
2023-05-19 17:30:47 +00:00
|
|
|
"TRY302",
|
2024-04-28 21:44:30 +00:00
|
|
|
"TRY401", # verbose-log-message
|
2023-11-27 18:56:10 +00:00
|
|
|
"UP",
|
2023-04-24 15:37:13 +00:00
|
|
|
]
|
|
|
|
|
|
2024-02-24 07:13:53 +00:00
|
|
|
[tool.ruff.lint.per-file-ignores]
|
2024-01-01 08:40:46 +00:00
|
|
|
"__init__.py" = [
|
|
|
|
|
"F401",
|
|
|
|
|
]
|
2024-02-24 07:13:53 +00:00
|
|
|
"functorch/notebooks/**" = [
|
|
|
|
|
"F401",
|
|
|
|
|
]
|
2024-01-01 08:40:46 +00:00
|
|
|
"test/typing/reveal/**" = [
|
|
|
|
|
"F821",
|
|
|
|
|
]
|
|
|
|
|
"test/torch_np/numpy_tests/**" = [
|
|
|
|
|
"F821",
|
2024-06-28 21:49:53 +00:00
|
|
|
"NPY201",
|
2024-01-01 08:40:46 +00:00
|
|
|
]
|
2024-05-28 20:06:20 +00:00
|
|
|
"test/dynamo/test_bytecode_utils.py" = [
|
|
|
|
|
"F821",
|
|
|
|
|
]
|
2024-05-24 18:38:33 +00:00
|
|
|
"test/dynamo/test_debug_utils.py" = [
|
|
|
|
|
"UP037",
|
|
|
|
|
]
|
2023-07-21 15:23:47 +00:00
|
|
|
"test/jit/**" = [
|
2023-11-28 20:49:03 +00:00
|
|
|
"PLR0133", # tests require this for JIT
|
2023-10-09 16:37:23 +00:00
|
|
|
"PYI",
|
2023-12-11 15:51:01 +00:00
|
|
|
"RUF015",
|
2023-07-21 15:23:47 +00:00
|
|
|
"UP", # We don't want to modify the jit test as they test specify syntax
|
|
|
|
|
]
|
2023-11-28 20:49:03 +00:00
|
|
|
"test/test_jit.py" = [
|
|
|
|
|
"PLR0133", # tests require this for JIT
|
|
|
|
|
"PYI",
|
2023-12-11 15:51:01 +00:00
|
|
|
"RUF015",
|
2023-11-28 20:49:03 +00:00
|
|
|
"UP", # We don't want to modify the jit test as they test specify syntax
|
|
|
|
|
]
|
2024-05-24 18:38:33 +00:00
|
|
|
"test/inductor/test_torchinductor.py" = [
|
|
|
|
|
"UP037",
|
|
|
|
|
]
|
2024-02-24 07:13:53 +00:00
|
|
|
# autogenerated #TODO figure out why file level noqa is ignored
|
|
|
|
|
"torch/_inductor/fx_passes/serialized_patterns/**" = ["F401", "F501"]
|
AutoHeuristic: mixed_mm heuristic for A100 (#131613)
This PR introduces changes to AutoHeuristic that allow one to learn a heuristic as a decision tree. I used this to learn a heuristic for mixed_mm on A100 that consistenly performs better than the default choice (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/mm.py#L402).
This is how the results look like:
Explanation of columns:
**wrong_max_spdup**: In the worst case, how much better would the best choice have been
**wrong_gman_spdup**: For inputs where the heuristic is wrong, how much better is the best choice on average (geomean)
**max_spdup_default**: Highest speedup achieved by the learned heuristic over the default choice
**gman_spdup_default**: Geomean speedup achived by the learned heuristic over the default choice
**max_slowdown_default**: If the default choice is better than the choice predicted by the learned heuristic, how much is it better in the worst case
**non_default_preds**: Number of times the learned heuristic predicted a choice that is not the default choice
**default_better**: Number of times the default choice is better than the choice made by the heuristic
```
set crit max_depth min_samples_leaf correct wrong unsure total wrong_max_spdup wrong_gman_spdup max_spdup_default gman_spdup_default max_slowdown_default non_default_preds default_better
train entropy 5 0.01 2376 740 323 3439 1.855386 1.063236 11.352318 3.438279 1.022164 3116 2
test entropy 5 0.01 563 183 71 817 1.622222 1.060897 10.084181 3.507741 1.017039 746 2
```
While the number of wrong predictions is high, on average the best choice is only around 6% better. What is important is that the choice predicted by the learned heuristic performs better than the default choice.
I evaluated my heuristic on gpt-fast `meta-llama/Llama-2-7b-chat-hf` with int8 weight quantization. To get the `tuned_mixed_mm` to trigger, I had to replace `F.linear()` in https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py#L355 with `torch.matmul(input, self.weight.t().to(dtype=input.dtype))` because the mixed_mm pattern does not match if there is a transpose between a cast and the matmul.
|batch size|prompt length| fallback | heuristic | speedup |
|----------|-------------|------------:|------------:|--------:|
| 1 | 7 | 75.31 tok/s | 148.83 tok/s| 1.97 |
| 1 | 11 | 75.99 tok/s | 148.15 tok/s| 1.94 |
| 4 | 7 | 103.48 tok/s | 472.00 tok/s| 4.56 |
| 4 | 11 | 103.56 tok/s | 371.36 tok/s| 3.58 |
| 8 | 7 | 201.92 tok/s | 813.44 tok/s| 4.02 |
| 8 | 11 | 201.76 tok/s | 699.36 tok/s| 3.46 |
Currently, the heuristic only applies to the following inputs:
- m <= 128, k >= 1024, n >= 1024 (For these sizes, one of the triton kernels wins in most cases, but the heuristic still has to be careful to not choose a config that performs worse than the fallback)
- k % 256 == 0 (If k is not a multiple of the block size, some choices perform extremely bad. In one case one config, that usually performs very well, was 130x slower.)
- mat1 not transposed
- mat2 transposed (In some cases, it was hard for the learned heuristic to detect some cases where it
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131613
Approved by: https://github.com/eellison
2024-08-01 22:48:47 +00:00
|
|
|
"torch/_inductor/autoheuristic/artifacts/**" = ["F401", "F501"]
|
2023-04-24 15:37:13 +00:00
|
|
|
"torchgen/api/types/__init__.py" = [
|
|
|
|
|
"F401",
|
|
|
|
|
"F403",
|
|
|
|
|
]
|
|
|
|
|
"torchgen/executorch/api/types/__init__.py" = [
|
|
|
|
|
"F401",
|
|
|
|
|
"F403",
|
|
|
|
|
]
|
2023-07-21 15:23:47 +00:00
|
|
|
"torch/utils/collect_env.py" = [
|
|
|
|
|
"UP", # collect_env.py needs to work with older versions of Python
|
|
|
|
|
]
|
2024-05-12 20:02:34 +00:00
|
|
|
"torch/_vendor/**" = [
|
|
|
|
|
"UP", # No need to mess with _vendor
|
|
|
|
|
]
|