mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
This PR is to support efficient attention and flash attention in ORTModule, including: - Use ATen to call efficient attention, which requires PyTorch 2.2.0 dev or newer. ORTMODULE_USE_EFFICIENT_ATTENTION=1 to enable. - Integrate Triton Flash attention, which requires triton==2.0.0.dev20221202. Need A100 or H100. ORTMODULE_USE_FLASH_ATTENTION=1 to enable. - A python transformer tool to match sub-graph by config and write transformer quickly. Current transformers supports attention mask for both efficient attn and flash attn, and dropout for efficient attn only. To support more training scenarios (such as causal mask in GPT2), more transformers need to be added. The feature is guarded by system environment variables, it won't effect any current behavior if not enabled. Since it requires specific PyTorch/Triton versions, related tests is not added for now.
95 lines
3.1 KiB
TOML
95 lines
3.1 KiB
TOML
[tool.black]
|
|
line-length = 120
|
|
# NOTE: Do not extend the exclude list. Edit .lintrunner.toml instead
|
|
extend-exclude = "cmake|onnxruntime/core/flatbuffers/"
|
|
target-version = ["py37", "py38", "py39", "py310", "py311"]
|
|
|
|
[tool.isort]
|
|
# NOTE: Do not extend the exclude list. Edit .lintrunner.toml instead
|
|
profile = "black"
|
|
line_length = 120
|
|
extend_skip_glob = [
|
|
"cmake/*",
|
|
"orttraining/*",
|
|
"onnxruntime/core/flatbuffers/*",
|
|
]
|
|
|
|
[tool.pydocstyle]
|
|
convention = "google"
|
|
|
|
[tool.pylint.BASIC]
|
|
good-names = [
|
|
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n",
|
|
"p", "q", "r", "s", "t", "u", "v", "w", "ex", "Run", "_", "x", "y", "z"
|
|
]
|
|
|
|
[tool.pylint.messages_control]
|
|
disable = [
|
|
"format",
|
|
"line-too-long",
|
|
"import-error",
|
|
"no-name-in-module",
|
|
"no-member",
|
|
"too-many-arguments",
|
|
"too-many-locals",
|
|
"too-few-public-methods",
|
|
"missing-docstring",
|
|
"fixme",
|
|
]
|
|
|
|
[tool.pyright]
|
|
exclude = ["onnxruntime/core/flatbuffers/*"]
|
|
reportMissingImports = false
|
|
|
|
[tool.ruff]
|
|
# NOTE: Do not create an exclude list. Edit .lintrunner.toml instead
|
|
target-version = "py38"
|
|
select = [
|
|
"B", # flake8-bugbear
|
|
"E", # pycodestyle
|
|
"F", # Pyflakes
|
|
"ISC", # flake8-implicit-str-concat
|
|
"N", # pep8-naming
|
|
"NPY", # numpy
|
|
"PERF", # Perflint
|
|
"PLC", # pylint conventions
|
|
"PLE", # pylint errors
|
|
"PLW", # pylint warnings
|
|
"RUF", # Ruff-specific rules
|
|
"SIM", # flake8-simplify
|
|
"T10", # flake8-debugger
|
|
"UP", # pyupgrade
|
|
"W", # pycodestyle
|
|
"YTT", # flake8-2020
|
|
]
|
|
# NOTE: Refrain from growing the ignore list unless for exceptional cases.
|
|
# Always include a comment to explain why.
|
|
ignore = [
|
|
"B028", # FIXME: Add stacklevel to warnings
|
|
"E501", # Line length controlled by black
|
|
"N803", # Argument casing
|
|
"N812", # Allow import torch.nn.functional as F
|
|
"N999", # Module names
|
|
"NPY002", # np.random.Generator may not always fit our use cases
|
|
"PERF203", # "try-except-in-loop" only affects Python <3.11, and the improvement is minor; can have false positives
|
|
"PERF401", # List comprehensions are not always readable
|
|
"SIM102", # We don't perfer always combining if branches
|
|
"SIM108", # We don't encourage ternary operators
|
|
"SIM114", # Don't combine if branches for debugability
|
|
"SIM116", # Don't use dict lookup to replace if-else
|
|
]
|
|
ignore-init-module-imports = true
|
|
unfixable = [
|
|
"F401", # Unused imports
|
|
"SIM112", # Use upper case for env vars
|
|
]
|
|
|
|
[tool.ruff.per-file-ignores]
|
|
# NOTE: Refrain from growing the ignore list unless for exceptional cases.
|
|
# Prefer inline ignores with `noqa: xxx`.
|
|
# Eventually this list should become empty.
|
|
"orttraining/orttraining/test/**" = ["N802"] # Function casing
|
|
"tools/nuget/generate_nuspec_for_native_nuget.py" = ["ISC003"] # Too many errors to fix
|
|
"onnxruntime/test/python/quantization/test_op_gemm.py" = ["N806"] # use of A for a matrix
|
|
"onnxruntime/test/python/quantization/op_test_utils.py" = ["N806", "PERF203", "RUF012"] # use of A for a matrix
|
|
"orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py" = ["N806", "PLW2901", "ISC001", "E731"] # Long triton code from other repo.
|