diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 12733c3551..7fa89cca38 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -269,6 +269,15 @@ data sparsity based performance optimizations. unset ORTMODULE_CACHE_DIR # Disable ``` +#### ORTMODULE_USE_EFFICIENT_ATTENTION + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and falling back to PyTorch's efficient_attention ATen kernel for execution. NOTE that it requires torch's version is 2.1.1 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually. + + ```bash + export ORTMODULE_USE_EFFICIENT_ATTENTION=1 + ``` + ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?* @@ -397,6 +406,15 @@ Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training export ORTMODULE_TUNING_RESULTS_PATH=/tmp/tuning_results ``` +#### ORTMODULE_USE_FLASH_ATTENTION + +- **Feature Area**: *ORTMODULE/TritonOp* +- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and using Flash Attention's Triton version as the kernel. NOTE that it requires ORTMODULE_USE_TRITON to be enabled, and CUDA device capability is 8.0 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually. + + ```bash + export ORTMODULE_USE_FLASH_ATTENTION=1 + ``` + #### ORTMODULE_TRITON_DEBUG - **Feature Area**: *ORTMODULE/TritonOp* diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py b/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py index dc9e0c18ea..3213a8831a 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py @@ -5,6 +5,8 @@ import os +import torch + from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out # noqa: F401 from ._slice_scel import slice_scel, slice_scel_backward # noqa: F401 @@ -17,7 +19,12 @@ _all_kernels = [ "slice_scel_backward", ] -if "ORTMODULE_USE_FLASH_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1: +if ( + "ORTMODULE_USE_FLASH_ATTENTION" in os.environ + and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1 + and torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 8 +): from ._flash_attn import flash_attn_backward, flash_attn_forward # noqa: F401 _all_kernels.extend(["flash_attn_forward", "flash_attn_backward"]) diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py index d215e12f81..3d3538a62d 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py @@ -5,9 +5,16 @@ import os +import torch +from packaging.version import Version + _all_optimizers = [] -if "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1: +if ( + "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ + and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1 + and Version(torch.__version__) >= Version("2.1.1") +): from ._aten_attn import optimize_graph_for_aten_efficient_attention # noqa: F401 _all_optimizers.append("optimize_graph_for_aten_efficient_attention") diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py index 94bd41293b..b1e8809f03 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py @@ -245,31 +245,25 @@ _PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ ("MatMul", False, []), # 0 ("Mul", True, [(0, 0, 0)]), # 1 ("Mul", True, [(0, 0, 1)]), # 2 - ("Cast", True, [(1, 0, 0)]), # 3 - ("Cast", True, [(2, 0, 0)]), # 4 - ("Transpose", True, [(3, 0, 0)]), # 5 - ("Transpose", True, [(4, 0, 0)]), # 6 - ("Softmax", False, [(0, 0, 0)]), # 7 - ("Cast", False, [(7, 0, 0)]), # 8 - ("MatMul", False, [(8, 0, 0)]), # 9 - ("Transpose", True, [(9, 0, 1)]), # 10 - ("Transpose", False, [(9, 0, 0)]), # 11 - ("FusedMatMul", False, [(10, 0, 1)]), # 12 - ("Cast", False, [(12, 0, 0)]), # 13 - ("SoftmaxGrad_13", False, [(13, 0, 0), (7, 0, 1)]), # 14 - ("FusedMatMul", False, [(2, 0, 1), (14, 0, 0)]), # 15 - ("FusedMatMul", False, [(1, 0, 0), (14, 0, 1)]), # 16 - ("Mul", False, [(15, 0, 0)]), # 17 - ("Mul", False, [(16, 0, 0)]), # 18 - ("Identity", False, [(17, 0, 0)]), # 19 - ("Identity", False, [(18, 0, 0)]), # 20 - ("Cast", False, [(19, 0, 0)]), # 21 - ("Cast", False, [(20, 0, 0)]), # 22 - ("Transpose", False, [(21, 0, 0)]), # 23 - ("Transpose", False, [(22, 0, 0)]), # 24 - ("FusedMatMul", False, [(8, 0, 0)]), # 25 - ("Transpose", True, [(25, 0, 1)]), # 26 - ("Transpose", False, [(25, 0, 0)]), # 27 + ("Transpose", True, [(1, 0, 0)]), # 3 + ("Transpose", True, [(2, 0, 0)]), # 4 + ("Softmax", False, [(0, 0, 0)]), # 5 + ("MatMul", False, [(5, 0, 0)]), # 6 + ("Transpose", True, [(6, 0, 1)]), # 7 + ("Transpose", False, [(6, 0, 0)]), # 8 + ("FusedMatMul", False, [(7, 0, 1)]), # 9 + ("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10 + ("FusedMatMul", False, [(2, 0, 1), (10, 0, 0)]), # 11 + ("FusedMatMul", False, [(1, 0, 0), (10, 0, 1)]), # 12 + ("Mul", False, [(11, 0, 0)]), # 13 + ("Mul", False, [(12, 0, 0)]), # 14 + ("Identity", False, [(13, 0, 0)]), # 15 + ("Identity", False, [(14, 0, 0)]), # 16 + ("Transpose", False, [(15, 0, 0)]), # 17 + ("Transpose", False, [(16, 0, 0)]), # 18 + ("FusedMatMul", False, [(5, 0, 0)]), # 19 + ("Transpose", True, [(19, 0, 1)]), # 20 + ("Transpose", False, [(19, 0, 0)]), # 21 ] @@ -280,27 +274,24 @@ def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodePro scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 if not ( - check_attribute_value(nodes[3], "to", 1) - and check_attribute_value(nodes[4], "to", 1) - and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1]) - and check_attribute_value(nodes[8], "to", 10) - and check_attribute_value(nodes[10], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[11], "perm", [0, 2, 1, 3]) + check_attribute_value(nodes[3], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1]) + and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3]) and scale_value_1 == scale_value_2 ): return [], [], [] nodes_to_add, new_value_infos = _make_efficient_attention_nodes( idx, - nodes[5].input[0], - nodes[6].input[0], - nodes[10].input[0], - nodes[11].output[0], - nodes[26].input[0], - nodes[23].output[0], - nodes[24].output[0], - nodes[27].output[0], + nodes[3].input[0], + nodes[4].input[0], + nodes[7].input[0], + nodes[8].output[0], + nodes[20].input[0], + nodes[17].output[0], + nodes[18].output[0], + nodes[21].output[0], "", False, scale_value_1, @@ -315,39 +306,32 @@ _PATTERN_3: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ ("MatMul", False, []), # 0 ("Mul", True, [(0, 0, 0)]), # 1 ("Mul", True, [(0, 0, 1)]), # 2 - ("Cast", True, [(1, 0, 0)]), # 3 - ("Cast", True, [(2, 0, 0)]), # 4 - ("Transpose", True, [(3, 0, 0)]), # 5 - ("Transpose", True, [(4, 0, 0)]), # 6 - ("Add", False, [(0, 0, 0)]), # 7 - ("Cast", True, [(7, 0, 1)]), # 8 - ("Slice", True, [(8, 0, 0)]), # 9 - ("Slice", True, [(9, 0, 0)]), # 10 - ("Unsqueeze", True, [(9, 0, 2)]), # 11 - ("Gather", True, [(11, 0, 0)]), # 12 - ("Shape", True, [(12, 0, 0)]), # 13 - ("Softmax", False, [(7, 0, 0)]), # 14 - ("Cast", False, [(14, 0, 0)]), # 15 - ("MatMul", False, [(15, 0, 0)]), # 16 - ("Transpose", True, [(16, 0, 1)]), # 17 - ("Transpose", False, [(16, 0, 0)]), # 18 - ("FusedMatMul", False, [(17, 0, 1)]), # 19 - ("Cast", False, [(19, 0, 0)]), # 20 - ("SoftmaxGrad_13", False, [(20, 0, 0), (14, 0, 1)]), # 21 - ("Identity", False, [(21, 0, 0)]), # 22 - ("FusedMatMul", False, [(2, 0, 1), (22, 0, 0)]), # 23 - ("FusedMatMul", False, [(1, 0, 0), (22, 0, 1)]), # 24 - ("Mul", False, [(23, 0, 0)]), # 25 - ("Mul", False, [(24, 0, 0)]), # 26 - ("Identity", False, [(25, 0, 0)]), # 27 - ("Identity", False, [(26, 0, 0)]), # 28 - ("Cast", False, [(27, 0, 0)]), # 29 - ("Cast", False, [(28, 0, 0)]), # 30 - ("Transpose", False, [(29, 0, 0)]), # 31 - ("Transpose", False, [(30, 0, 0)]), # 32 - ("FusedMatMul", False, [(15, 0, 0)]), # 33 - ("Transpose", True, [(33, 0, 1)]), # 34 - ("Transpose", False, [(33, 0, 0)]), # 35 + ("Transpose", True, [(1, 0, 0)]), # 3 + ("Transpose", True, [(2, 0, 0)]), # 4 + ("Add", False, [(0, 0, 0)]), # 5 + ("Slice", True, [(5, 0, 1)]), # 6 + ("Slice", True, [(6, 0, 0)]), # 7 + ("Unsqueeze", True, [(6, 0, 2)]), # 8 + ("Gather", True, [(8, 0, 0)]), # 9 + ("Shape", True, [(9, 0, 0)]), # 10 + ("Softmax", False, [(5, 0, 0)]), # 11 + ("MatMul", False, [(11, 0, 0)]), # 12 + ("Transpose", True, [(12, 0, 1)]), # 13 + ("Transpose", False, [(12, 0, 0)]), # 14 + ("FusedMatMul", False, [(13, 0, 1)]), # 15 + ("SoftmaxGrad_13", False, [(15, 0, 0), (11, 0, 1)]), # 16 + ("Identity", False, [(16, 0, 0)]), # 17 + ("FusedMatMul", False, [(2, 0, 1), (17, 0, 0)]), # 18 + ("FusedMatMul", False, [(1, 0, 0), (17, 0, 1)]), # 19 + ("Mul", False, [(18, 0, 0)]), # 20 + ("Mul", False, [(19, 0, 0)]), # 21 + ("Identity", False, [(20, 0, 0)]), # 22 + ("Identity", False, [(21, 0, 0)]), # 23 + ("Transpose", False, [(22, 0, 0)]), # 24 + ("Transpose", False, [(23, 0, 0)]), # 25 + ("FusedMatMul", False, [(11, 0, 0)]), # 26 + ("Transpose", True, [(26, 0, 1)]), # 27 + ("Transpose", False, [(26, 0, 0)]), # 28 ] @@ -358,27 +342,24 @@ def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodePro scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 if not ( - check_attribute_value(nodes[3], "to", 1) - and check_attribute_value(nodes[4], "to", 1) - and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1]) - and check_attribute_value(nodes[15], "to", 10) - and check_attribute_value(nodes[17], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[18], "perm", [0, 2, 1, 3]) + check_attribute_value(nodes[3], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1]) + and check_attribute_value(nodes[13], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[14], "perm", [0, 2, 1, 3]) and scale_value_1 == scale_value_2 ): return [], [], [] nodes_to_add, new_value_infos = _make_efficient_attention_nodes( idx, - nodes[5].input[0], - nodes[6].input[0], - nodes[17].input[0], - nodes[18].output[0], - nodes[34].input[0], - nodes[31].output[0], - nodes[32].output[0], - nodes[35].output[0], + nodes[3].input[0], + nodes[4].input[0], + nodes[13].input[0], + nodes[14].output[0], + nodes[27].input[0], + nodes[24].output[0], + nodes[25].output[0], + nodes[28].output[0], "", False, scale_value_1,