[ORTModule] Adjust Attention Patterns for Efficient Attention ATen Fallback (#18471)

Adjust attention patterns to match latest Whisper+exporter. Also add
some condition check and add docs.
This commit is contained in:
Vincent Wang 2023-11-21 23:24:05 -08:00 committed by GitHub
parent 7c573054b6
commit 3bc9efc7b2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 103 additions and 90 deletions

View file

@ -269,6 +269,15 @@ data sparsity based performance optimizations.
unset ORTMODULE_CACHE_DIR # Disable
```
#### ORTMODULE_USE_EFFICIENT_ATTENTION
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and falling back to PyTorch's efficient_attention ATen kernel for execution. NOTE that it requires torch's version is 2.1.1 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually.
```bash
export ORTMODULE_USE_EFFICIENT_ATTENTION=1
```
### 2.2 Memory Optimization
Q: *Want to run a bigger batch size?*
@ -397,6 +406,15 @@ Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training
export ORTMODULE_TUNING_RESULTS_PATH=/tmp/tuning_results
```
#### ORTMODULE_USE_FLASH_ATTENTION
- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and using Flash Attention's Triton version as the kernel. NOTE that it requires ORTMODULE_USE_TRITON to be enabled, and CUDA device capability is 8.0 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually.
```bash
export ORTMODULE_USE_FLASH_ATTENTION=1
```
#### ORTMODULE_TRITON_DEBUG
- **Feature Area**: *ORTMODULE/TritonOp*

View file

@ -5,6 +5,8 @@
import os
import torch
from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out # noqa: F401
from ._slice_scel import slice_scel, slice_scel_backward # noqa: F401
@ -17,7 +19,12 @@ _all_kernels = [
"slice_scel_backward",
]
if "ORTMODULE_USE_FLASH_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1:
if (
"ORTMODULE_USE_FLASH_ATTENTION" in os.environ
and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 8
):
from ._flash_attn import flash_attn_backward, flash_attn_forward # noqa: F401
_all_kernels.extend(["flash_attn_forward", "flash_attn_backward"])

View file

@ -5,9 +5,16 @@
import os
import torch
from packaging.version import Version
_all_optimizers = []
if "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1:
if (
"ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ
and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1
and Version(torch.__version__) >= Version("2.1.1")
):
from ._aten_attn import optimize_graph_for_aten_efficient_attention # noqa: F401
_all_optimizers.append("optimize_graph_for_aten_efficient_attention")

View file

@ -245,31 +245,25 @@ _PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
("MatMul", False, []), # 0
("Mul", True, [(0, 0, 0)]), # 1
("Mul", True, [(0, 0, 1)]), # 2
("Cast", True, [(1, 0, 0)]), # 3
("Cast", True, [(2, 0, 0)]), # 4
("Transpose", True, [(3, 0, 0)]), # 5
("Transpose", True, [(4, 0, 0)]), # 6
("Softmax", False, [(0, 0, 0)]), # 7
("Cast", False, [(7, 0, 0)]), # 8
("MatMul", False, [(8, 0, 0)]), # 9
("Transpose", True, [(9, 0, 1)]), # 10
("Transpose", False, [(9, 0, 0)]), # 11
("FusedMatMul", False, [(10, 0, 1)]), # 12
("Cast", False, [(12, 0, 0)]), # 13
("SoftmaxGrad_13", False, [(13, 0, 0), (7, 0, 1)]), # 14
("FusedMatMul", False, [(2, 0, 1), (14, 0, 0)]), # 15
("FusedMatMul", False, [(1, 0, 0), (14, 0, 1)]), # 16
("Mul", False, [(15, 0, 0)]), # 17
("Mul", False, [(16, 0, 0)]), # 18
("Identity", False, [(17, 0, 0)]), # 19
("Identity", False, [(18, 0, 0)]), # 20
("Cast", False, [(19, 0, 0)]), # 21
("Cast", False, [(20, 0, 0)]), # 22
("Transpose", False, [(21, 0, 0)]), # 23
("Transpose", False, [(22, 0, 0)]), # 24
("FusedMatMul", False, [(8, 0, 0)]), # 25
("Transpose", True, [(25, 0, 1)]), # 26
("Transpose", False, [(25, 0, 0)]), # 27
("Transpose", True, [(1, 0, 0)]), # 3
("Transpose", True, [(2, 0, 0)]), # 4
("Softmax", False, [(0, 0, 0)]), # 5
("MatMul", False, [(5, 0, 0)]), # 6
("Transpose", True, [(6, 0, 1)]), # 7
("Transpose", False, [(6, 0, 0)]), # 8
("FusedMatMul", False, [(7, 0, 1)]), # 9
("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10
("FusedMatMul", False, [(2, 0, 1), (10, 0, 0)]), # 11
("FusedMatMul", False, [(1, 0, 0), (10, 0, 1)]), # 12
("Mul", False, [(11, 0, 0)]), # 13
("Mul", False, [(12, 0, 0)]), # 14
("Identity", False, [(13, 0, 0)]), # 15
("Identity", False, [(14, 0, 0)]), # 16
("Transpose", False, [(15, 0, 0)]), # 17
("Transpose", False, [(16, 0, 0)]), # 18
("FusedMatMul", False, [(5, 0, 0)]), # 19
("Transpose", True, [(19, 0, 1)]), # 20
("Transpose", False, [(19, 0, 0)]), # 21
]
@ -280,27 +274,24 @@ def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodePro
scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
if not (
check_attribute_value(nodes[3], "to", 1)
and check_attribute_value(nodes[4], "to", 1)
and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1])
and check_attribute_value(nodes[8], "to", 10)
and check_attribute_value(nodes[10], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[11], "perm", [0, 2, 1, 3])
check_attribute_value(nodes[3], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1])
and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3])
and scale_value_1 == scale_value_2
):
return [], [], []
nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
idx,
nodes[5].input[0],
nodes[6].input[0],
nodes[10].input[0],
nodes[11].output[0],
nodes[26].input[0],
nodes[23].output[0],
nodes[24].output[0],
nodes[27].output[0],
nodes[3].input[0],
nodes[4].input[0],
nodes[7].input[0],
nodes[8].output[0],
nodes[20].input[0],
nodes[17].output[0],
nodes[18].output[0],
nodes[21].output[0],
"",
False,
scale_value_1,
@ -315,39 +306,32 @@ _PATTERN_3: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
("MatMul", False, []), # 0
("Mul", True, [(0, 0, 0)]), # 1
("Mul", True, [(0, 0, 1)]), # 2
("Cast", True, [(1, 0, 0)]), # 3
("Cast", True, [(2, 0, 0)]), # 4
("Transpose", True, [(3, 0, 0)]), # 5
("Transpose", True, [(4, 0, 0)]), # 6
("Add", False, [(0, 0, 0)]), # 7
("Cast", True, [(7, 0, 1)]), # 8
("Slice", True, [(8, 0, 0)]), # 9
("Slice", True, [(9, 0, 0)]), # 10
("Unsqueeze", True, [(9, 0, 2)]), # 11
("Gather", True, [(11, 0, 0)]), # 12
("Shape", True, [(12, 0, 0)]), # 13
("Softmax", False, [(7, 0, 0)]), # 14
("Cast", False, [(14, 0, 0)]), # 15
("MatMul", False, [(15, 0, 0)]), # 16
("Transpose", True, [(16, 0, 1)]), # 17
("Transpose", False, [(16, 0, 0)]), # 18
("FusedMatMul", False, [(17, 0, 1)]), # 19
("Cast", False, [(19, 0, 0)]), # 20
("SoftmaxGrad_13", False, [(20, 0, 0), (14, 0, 1)]), # 21
("Identity", False, [(21, 0, 0)]), # 22
("FusedMatMul", False, [(2, 0, 1), (22, 0, 0)]), # 23
("FusedMatMul", False, [(1, 0, 0), (22, 0, 1)]), # 24
("Mul", False, [(23, 0, 0)]), # 25
("Mul", False, [(24, 0, 0)]), # 26
("Identity", False, [(25, 0, 0)]), # 27
("Identity", False, [(26, 0, 0)]), # 28
("Cast", False, [(27, 0, 0)]), # 29
("Cast", False, [(28, 0, 0)]), # 30
("Transpose", False, [(29, 0, 0)]), # 31
("Transpose", False, [(30, 0, 0)]), # 32
("FusedMatMul", False, [(15, 0, 0)]), # 33
("Transpose", True, [(33, 0, 1)]), # 34
("Transpose", False, [(33, 0, 0)]), # 35
("Transpose", True, [(1, 0, 0)]), # 3
("Transpose", True, [(2, 0, 0)]), # 4
("Add", False, [(0, 0, 0)]), # 5
("Slice", True, [(5, 0, 1)]), # 6
("Slice", True, [(6, 0, 0)]), # 7
("Unsqueeze", True, [(6, 0, 2)]), # 8
("Gather", True, [(8, 0, 0)]), # 9
("Shape", True, [(9, 0, 0)]), # 10
("Softmax", False, [(5, 0, 0)]), # 11
("MatMul", False, [(11, 0, 0)]), # 12
("Transpose", True, [(12, 0, 1)]), # 13
("Transpose", False, [(12, 0, 0)]), # 14
("FusedMatMul", False, [(13, 0, 1)]), # 15
("SoftmaxGrad_13", False, [(15, 0, 0), (11, 0, 1)]), # 16
("Identity", False, [(16, 0, 0)]), # 17
("FusedMatMul", False, [(2, 0, 1), (17, 0, 0)]), # 18
("FusedMatMul", False, [(1, 0, 0), (17, 0, 1)]), # 19
("Mul", False, [(18, 0, 0)]), # 20
("Mul", False, [(19, 0, 0)]), # 21
("Identity", False, [(20, 0, 0)]), # 22
("Identity", False, [(21, 0, 0)]), # 23
("Transpose", False, [(22, 0, 0)]), # 24
("Transpose", False, [(23, 0, 0)]), # 25
("FusedMatMul", False, [(11, 0, 0)]), # 26
("Transpose", True, [(26, 0, 1)]), # 27
("Transpose", False, [(26, 0, 0)]), # 28
]
@ -358,27 +342,24 @@ def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodePro
scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
if not (
check_attribute_value(nodes[3], "to", 1)
and check_attribute_value(nodes[4], "to", 1)
and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1])
and check_attribute_value(nodes[15], "to", 10)
and check_attribute_value(nodes[17], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[18], "perm", [0, 2, 1, 3])
check_attribute_value(nodes[3], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1])
and check_attribute_value(nodes[13], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[14], "perm", [0, 2, 1, 3])
and scale_value_1 == scale_value_2
):
return [], [], []
nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
idx,
nodes[5].input[0],
nodes[6].input[0],
nodes[17].input[0],
nodes[18].output[0],
nodes[34].input[0],
nodes[31].output[0],
nodes[32].output[0],
nodes[35].output[0],
nodes[3].input[0],
nodes[4].input[0],
nodes[13].input[0],
nodes[14].output[0],
nodes[27].input[0],
nodes[24].output[0],
nodes[25].output[0],
nodes[28].output[0],
"",
False,
scale_value_1,