From 29a818caa09ad377d6cc019e11f90630e1eaaf66 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 8 Sep 2023 14:17:14 -0700 Subject: [PATCH] Attention fusion for stable diffusion clip model (#17445) Add attention fusion for stable diffusion clip model to improve performance of SD or SDXL --- .../transformers/compare_bert_results.py | 17 +- .../tools/transformers/fusion_attention.py | 12 +- .../transformers/fusion_attention_clip.py | 218 ++++++++++++++++++ .../tools/transformers/fusion_options.py | 3 + .../models/stable_diffusion/README.md | 3 +- .../tools/transformers/onnx_model_clip.py | 9 +- .../test_optimizer_stable_diffusion.py | 156 +++++++++++++ 7 files changed, 403 insertions(+), 15 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/fusion_attention_clip.py create mode 100644 onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py diff --git a/onnxruntime/python/tools/transformers/compare_bert_results.py b/onnxruntime/python/tools/transformers/compare_bert_results.py index 4cb9585962..61e4c97c75 100644 --- a/onnxruntime/python/tools/transformers/compare_bert_results.py +++ b/onnxruntime/python/tools/transformers/compare_bert_results.py @@ -33,18 +33,20 @@ def run_model(model_path, all_inputs, use_gpu, disable_optimization): return results, latency_list, output_names -def compare(baseline_results, treatment_results, verbose, rtol=1e-3, atol=1e-4): +def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3): # Validate the output of baseline and treatment, to make sure the results are similar. diff_count = 0 - max_rel_diff = 0 max_abs_diff = 0 for test_case_id, results in enumerate(baseline_results): case_passed = True for i in range(len(results)): treatment_output = treatment_results[test_case_id][i] - rel_diff = np.amax(np.abs((treatment_output - results[i]) / results[i])) abs_diff = np.amax(np.abs(treatment_output - results[i])) - max_rel_diff = max(max_rel_diff, rel_diff) + if verbose and abs_diff > atol: + print("abs_diff", abs_diff) + print("treatment", treatment_output) + print("baseline", results[i]) + max_abs_diff = max(max_abs_diff, abs_diff) if not np.allclose(results[i].tolist(), treatment_output.tolist(), rtol=rtol, atol=atol): if case_passed: @@ -54,7 +56,7 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-3, atol=1e-4): if verbose: print(f"case {test_case_id} output {i}") print(f"baseline={results[i].tolist()}\ntreatment={treatment_output}") - print(f"rel_diff={rel_diff} abs_diff={abs_diff}") + print(f"abs_diff={abs_diff}") if diff_count == 0: print( @@ -70,8 +72,7 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-3, atol=1e-4): ) print(f"maximum absolute difference={max_abs_diff}") - - print(f"maximum relative difference={max_rel_diff}") + return max_abs_diff, case_passed def run_test( @@ -133,7 +134,7 @@ def run_test( print(f"treatment average latency: {statistics.mean(treatment_latency) * 1000} ms") # Validate the output of baseline and treatment, to make sure the results are similar. - compare(baseline_results, treatment_results, verbose, rtol, atol) + return compare(baseline_results, treatment_results, verbose, rtol, atol) def parse_arguments(): diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 9628e2a741..40f2aee875 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -196,7 +196,7 @@ class FusionAttention(Fusion): def get_add_qk_str(self, add_qk: NodeProto): shape_infer = self.model.infer_runtime_shape(update=True) if shape_infer is None: - return + return None input_0_shape = shape_infer.get_edge_shape(add_qk.input[0]) input_1_shape = shape_infer.get_edge_shape(add_qk.input[1]) @@ -697,6 +697,7 @@ class FusionAttention(Fusion): present_k: str = "", present_v: str = "", scale: Optional[float] = None, + causal: bool = False, ) -> Union[NodeProto, None]: """Create an Attention node. @@ -717,6 +718,8 @@ class FusionAttention(Fusion): past_v (str): name of input for past V value present_k (str): name of output to store present K value present_v (str): name of output to store present V value + scale: scale before softmax + causal: whether it is uni-directional mask. Returns: Union[NodeProto, None]: the node created or None if failed. @@ -828,7 +831,7 @@ class FusionAttention(Fusion): # For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights. if self.use_multi_head_attention: - if add_qk_str is not None: + if add_qk_str: logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.") return None @@ -864,7 +867,7 @@ class FusionAttention(Fusion): past_kv = self.concat_kv(past_k, past_v) attention_inputs.append(past_kv) - if add_qk_str is not None: + if add_qk_str: # Convert 4d mask from (B,1,M,M) to (B,N,M,M) # B = batch size, M = max sequence length, N = num heads concat_node_name = self.model.create_node_name("Concat") @@ -901,6 +904,9 @@ class FusionAttention(Fusion): attention_node.domain = "com.microsoft" attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + if causal: + attention_node.attribute.extend([helper.make_attribute("unidirectional", 1)]) + if scale is not None: attention_node.attribute.extend([helper.make_attribute("scale", scale)]) diff --git a/onnxruntime/python/tools/transformers/fusion_attention_clip.py b/onnxruntime/python/tools/transformers/fusion_attention_clip.py new file mode 100644 index 0000000000..d400e248d6 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_attention_clip.py @@ -0,0 +1,218 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import Tuple + +from fusion_attention import AttentionMask, FusionAttention +from fusion_options import AttentionMaskFormat +from onnx import NodeProto +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionAttentionClip(FusionAttention): + """ + Fuse Attention subgraph of Clip into one Attention node. + """ + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + ): + attention_mask = AttentionMask(model) + attention_mask.mask_format = AttentionMaskFormat.NoMask + + super().__init__( + model, + hidden_size, + num_heads, + attention_mask, + use_multi_head_attention=False, + search_op_types=["SkipLayerNormalization"], + ) + + def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]: + """Detect num_heads and hidden_size for ONNX model from MiDaS + Args: + reshape_q (NodeProto): reshape node for q + Returns: + Tuple[int, int]: num_heads and hidden_size + """ + concat = self.model.match_parent(reshape_q, "Concat", 1) + if concat is None or len(concat.input) != 4: + return self.num_heads, self.hidden_size + + # The shape is a tensor like [?, ?, num_heads, head_size] + num_head_value = self.model.get_constant_value(concat.input[2]) + if num_head_value is None: + return self.num_heads, self.hidden_size # Fall back to user specified value + + if len(num_head_value) != 1 or num_head_value[0] <= 0: + return self.num_heads, self.hidden_size # Fall back to user specified value + + num_heads = num_head_value[0] + + head_size_value = self.model.get_constant_value(concat.input[3]) + if head_size_value is None: + return self.num_heads, self.hidden_size # Fall back to user specified value + + if len(head_size_value) != 1 or head_size_value[0] <= 0: + return self.num_heads, self.hidden_size # Fall back to user specified value + + head_size = head_size_value[0] + + hidden_size = num_heads * head_size + + if self.num_heads > 0 and num_heads != self.num_heads: + if self.num_heads_warning: + logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.") + self.num_heads_warning = False # Do not show the warning more than once + + if self.hidden_size > 0 and hidden_size != self.hidden_size: + if self.hidden_size_warning: + logger.warning( + f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value." + ) + self.hidden_size_warning = False # Do not show the warning more than once + + return num_heads, hidden_size + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + skip_input_index = None + node_before_layer_norm = None + for i in [1, 0]: + parent = self.model.match_parent(normalize_node, "SkipLayerNormalization", i) + if parent is not None: + skip_input_index = i + node_before_layer_norm = parent + + root_input = None + if node_before_layer_norm is not None: + root_input = node_before_layer_norm.output[0] + else: + # Deal with the first attention after the embedding layer. + for i in [0, 1]: + node_before_layer_norm = self.model.match_parent(normalize_node, "Add", i) + if node_before_layer_norm is None: + continue + child = self.model.find_first_child_by_type( + node_before_layer_norm, "LayerNormalization", input_name_to_nodes, False + ) + if child is None: + continue + root_input = child.output[0] + skip_input_index = i + break + + if skip_input_index is None: + return + + qkv_nodes = self.model.match_parent_path( + normalize_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [1 - skip_input_index, None, None, 0, 0, 0], + ) + if qkv_nodes is None: + return + + (_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes + + v_nodes = self.model.match_parent_path( + matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None] + ) + if v_nodes is None: + logger.debug("fuse_attention: failed to match v path") + return + (_, _, reshape_v, add_v, matmul_v) = v_nodes + + add_mask_indices = [] + qk_nodes = self.model.match_parent_path( + matmul_qkv, + ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], + [0, 0, 0, None, 0], + return_indice=add_mask_indices, + ) + if qk_nodes is None: + logger.debug("fuse_attention: failed to match qk path") + return + assert len(add_mask_indices) == 1 + causal_mask_input_index = 1 - add_mask_indices[0] + + (_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes + + q_nodes = self.model.match_parent_path( + matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, None, None] + ) + if q_nodes is None: + logger.debug("fuse_attention: failed to match q path") + return + (_, _transpose_q, reshape_q, mul_q, add_q, matmul_q) = q_nodes + + k_nodes = self.model.match_parent_path( + matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None] + ) + if k_nodes is None: + logger.debug("fuse_attention: failed to match k path") + return + + (_transpose_k, _reshape_k, _, _, add_k, matmul_k) = k_nodes + if matmul_q.input[0] != root_input or matmul_k.input[0] != root_input or matmul_v.input[0] != root_input: + logger.debug("fuse_attention: expect to have same input to q, k and v matmul") + return + + num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q) + if num_heads <= 0 or hidden_size <= 0: + logger.debug("fuse_attention: failed to detect num_heads or hidden_size") + return + + attention_last_node = reshape_qkv + + # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path + # of computing causal mask. + causal_mask_nodes = self.model.match_parent_path( + add_mask, + ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0, 0], + ) + if causal_mask_nodes is None: + # If the model is exported with batch_size == 1, there is no Concat node + causal_mask_nodes = self.model.match_parent_path( + add_mask, + ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0], + ) + if causal_mask_nodes is None: + logger.debug("fuse_attention: failed to match causal mask subgraph") + return + + new_node = self.create_attention_node( + mask_index=None, + q_matmul=matmul_q, + k_matmul=matmul_k, + v_matmul=matmul_v, + q_add=add_q, + k_add=add_k, + v_add=add_v, + num_heads=num_heads, + hidden_size=hidden_size, + input=root_input, + output=attention_last_node.output[0], + add_qk_str=None, + scale=None, + causal=True, + ) + if new_node is None: + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv]) + + # Use prune graph to remove nodes since they are shared by all attention nodes. + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 57f0fea99d..69b5cd26f4 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -45,6 +45,9 @@ class FusionOptions: self.enable_gemm_fast_gelu = False self.group_norm_channels_last = True + if model_type == "clip": + self.enable_embed_layer_norm = False + # Set default to sequence length for BERT model to use fused attention to speed up. # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd. self.attention_mask_format = AttentionMaskFormat.AttentionMask diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index d184224317..facbd3bf69 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -384,8 +384,7 @@ Some kernels are enabled by MIOpen. We hereby thank for the AMD developers' coll There are other optimizations might improve the performance or reduce memory footprint: * Export the whole pipeline into a single ONNX model. Currently, there are multiple ONNX models (CLIP, VAE and U-Net etc). Each model uses separated thread pool and memory allocator. Combine them into one model could share thread pool and memory allocator. The end result is more efficient and less memory footprint. -* For Stable Diffusion 2.1, we disable TensorRT flash attention kernel and use only memory efficient attention. It is possible to add flash attention using Triton compiler to improve performance. +* For Stable Diffusion 2.1, we disable TensorRT flash attention kernel and use only memory efficient attention. It is possible to add flash attention in Windows to improve performance. * Reduce GPU memory footprint by actively deleting buffers for intermediate results. -* Attention fusion in CLIP * Safety Checker Optimization * Leverage FP8 in latest GPU diff --git a/onnxruntime/python/tools/transformers/onnx_model_clip.py b/onnxruntime/python/tools/transformers/onnx_model_clip.py index 93e8623768..9b4ca03a47 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_clip.py +++ b/onnxruntime/python/tools/transformers/onnx_model_clip.py @@ -5,15 +5,17 @@ from logging import getLogger +from fusion_attention_clip import FusionAttentionClip from onnx import ModelProto -from onnx_model_unet import UnetOnnxModel +from onnx_model_bert import BertOnnxModel logger = getLogger(__name__) -class ClipOnnxModel(UnetOnnxModel): +class ClipOnnxModel(BertOnnxModel): def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) + self.clip_attention_fusion = FusionAttentionClip(self, self.hidden_size, self.num_heads) def get_fused_operator_statistics(self): """ @@ -31,3 +33,6 @@ class ClipOnnxModel(UnetOnnxModel): logger.info(f"Optimized operators:{op_count}") return op_count + + def fuse_attention(self): + self.clip_attention_fusion.apply() diff --git a/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py b/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py new file mode 100644 index 0000000000..cde6b56a66 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py @@ -0,0 +1,156 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import os +import shutil +import unittest + +import pytest +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from compare_bert_results import run_test + from fusion_options import FusionOptions + from optimizer import optimize_model +else: + from onnxruntime.transformers.compare_bert_results import run_test + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.optimizer import optimize_model + +TINY_MODELS = { + "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", + "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", +} + + +class TestStableDiffusionOptimization(unittest.TestCase): + def verify_node_count(self, onnx_model, expected_node_count, test_name): + for op_type, count in expected_node_count.items(): + if len(onnx_model.get_nodes_by_op_type(op_type)) != count: + print(f"Counters is not expected in test: {test_name}") + for op, counter in expected_node_count.items(): + print(f"{op}: {len(onnx_model.get_nodes_by_op_type(op))} expected={counter}") + + self.assertEqual(len(onnx_model.get_nodes_by_op_type(op_type)), count) + + def verify_clip_optimizer(self, clip_onnx_path, optimized_clip_onnx_path, expected_counters, float16=False): + fusion_options = FusionOptions("clip") + m = optimize_model( + clip_onnx_path, + model_type="clip", + num_heads=0, + hidden_size=0, + opt_level=0, + optimization_options=fusion_options, + use_gpu=True, + ) + self.verify_node_count(m, expected_counters, "test_clip") + + if float16: + m.convert_float_to_float16( + keep_io_types=True, + ) + print(m.get_operator_statistics()) + m.save_model_to_file(optimized_clip_onnx_path) + + threshold = 1e-2 if float16 else 3e-3 + max_abs_diff, passed = run_test( + clip_onnx_path, + optimized_clip_onnx_path, + output_dir=None, + batch_size=1, + sequence_length=77, + use_gpu=True, + test_cases=10, + seed=1, + verbose=False, + rtol=1e-1, + atol=threshold, + input_ids_name="input_ids", + segment_ids_name=None, + input_mask_name=None, + mask_type=0, + ) + + self.assertLess(max_abs_diff, threshold) + self.assertTrue(passed) + + @pytest.mark.slow + def test_clip_sd(self): + save_directory = "tiny-random-stable-diffusion" + if os.path.exists(save_directory): + shutil.rmtree(save_directory, ignore_errors=True) + + model_type = "stable-diffusion" + model_name = TINY_MODELS[model_type] + + from optimum.onnxruntime import ORTStableDiffusionPipeline + + base = ORTStableDiffusionPipeline.from_pretrained(model_name, export=True) + base.save_pretrained(save_directory) + + clip_onnx_path = os.path.join(save_directory, "text_encoder", "model.onnx") + optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder", "opt.onnx") + self.verify_clip_optimizer( + clip_onnx_path, + optimized_clip_onnx_path, + expected_counters={ + "EmbedLayerNormalization": 0, + "Attention": 5, + "SkipLayerNormalization": 10, + "LayerNormalization": 1, + "Gelu": 0, + "BiasGelu": 0, + }, + float16=True, + ) + + @pytest.mark.slow + def test_clip_sdxl(self): + save_directory = "tiny-random-stable-diffusion-xl" + if os.path.exists(save_directory): + shutil.rmtree(save_directory, ignore_errors=True) + + model_type = "stable-diffusion-xl" + model_name = TINY_MODELS[model_type] + + from optimum.onnxruntime import ORTStableDiffusionXLPipeline + + base = ORTStableDiffusionXLPipeline.from_pretrained(model_name, export=True) + base.save_pretrained(save_directory) + + clip_onnx_path = os.path.join(save_directory, "text_encoder", "model.onnx") + optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder", "opt.onnx") + self.verify_clip_optimizer( + clip_onnx_path, + optimized_clip_onnx_path, + expected_counters={ + "EmbedLayerNormalization": 0, + "Attention": 5, + "SkipLayerNormalization": 10, + "LayerNormalization": 1, + "Gelu": 0, + "BiasGelu": 5, + }, + ) + + clip_onnx_path = os.path.join(save_directory, "text_encoder_2", "model.onnx") + optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder_2", "opt.onnx") + self.verify_clip_optimizer( + clip_onnx_path, + optimized_clip_onnx_path, + expected_counters={ + "EmbedLayerNormalization": 0, + "Attention": 5, + "SkipLayerNormalization": 10, + "LayerNormalization": 1, + "Gelu": 0, + "BiasGelu": 5, + }, + ) + + +if __name__ == "__main__": + unittest.main()