onnxruntime/pyproject.toml
Vincent Wang b7408f7389
[ORTModule] ATen Efficient Attention and Triton Flash Attention (#17959)
This PR is to support efficient attention and flash attention in
ORTModule, including:
- Use ATen to call efficient attention, which requires PyTorch 2.2.0 dev
or newer. ORTMODULE_USE_EFFICIENT_ATTENTION=1 to enable.
- Integrate Triton Flash attention, which requires
triton==2.0.0.dev20221202. Need A100 or H100.
ORTMODULE_USE_FLASH_ATTENTION=1 to enable.
- A python transformer tool to match sub-graph by config and write
transformer quickly.

Current transformers supports attention mask for both efficient attn and
flash attn, and dropout for efficient attn only. To support more
training scenarios (such as causal mask in GPT2), more transformers need
to be added.

The feature is guarded by system environment variables, it won't effect
any current behavior if not enabled. Since it requires specific
PyTorch/Triton versions, related tests is not added for now.
2023-10-27 10:29:27 +08:00

95 lines
3.1 KiB
TOML

[tool.black]
line-length = 120
# NOTE: Do not extend the exclude list. Edit .lintrunner.toml instead
extend-exclude = "cmake|onnxruntime/core/flatbuffers/"
target-version = ["py37", "py38", "py39", "py310", "py311"]
[tool.isort]
# NOTE: Do not extend the exclude list. Edit .lintrunner.toml instead
profile = "black"
line_length = 120
extend_skip_glob = [
"cmake/*",
"orttraining/*",
"onnxruntime/core/flatbuffers/*",
]
[tool.pydocstyle]
convention = "google"
[tool.pylint.BASIC]
good-names = [
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n",
"p", "q", "r", "s", "t", "u", "v", "w", "ex", "Run", "_", "x", "y", "z"
]
[tool.pylint.messages_control]
disable = [
"format",
"line-too-long",
"import-error",
"no-name-in-module",
"no-member",
"too-many-arguments",
"too-many-locals",
"too-few-public-methods",
"missing-docstring",
"fixme",
]
[tool.pyright]
exclude = ["onnxruntime/core/flatbuffers/*"]
reportMissingImports = false
[tool.ruff]
# NOTE: Do not create an exclude list. Edit .lintrunner.toml instead
target-version = "py38"
select = [
"B", # flake8-bugbear
"E", # pycodestyle
"F", # Pyflakes
"ISC", # flake8-implicit-str-concat
"N", # pep8-naming
"NPY", # numpy
"PERF", # Perflint
"PLC", # pylint conventions
"PLE", # pylint errors
"PLW", # pylint warnings
"RUF", # Ruff-specific rules
"SIM", # flake8-simplify
"T10", # flake8-debugger
"UP", # pyupgrade
"W", # pycodestyle
"YTT", # flake8-2020
]
# NOTE: Refrain from growing the ignore list unless for exceptional cases.
# Always include a comment to explain why.
ignore = [
"B028", # FIXME: Add stacklevel to warnings
"E501", # Line length controlled by black
"N803", # Argument casing
"N812", # Allow import torch.nn.functional as F
"N999", # Module names
"NPY002", # np.random.Generator may not always fit our use cases
"PERF203", # "try-except-in-loop" only affects Python <3.11, and the improvement is minor; can have false positives
"PERF401", # List comprehensions are not always readable
"SIM102", # We don't perfer always combining if branches
"SIM108", # We don't encourage ternary operators
"SIM114", # Don't combine if branches for debugability
"SIM116", # Don't use dict lookup to replace if-else
]
ignore-init-module-imports = true
unfixable = [
"F401", # Unused imports
"SIM112", # Use upper case for env vars
]
[tool.ruff.per-file-ignores]
# NOTE: Refrain from growing the ignore list unless for exceptional cases.
# Prefer inline ignores with `noqa: xxx`.
# Eventually this list should become empty.
"orttraining/orttraining/test/**" = ["N802"] # Function casing
"tools/nuget/generate_nuspec_for_native_nuget.py" = ["ISC003"] # Too many errors to fix
"onnxruntime/test/python/quantization/test_op_gemm.py" = ["N806"] # use of A for a matrix
"onnxruntime/test/python/quantization/op_test_utils.py" = ["N806", "PERF203", "RUF012"] # use of A for a matrix
"orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py" = ["N806", "PLW2901", "ISC001", "E731"] # Long triton code from other repo.