onnxruntime/orttraining/orttraining/python
Vincent Wang b7408f7389
[ORTModule] ATen Efficient Attention and Triton Flash Attention (#17959)
This PR is to support efficient attention and flash attention in
ORTModule, including:
- Use ATen to call efficient attention, which requires PyTorch 2.2.0 dev
or newer. ORTMODULE_USE_EFFICIENT_ATTENTION=1 to enable.
- Integrate Triton Flash attention, which requires
triton==2.0.0.dev20221202. Need A100 or H100.
ORTMODULE_USE_FLASH_ATTENTION=1 to enable.
- A python transformer tool to match sub-graph by config and write
transformer quickly.

Current transformers supports attention mask for both efficient attn and
flash attn, and dropout for efficient attn only. To support more
training scenarios (such as causal mask in GPT2), more transformers need
to be added.

The feature is guarded by system environment variables, it won't effect
any current behavior if not enabled. Since it requires specific
PyTorch/Triton versions, related tests is not added for now.
2023-10-27 10:29:27 +08:00
..
deprecated Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
training [ORTModule] ATen Efficient Attention and Triton Flash Attention (#17959) 2023-10-27 10:29:27 +08:00
checkpointing_utils.py [Better Engineering] Bump ruff to 0.0.278 and fix new lint errors (#16789) 2023-07-21 12:53:41 -07:00
ort_trainer.py [Linter] Bump ruff and remove pylint (#17797) 2023-10-05 21:07:33 -07:00
orttraining_pybind_common.h Re-work global objects dependancies in pybind layer. (#14941) 2023-03-10 13:55:31 -08:00
orttraining_pybind_state.cc Support inplace update for PythonOp/Grad (#17687) 2023-10-10 21:36:45 -07:00
orttraining_python_module.cc Python API to check whether collective ops are available or not (#17730) 2023-09-29 14:11:05 -07:00
orttraining_python_module_eager.h Run clang-format in CI (#15524) 2023-04-18 09:26:58 -07:00
pt_patch.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00