mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
[ORTModule] Adjust Attention Patterns for Efficient Attention ATen Fallback (#18471)
Adjust attention patterns to match latest Whisper+exporter. Also add some condition check and add docs.
This commit is contained in:
parent
7c573054b6
commit
3bc9efc7b2
4 changed files with 103 additions and 90 deletions
|
|
@ -269,6 +269,15 @@ data sparsity based performance optimizations.
|
|||
unset ORTMODULE_CACHE_DIR # Disable
|
||||
```
|
||||
|
||||
#### ORTMODULE_USE_EFFICIENT_ATTENTION
|
||||
|
||||
- **Feature Area**: *ORTMODULE/Optimizations*
|
||||
- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and falling back to PyTorch's efficient_attention ATen kernel for execution. NOTE that it requires torch's version is 2.1.1 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually.
|
||||
|
||||
```bash
|
||||
export ORTMODULE_USE_EFFICIENT_ATTENTION=1
|
||||
```
|
||||
|
||||
### 2.2 Memory Optimization
|
||||
|
||||
Q: *Want to run a bigger batch size?*
|
||||
|
|
@ -397,6 +406,15 @@ Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training
|
|||
export ORTMODULE_TUNING_RESULTS_PATH=/tmp/tuning_results
|
||||
```
|
||||
|
||||
#### ORTMODULE_USE_FLASH_ATTENTION
|
||||
|
||||
- **Feature Area**: *ORTMODULE/TritonOp*
|
||||
- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and using Flash Attention's Triton version as the kernel. NOTE that it requires ORTMODULE_USE_TRITON to be enabled, and CUDA device capability is 8.0 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually.
|
||||
|
||||
```bash
|
||||
export ORTMODULE_USE_FLASH_ATTENTION=1
|
||||
```
|
||||
|
||||
#### ORTMODULE_TRITON_DEBUG
|
||||
|
||||
- **Feature Area**: *ORTMODULE/TritonOp*
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out # noqa: F401
|
||||
from ._slice_scel import slice_scel, slice_scel_backward # noqa: F401
|
||||
|
||||
|
|
@ -17,7 +19,12 @@ _all_kernels = [
|
|||
"slice_scel_backward",
|
||||
]
|
||||
|
||||
if "ORTMODULE_USE_FLASH_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1:
|
||||
if (
|
||||
"ORTMODULE_USE_FLASH_ATTENTION" in os.environ
|
||||
and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1
|
||||
and torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability()[0] >= 8
|
||||
):
|
||||
from ._flash_attn import flash_attn_backward, flash_attn_forward # noqa: F401
|
||||
|
||||
_all_kernels.extend(["flash_attn_forward", "flash_attn_backward"])
|
||||
|
|
|
|||
|
|
@ -5,9 +5,16 @@
|
|||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
|
||||
_all_optimizers = []
|
||||
|
||||
if "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1:
|
||||
if (
|
||||
"ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ
|
||||
and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1
|
||||
and Version(torch.__version__) >= Version("2.1.1")
|
||||
):
|
||||
from ._aten_attn import optimize_graph_for_aten_efficient_attention # noqa: F401
|
||||
|
||||
_all_optimizers.append("optimize_graph_for_aten_efficient_attention")
|
||||
|
|
|
|||
|
|
@ -245,31 +245,25 @@ _PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
|
|||
("MatMul", False, []), # 0
|
||||
("Mul", True, [(0, 0, 0)]), # 1
|
||||
("Mul", True, [(0, 0, 1)]), # 2
|
||||
("Cast", True, [(1, 0, 0)]), # 3
|
||||
("Cast", True, [(2, 0, 0)]), # 4
|
||||
("Transpose", True, [(3, 0, 0)]), # 5
|
||||
("Transpose", True, [(4, 0, 0)]), # 6
|
||||
("Softmax", False, [(0, 0, 0)]), # 7
|
||||
("Cast", False, [(7, 0, 0)]), # 8
|
||||
("MatMul", False, [(8, 0, 0)]), # 9
|
||||
("Transpose", True, [(9, 0, 1)]), # 10
|
||||
("Transpose", False, [(9, 0, 0)]), # 11
|
||||
("FusedMatMul", False, [(10, 0, 1)]), # 12
|
||||
("Cast", False, [(12, 0, 0)]), # 13
|
||||
("SoftmaxGrad_13", False, [(13, 0, 0), (7, 0, 1)]), # 14
|
||||
("FusedMatMul", False, [(2, 0, 1), (14, 0, 0)]), # 15
|
||||
("FusedMatMul", False, [(1, 0, 0), (14, 0, 1)]), # 16
|
||||
("Mul", False, [(15, 0, 0)]), # 17
|
||||
("Mul", False, [(16, 0, 0)]), # 18
|
||||
("Identity", False, [(17, 0, 0)]), # 19
|
||||
("Identity", False, [(18, 0, 0)]), # 20
|
||||
("Cast", False, [(19, 0, 0)]), # 21
|
||||
("Cast", False, [(20, 0, 0)]), # 22
|
||||
("Transpose", False, [(21, 0, 0)]), # 23
|
||||
("Transpose", False, [(22, 0, 0)]), # 24
|
||||
("FusedMatMul", False, [(8, 0, 0)]), # 25
|
||||
("Transpose", True, [(25, 0, 1)]), # 26
|
||||
("Transpose", False, [(25, 0, 0)]), # 27
|
||||
("Transpose", True, [(1, 0, 0)]), # 3
|
||||
("Transpose", True, [(2, 0, 0)]), # 4
|
||||
("Softmax", False, [(0, 0, 0)]), # 5
|
||||
("MatMul", False, [(5, 0, 0)]), # 6
|
||||
("Transpose", True, [(6, 0, 1)]), # 7
|
||||
("Transpose", False, [(6, 0, 0)]), # 8
|
||||
("FusedMatMul", False, [(7, 0, 1)]), # 9
|
||||
("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10
|
||||
("FusedMatMul", False, [(2, 0, 1), (10, 0, 0)]), # 11
|
||||
("FusedMatMul", False, [(1, 0, 0), (10, 0, 1)]), # 12
|
||||
("Mul", False, [(11, 0, 0)]), # 13
|
||||
("Mul", False, [(12, 0, 0)]), # 14
|
||||
("Identity", False, [(13, 0, 0)]), # 15
|
||||
("Identity", False, [(14, 0, 0)]), # 16
|
||||
("Transpose", False, [(15, 0, 0)]), # 17
|
||||
("Transpose", False, [(16, 0, 0)]), # 18
|
||||
("FusedMatMul", False, [(5, 0, 0)]), # 19
|
||||
("Transpose", True, [(19, 0, 1)]), # 20
|
||||
("Transpose", False, [(19, 0, 0)]), # 21
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -280,27 +274,24 @@ def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodePro
|
|||
scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
|
||||
scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
|
||||
if not (
|
||||
check_attribute_value(nodes[3], "to", 1)
|
||||
and check_attribute_value(nodes[4], "to", 1)
|
||||
and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1])
|
||||
and check_attribute_value(nodes[8], "to", 10)
|
||||
and check_attribute_value(nodes[10], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[11], "perm", [0, 2, 1, 3])
|
||||
check_attribute_value(nodes[3], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1])
|
||||
and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3])
|
||||
and scale_value_1 == scale_value_2
|
||||
):
|
||||
return [], [], []
|
||||
|
||||
nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
|
||||
idx,
|
||||
nodes[5].input[0],
|
||||
nodes[6].input[0],
|
||||
nodes[10].input[0],
|
||||
nodes[11].output[0],
|
||||
nodes[26].input[0],
|
||||
nodes[23].output[0],
|
||||
nodes[24].output[0],
|
||||
nodes[27].output[0],
|
||||
nodes[3].input[0],
|
||||
nodes[4].input[0],
|
||||
nodes[7].input[0],
|
||||
nodes[8].output[0],
|
||||
nodes[20].input[0],
|
||||
nodes[17].output[0],
|
||||
nodes[18].output[0],
|
||||
nodes[21].output[0],
|
||||
"",
|
||||
False,
|
||||
scale_value_1,
|
||||
|
|
@ -315,39 +306,32 @@ _PATTERN_3: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
|
|||
("MatMul", False, []), # 0
|
||||
("Mul", True, [(0, 0, 0)]), # 1
|
||||
("Mul", True, [(0, 0, 1)]), # 2
|
||||
("Cast", True, [(1, 0, 0)]), # 3
|
||||
("Cast", True, [(2, 0, 0)]), # 4
|
||||
("Transpose", True, [(3, 0, 0)]), # 5
|
||||
("Transpose", True, [(4, 0, 0)]), # 6
|
||||
("Add", False, [(0, 0, 0)]), # 7
|
||||
("Cast", True, [(7, 0, 1)]), # 8
|
||||
("Slice", True, [(8, 0, 0)]), # 9
|
||||
("Slice", True, [(9, 0, 0)]), # 10
|
||||
("Unsqueeze", True, [(9, 0, 2)]), # 11
|
||||
("Gather", True, [(11, 0, 0)]), # 12
|
||||
("Shape", True, [(12, 0, 0)]), # 13
|
||||
("Softmax", False, [(7, 0, 0)]), # 14
|
||||
("Cast", False, [(14, 0, 0)]), # 15
|
||||
("MatMul", False, [(15, 0, 0)]), # 16
|
||||
("Transpose", True, [(16, 0, 1)]), # 17
|
||||
("Transpose", False, [(16, 0, 0)]), # 18
|
||||
("FusedMatMul", False, [(17, 0, 1)]), # 19
|
||||
("Cast", False, [(19, 0, 0)]), # 20
|
||||
("SoftmaxGrad_13", False, [(20, 0, 0), (14, 0, 1)]), # 21
|
||||
("Identity", False, [(21, 0, 0)]), # 22
|
||||
("FusedMatMul", False, [(2, 0, 1), (22, 0, 0)]), # 23
|
||||
("FusedMatMul", False, [(1, 0, 0), (22, 0, 1)]), # 24
|
||||
("Mul", False, [(23, 0, 0)]), # 25
|
||||
("Mul", False, [(24, 0, 0)]), # 26
|
||||
("Identity", False, [(25, 0, 0)]), # 27
|
||||
("Identity", False, [(26, 0, 0)]), # 28
|
||||
("Cast", False, [(27, 0, 0)]), # 29
|
||||
("Cast", False, [(28, 0, 0)]), # 30
|
||||
("Transpose", False, [(29, 0, 0)]), # 31
|
||||
("Transpose", False, [(30, 0, 0)]), # 32
|
||||
("FusedMatMul", False, [(15, 0, 0)]), # 33
|
||||
("Transpose", True, [(33, 0, 1)]), # 34
|
||||
("Transpose", False, [(33, 0, 0)]), # 35
|
||||
("Transpose", True, [(1, 0, 0)]), # 3
|
||||
("Transpose", True, [(2, 0, 0)]), # 4
|
||||
("Add", False, [(0, 0, 0)]), # 5
|
||||
("Slice", True, [(5, 0, 1)]), # 6
|
||||
("Slice", True, [(6, 0, 0)]), # 7
|
||||
("Unsqueeze", True, [(6, 0, 2)]), # 8
|
||||
("Gather", True, [(8, 0, 0)]), # 9
|
||||
("Shape", True, [(9, 0, 0)]), # 10
|
||||
("Softmax", False, [(5, 0, 0)]), # 11
|
||||
("MatMul", False, [(11, 0, 0)]), # 12
|
||||
("Transpose", True, [(12, 0, 1)]), # 13
|
||||
("Transpose", False, [(12, 0, 0)]), # 14
|
||||
("FusedMatMul", False, [(13, 0, 1)]), # 15
|
||||
("SoftmaxGrad_13", False, [(15, 0, 0), (11, 0, 1)]), # 16
|
||||
("Identity", False, [(16, 0, 0)]), # 17
|
||||
("FusedMatMul", False, [(2, 0, 1), (17, 0, 0)]), # 18
|
||||
("FusedMatMul", False, [(1, 0, 0), (17, 0, 1)]), # 19
|
||||
("Mul", False, [(18, 0, 0)]), # 20
|
||||
("Mul", False, [(19, 0, 0)]), # 21
|
||||
("Identity", False, [(20, 0, 0)]), # 22
|
||||
("Identity", False, [(21, 0, 0)]), # 23
|
||||
("Transpose", False, [(22, 0, 0)]), # 24
|
||||
("Transpose", False, [(23, 0, 0)]), # 25
|
||||
("FusedMatMul", False, [(11, 0, 0)]), # 26
|
||||
("Transpose", True, [(26, 0, 1)]), # 27
|
||||
("Transpose", False, [(26, 0, 0)]), # 28
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -358,27 +342,24 @@ def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodePro
|
|||
scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
|
||||
scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
|
||||
if not (
|
||||
check_attribute_value(nodes[3], "to", 1)
|
||||
and check_attribute_value(nodes[4], "to", 1)
|
||||
and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1])
|
||||
and check_attribute_value(nodes[15], "to", 10)
|
||||
and check_attribute_value(nodes[17], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[18], "perm", [0, 2, 1, 3])
|
||||
check_attribute_value(nodes[3], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1])
|
||||
and check_attribute_value(nodes[13], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[14], "perm", [0, 2, 1, 3])
|
||||
and scale_value_1 == scale_value_2
|
||||
):
|
||||
return [], [], []
|
||||
|
||||
nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
|
||||
idx,
|
||||
nodes[5].input[0],
|
||||
nodes[6].input[0],
|
||||
nodes[17].input[0],
|
||||
nodes[18].output[0],
|
||||
nodes[34].input[0],
|
||||
nodes[31].output[0],
|
||||
nodes[32].output[0],
|
||||
nodes[35].output[0],
|
||||
nodes[3].input[0],
|
||||
nodes[4].input[0],
|
||||
nodes[13].input[0],
|
||||
nodes[14].output[0],
|
||||
nodes[27].input[0],
|
||||
nodes[24].output[0],
|
||||
nodes[25].output[0],
|
||||
nodes[28].output[0],
|
||||
"",
|
||||
False,
|
||||
scale_value_1,
|
||||
|
|
|
|||
Loading…
Reference in a new issue