Attention fusion for stable diffusion clip model (#17445)

Add attention fusion for stable diffusion clip model to improve performance of SD or SDXL
This commit is contained in:
Tianlei Wu 2023-09-08 14:17:14 -07:00 committed by GitHub
parent 4d753b74a5
commit 29a818caa0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 403 additions and 15 deletions

View file

@ -33,18 +33,20 @@ def run_model(model_path, all_inputs, use_gpu, disable_optimization):
return results, latency_list, output_names
def compare(baseline_results, treatment_results, verbose, rtol=1e-3, atol=1e-4):
def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3):
# Validate the output of baseline and treatment, to make sure the results are similar.
diff_count = 0
max_rel_diff = 0
max_abs_diff = 0
for test_case_id, results in enumerate(baseline_results):
case_passed = True
for i in range(len(results)):
treatment_output = treatment_results[test_case_id][i]
rel_diff = np.amax(np.abs((treatment_output - results[i]) / results[i]))
abs_diff = np.amax(np.abs(treatment_output - results[i]))
max_rel_diff = max(max_rel_diff, rel_diff)
if verbose and abs_diff > atol:
print("abs_diff", abs_diff)
print("treatment", treatment_output)
print("baseline", results[i])
max_abs_diff = max(max_abs_diff, abs_diff)
if not np.allclose(results[i].tolist(), treatment_output.tolist(), rtol=rtol, atol=atol):
if case_passed:
@ -54,7 +56,7 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-3, atol=1e-4):
if verbose:
print(f"case {test_case_id} output {i}")
print(f"baseline={results[i].tolist()}\ntreatment={treatment_output}")
print(f"rel_diff={rel_diff} abs_diff={abs_diff}")
print(f"abs_diff={abs_diff}")
if diff_count == 0:
print(
@ -70,8 +72,7 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-3, atol=1e-4):
)
print(f"maximum absolute difference={max_abs_diff}")
print(f"maximum relative difference={max_rel_diff}")
return max_abs_diff, case_passed
def run_test(
@ -133,7 +134,7 @@ def run_test(
print(f"treatment average latency: {statistics.mean(treatment_latency) * 1000} ms")
# Validate the output of baseline and treatment, to make sure the results are similar.
compare(baseline_results, treatment_results, verbose, rtol, atol)
return compare(baseline_results, treatment_results, verbose, rtol, atol)
def parse_arguments():

View file

@ -196,7 +196,7 @@ class FusionAttention(Fusion):
def get_add_qk_str(self, add_qk: NodeProto):
shape_infer = self.model.infer_runtime_shape(update=True)
if shape_infer is None:
return
return None
input_0_shape = shape_infer.get_edge_shape(add_qk.input[0])
input_1_shape = shape_infer.get_edge_shape(add_qk.input[1])
@ -697,6 +697,7 @@ class FusionAttention(Fusion):
present_k: str = "",
present_v: str = "",
scale: Optional[float] = None,
causal: bool = False,
) -> Union[NodeProto, None]:
"""Create an Attention node.
@ -717,6 +718,8 @@ class FusionAttention(Fusion):
past_v (str): name of input for past V value
present_k (str): name of output to store present K value
present_v (str): name of output to store present V value
scale: scale before softmax
causal: whether it is uni-directional mask.
Returns:
Union[NodeProto, None]: the node created or None if failed.
@ -828,7 +831,7 @@ class FusionAttention(Fusion):
# For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
if self.use_multi_head_attention:
if add_qk_str is not None:
if add_qk_str:
logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.")
return None
@ -864,7 +867,7 @@ class FusionAttention(Fusion):
past_kv = self.concat_kv(past_k, past_v)
attention_inputs.append(past_kv)
if add_qk_str is not None:
if add_qk_str:
# Convert 4d mask from (B,1,M,M) to (B,N,M,M)
# B = batch size, M = max sequence length, N = num heads
concat_node_name = self.model.create_node_name("Concat")
@ -901,6 +904,9 @@ class FusionAttention(Fusion):
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
if causal:
attention_node.attribute.extend([helper.make_attribute("unidirectional", 1)])
if scale is not None:
attention_node.attribute.extend([helper.make_attribute("scale", scale)])

View file

@ -0,0 +1,218 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Tuple
from fusion_attention import AttentionMask, FusionAttention
from fusion_options import AttentionMaskFormat
from onnx import NodeProto
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionAttentionClip(FusionAttention):
"""
Fuse Attention subgraph of Clip into one Attention node.
"""
def __init__(
self,
model: OnnxModel,
hidden_size: int,
num_heads: int,
):
attention_mask = AttentionMask(model)
attention_mask.mask_format = AttentionMaskFormat.NoMask
super().__init__(
model,
hidden_size,
num_heads,
attention_mask,
use_multi_head_attention=False,
search_op_types=["SkipLayerNormalization"],
)
def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]:
"""Detect num_heads and hidden_size for ONNX model from MiDaS
Args:
reshape_q (NodeProto): reshape node for q
Returns:
Tuple[int, int]: num_heads and hidden_size
"""
concat = self.model.match_parent(reshape_q, "Concat", 1)
if concat is None or len(concat.input) != 4:
return self.num_heads, self.hidden_size
# The shape is a tensor like [?, ?, num_heads, head_size]
num_head_value = self.model.get_constant_value(concat.input[2])
if num_head_value is None:
return self.num_heads, self.hidden_size # Fall back to user specified value
if len(num_head_value) != 1 or num_head_value[0] <= 0:
return self.num_heads, self.hidden_size # Fall back to user specified value
num_heads = num_head_value[0]
head_size_value = self.model.get_constant_value(concat.input[3])
if head_size_value is None:
return self.num_heads, self.hidden_size # Fall back to user specified value
if len(head_size_value) != 1 or head_size_value[0] <= 0:
return self.num_heads, self.hidden_size # Fall back to user specified value
head_size = head_size_value[0]
hidden_size = num_heads * head_size
if self.num_heads > 0 and num_heads != self.num_heads:
if self.num_heads_warning:
logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
self.num_heads_warning = False # Do not show the warning more than once
if self.hidden_size > 0 and hidden_size != self.hidden_size:
if self.hidden_size_warning:
logger.warning(
f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
)
self.hidden_size_warning = False # Do not show the warning more than once
return num_heads, hidden_size
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
skip_input_index = None
node_before_layer_norm = None
for i in [1, 0]:
parent = self.model.match_parent(normalize_node, "SkipLayerNormalization", i)
if parent is not None:
skip_input_index = i
node_before_layer_norm = parent
root_input = None
if node_before_layer_norm is not None:
root_input = node_before_layer_norm.output[0]
else:
# Deal with the first attention after the embedding layer.
for i in [0, 1]:
node_before_layer_norm = self.model.match_parent(normalize_node, "Add", i)
if node_before_layer_norm is None:
continue
child = self.model.find_first_child_by_type(
node_before_layer_norm, "LayerNormalization", input_name_to_nodes, False
)
if child is None:
continue
root_input = child.output[0]
skip_input_index = i
break
if skip_input_index is None:
return
qkv_nodes = self.model.match_parent_path(
normalize_node,
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
[1 - skip_input_index, None, None, 0, 0, 0],
)
if qkv_nodes is None:
return
(_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
v_nodes = self.model.match_parent_path(
matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None]
)
if v_nodes is None:
logger.debug("fuse_attention: failed to match v path")
return
(_, _, reshape_v, add_v, matmul_v) = v_nodes
add_mask_indices = []
qk_nodes = self.model.match_parent_path(
matmul_qkv,
["Softmax", "Reshape", "Add", "Reshape", "MatMul"],
[0, 0, 0, None, 0],
return_indice=add_mask_indices,
)
if qk_nodes is None:
logger.debug("fuse_attention: failed to match qk path")
return
assert len(add_mask_indices) == 1
causal_mask_input_index = 1 - add_mask_indices[0]
(_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes
q_nodes = self.model.match_parent_path(
matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, None, None]
)
if q_nodes is None:
logger.debug("fuse_attention: failed to match q path")
return
(_, _transpose_q, reshape_q, mul_q, add_q, matmul_q) = q_nodes
k_nodes = self.model.match_parent_path(
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None]
)
if k_nodes is None:
logger.debug("fuse_attention: failed to match k path")
return
(_transpose_k, _reshape_k, _, _, add_k, matmul_k) = k_nodes
if matmul_q.input[0] != root_input or matmul_k.input[0] != root_input or matmul_v.input[0] != root_input:
logger.debug("fuse_attention: expect to have same input to q, k and v matmul")
return
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
if num_heads <= 0 or hidden_size <= 0:
logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
return
attention_last_node = reshape_qkv
# Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path
# of computing causal mask.
causal_mask_nodes = self.model.match_parent_path(
add_mask,
["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0, 0],
)
if causal_mask_nodes is None:
# If the model is exported with batch_size == 1, there is no Concat node
causal_mask_nodes = self.model.match_parent_path(
add_mask,
["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0],
)
if causal_mask_nodes is None:
logger.debug("fuse_attention: failed to match causal mask subgraph")
return
new_node = self.create_attention_node(
mask_index=None,
q_matmul=matmul_q,
k_matmul=matmul_k,
v_matmul=matmul_v,
q_add=add_q,
k_add=add_k,
v_add=add_v,
num_heads=num_heads,
hidden_size=hidden_size,
input=root_input,
output=attention_last_node.output[0],
add_qk_str=None,
scale=None,
causal=True,
)
if new_node is None:
return
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
# Use prune graph to remove nodes since they are shared by all attention nodes.
self.prune_graph = True

View file

@ -45,6 +45,9 @@ class FusionOptions:
self.enable_gemm_fast_gelu = False
self.group_norm_channels_last = True
if model_type == "clip":
self.enable_embed_layer_norm = False
# Set default to sequence length for BERT model to use fused attention to speed up.
# Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd.
self.attention_mask_format = AttentionMaskFormat.AttentionMask

View file

@ -384,8 +384,7 @@ Some kernels are enabled by MIOpen. We hereby thank for the AMD developers' coll
There are other optimizations might improve the performance or reduce memory footprint:
* Export the whole pipeline into a single ONNX model. Currently, there are multiple ONNX models (CLIP, VAE and U-Net etc). Each model uses separated thread pool and memory allocator. Combine them into one model could share thread pool and memory allocator. The end result is more efficient and less memory footprint.
* For Stable Diffusion 2.1, we disable TensorRT flash attention kernel and use only memory efficient attention. It is possible to add flash attention using Triton compiler to improve performance.
* For Stable Diffusion 2.1, we disable TensorRT flash attention kernel and use only memory efficient attention. It is possible to add flash attention in Windows to improve performance.
* Reduce GPU memory footprint by actively deleting buffers for intermediate results.
* Attention fusion in CLIP
* Safety Checker Optimization
* Leverage FP8 in latest GPU

View file

@ -5,15 +5,17 @@
from logging import getLogger
from fusion_attention_clip import FusionAttentionClip
from onnx import ModelProto
from onnx_model_unet import UnetOnnxModel
from onnx_model_bert import BertOnnxModel
logger = getLogger(__name__)
class ClipOnnxModel(UnetOnnxModel):
class ClipOnnxModel(BertOnnxModel):
def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
self.clip_attention_fusion = FusionAttentionClip(self, self.hidden_size, self.num_heads)
def get_fused_operator_statistics(self):
"""
@ -31,3 +33,6 @@ class ClipOnnxModel(UnetOnnxModel):
logger.info(f"Optimized operators:{op_count}")
return op_count
def fuse_attention(self):
self.clip_attention_fusion.apply()

View file

@ -0,0 +1,156 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import os
import shutil
import unittest
import pytest
from parity_utilities import find_transformers_source
if find_transformers_source():
from compare_bert_results import run_test
from fusion_options import FusionOptions
from optimizer import optimize_model
else:
from onnxruntime.transformers.compare_bert_results import run_test
from onnxruntime.transformers.fusion_options import FusionOptions
from onnxruntime.transformers.optimizer import optimize_model
TINY_MODELS = {
"stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch",
"stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl",
}
class TestStableDiffusionOptimization(unittest.TestCase):
def verify_node_count(self, onnx_model, expected_node_count, test_name):
for op_type, count in expected_node_count.items():
if len(onnx_model.get_nodes_by_op_type(op_type)) != count:
print(f"Counters is not expected in test: {test_name}")
for op, counter in expected_node_count.items():
print(f"{op}: {len(onnx_model.get_nodes_by_op_type(op))} expected={counter}")
self.assertEqual(len(onnx_model.get_nodes_by_op_type(op_type)), count)
def verify_clip_optimizer(self, clip_onnx_path, optimized_clip_onnx_path, expected_counters, float16=False):
fusion_options = FusionOptions("clip")
m = optimize_model(
clip_onnx_path,
model_type="clip",
num_heads=0,
hidden_size=0,
opt_level=0,
optimization_options=fusion_options,
use_gpu=True,
)
self.verify_node_count(m, expected_counters, "test_clip")
if float16:
m.convert_float_to_float16(
keep_io_types=True,
)
print(m.get_operator_statistics())
m.save_model_to_file(optimized_clip_onnx_path)
threshold = 1e-2 if float16 else 3e-3
max_abs_diff, passed = run_test(
clip_onnx_path,
optimized_clip_onnx_path,
output_dir=None,
batch_size=1,
sequence_length=77,
use_gpu=True,
test_cases=10,
seed=1,
verbose=False,
rtol=1e-1,
atol=threshold,
input_ids_name="input_ids",
segment_ids_name=None,
input_mask_name=None,
mask_type=0,
)
self.assertLess(max_abs_diff, threshold)
self.assertTrue(passed)
@pytest.mark.slow
def test_clip_sd(self):
save_directory = "tiny-random-stable-diffusion"
if os.path.exists(save_directory):
shutil.rmtree(save_directory, ignore_errors=True)
model_type = "stable-diffusion"
model_name = TINY_MODELS[model_type]
from optimum.onnxruntime import ORTStableDiffusionPipeline
base = ORTStableDiffusionPipeline.from_pretrained(model_name, export=True)
base.save_pretrained(save_directory)
clip_onnx_path = os.path.join(save_directory, "text_encoder", "model.onnx")
optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder", "opt.onnx")
self.verify_clip_optimizer(
clip_onnx_path,
optimized_clip_onnx_path,
expected_counters={
"EmbedLayerNormalization": 0,
"Attention": 5,
"SkipLayerNormalization": 10,
"LayerNormalization": 1,
"Gelu": 0,
"BiasGelu": 0,
},
float16=True,
)
@pytest.mark.slow
def test_clip_sdxl(self):
save_directory = "tiny-random-stable-diffusion-xl"
if os.path.exists(save_directory):
shutil.rmtree(save_directory, ignore_errors=True)
model_type = "stable-diffusion-xl"
model_name = TINY_MODELS[model_type]
from optimum.onnxruntime import ORTStableDiffusionXLPipeline
base = ORTStableDiffusionXLPipeline.from_pretrained(model_name, export=True)
base.save_pretrained(save_directory)
clip_onnx_path = os.path.join(save_directory, "text_encoder", "model.onnx")
optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder", "opt.onnx")
self.verify_clip_optimizer(
clip_onnx_path,
optimized_clip_onnx_path,
expected_counters={
"EmbedLayerNormalization": 0,
"Attention": 5,
"SkipLayerNormalization": 10,
"LayerNormalization": 1,
"Gelu": 0,
"BiasGelu": 5,
},
)
clip_onnx_path = os.path.join(save_directory, "text_encoder_2", "model.onnx")
optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder_2", "opt.onnx")
self.verify_clip_optimizer(
clip_onnx_path,
optimized_clip_onnx_path,
expected_counters={
"EmbedLayerNormalization": 0,
"Attention": 5,
"SkipLayerNormalization": 10,
"LayerNormalization": 1,
"Gelu": 0,
"BiasGelu": 5,
},
)
if __name__ == "__main__":
unittest.main()