mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Attention fusion for UNet onnx model export from PyTorch 2.* (#16629)
### Description Tested with stable diffusion unet models exported by pytorch nightly. Example to run: ``` cd onnxruntime/python/tools/transformers/ python optimizer.py --input unet.onnx --output unet_fp16.onnx --model_type unet --float16 --opt_level 0 ```
This commit is contained in:
parent
b4bf7d5044
commit
2de5807703
1 changed files with 159 additions and 66 deletions
|
|
@ -39,41 +39,72 @@ class FusionAttentionUnet(Fusion):
|
|||
self.num_heads_warning = True
|
||||
self.hidden_size_warning = True
|
||||
|
||||
def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, layernorm_node: NodeProto) -> Tuple[int, int]:
|
||||
"""Detect num_heads and hidden_size from a reshape node.
|
||||
def get_num_heads(self, reshape_q: NodeProto, is_torch2: bool = False) -> int:
|
||||
"""Detect num_heads from a reshape node.
|
||||
|
||||
Args:
|
||||
reshape_q (NodeProto): reshape node for Q
|
||||
add_q (NodeProto): add node for Q
|
||||
is_torch2 (bool): graph pattern is from PyTorch 2.*
|
||||
Returns:
|
||||
int: num_heads, or 0 if not found
|
||||
"""
|
||||
num_heads = 0
|
||||
if is_torch2:
|
||||
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
|
||||
reshape_parent = self.model.get_parent(reshape_q, 1)
|
||||
if reshape_parent and reshape_parent.op_type == "Concat" and len(reshape_parent.input) == 4:
|
||||
num_heads = self.model.get_constant_value(reshape_parent.input[2])
|
||||
if isinstance(num_heads, np.ndarray) and list(num_heads.shape) == [1]:
|
||||
num_heads = int(num_heads)
|
||||
else:
|
||||
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
|
||||
q_shape_value = self.model.get_constant_value(reshape_q.input[1])
|
||||
if isinstance(q_shape_value, np.ndarray) and list(q_shape_value.shape) == [4]:
|
||||
num_heads = int(q_shape_value[2])
|
||||
|
||||
if isinstance(num_heads, int) and num_heads > 0:
|
||||
return num_heads
|
||||
|
||||
return 0
|
||||
|
||||
def get_hidden_size(self, layernorm_node):
|
||||
"""Detect hidden_size from LayerNormalization node.
|
||||
Args:
|
||||
layernorm_node (NodeProto): LayerNormalization node before Q, K and V
|
||||
Returns:
|
||||
int: hidden_size, or 0 if not found
|
||||
"""
|
||||
layernorm_bias = self.model.get_initializer(layernorm_node.input[2])
|
||||
if layernorm_bias:
|
||||
return NumpyHelper.to_array(layernorm_bias).shape[0]
|
||||
|
||||
return 0
|
||||
|
||||
def get_num_heads_and_hidden_size(
|
||||
self, reshape_q: NodeProto, layernorm_node: NodeProto, is_torch2: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
"""Detect num_heads and hidden_size.
|
||||
|
||||
Args:
|
||||
reshape_q (NodeProto): reshape node for Q
|
||||
is_torch2 (bool): graph pattern is from PyTorch 2.*
|
||||
layernorm_node (NodeProto): LayerNormalization node before Q, K, V
|
||||
Returns:
|
||||
Tuple[int, int]: num_heads and hidden_size
|
||||
"""
|
||||
|
||||
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
|
||||
q_shape_value = self.model.get_constant_value(reshape_q.input[1])
|
||||
if q_shape_value is None:
|
||||
logger.debug(f"{reshape_q.input[1]} is not constant.")
|
||||
return self.num_heads, self.hidden_size # Fall back to user specified value
|
||||
|
||||
if len(q_shape_value) != 4 or q_shape_value[2] <= 0:
|
||||
logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, -1].")
|
||||
return self.num_heads, self.hidden_size # Fall back to user specified value
|
||||
|
||||
num_heads = q_shape_value[2]
|
||||
|
||||
layernorm_bias = self.model.get_initializer(layernorm_node.input[1])
|
||||
if layernorm_bias is None:
|
||||
logger.debug(f"{layernorm_node.input[1]} is not initializer.")
|
||||
return self.num_heads, self.hidden_size # Fall back to user specified value
|
||||
|
||||
hidden_size = NumpyHelper.to_array(layernorm_bias).shape[0]
|
||||
num_heads = self.get_num_heads(reshape_q, is_torch2)
|
||||
if num_heads <= 0:
|
||||
num_heads = self.num_heads # Fall back to user specified value
|
||||
|
||||
if self.num_heads > 0 and num_heads != self.num_heads:
|
||||
if self.num_heads_warning:
|
||||
logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
|
||||
self.num_heads_warning = False # Do not show the warning more than once
|
||||
|
||||
hidden_size = self.get_hidden_size(layernorm_node)
|
||||
if hidden_size <= 0:
|
||||
hidden_size = self.hidden_size # Fall back to user specified value
|
||||
|
||||
if self.hidden_size > 0 and hidden_size != self.hidden_size:
|
||||
if self.hidden_size_warning:
|
||||
logger.warning(
|
||||
|
|
@ -359,60 +390,21 @@ class FusionAttentionUnet(Fusion):
|
|||
children_nodes = input_name_to_nodes[root_input]
|
||||
skip_add = None
|
||||
for node in children_nodes:
|
||||
if node.op_type == "Add": # or node.op_type == "SkipLayerNormalization":
|
||||
if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet
|
||||
skip_add = node
|
||||
break
|
||||
if skip_add is None:
|
||||
return
|
||||
|
||||
another_input = 1 if skip_add.input[0] == root_input else 0
|
||||
qkv_nodes = self.model.match_parent_path(
|
||||
skip_add,
|
||||
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
|
||||
[another_input, None, None, 0, 0, 0],
|
||||
)
|
||||
|
||||
if qkv_nodes is None:
|
||||
match_qkv = self.match_qkv_torch1(root_input, skip_add) or self.match_qkv_torch2(root_input, skip_add)
|
||||
if match_qkv is None:
|
||||
return
|
||||
|
||||
(_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
|
||||
|
||||
# No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
|
||||
v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
|
||||
if v_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match v path")
|
||||
return
|
||||
(_, _, _, matmul_v) = v_nodes
|
||||
|
||||
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
|
||||
if qk_nodes is not None:
|
||||
(_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
|
||||
else:
|
||||
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
|
||||
if qk_nodes is not None:
|
||||
(_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
|
||||
else:
|
||||
logger.debug("fuse_attention: failed to match qk path")
|
||||
return
|
||||
|
||||
q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
|
||||
if q_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match q path")
|
||||
return
|
||||
(_, _transpose_q, reshape_q, matmul_q) = q_nodes
|
||||
|
||||
k_nodes = self.model.match_parent_path(
|
||||
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0]
|
||||
)
|
||||
if k_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match k path")
|
||||
return
|
||||
|
||||
(_, _, _, _, matmul_k) = k_nodes
|
||||
is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v = match_qkv
|
||||
|
||||
attention_last_node = reshape_qkv
|
||||
|
||||
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node)
|
||||
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
|
||||
if q_num_heads <= 0:
|
||||
logger.debug("fuse_attention: failed to detect num_heads")
|
||||
return
|
||||
|
|
@ -437,3 +429,104 @@ class FusionAttentionUnet(Fusion):
|
|||
|
||||
# Use prune graph to remove nodes since they are shared by all attention nodes.
|
||||
self.prune_graph = True
|
||||
|
||||
def match_qkv_torch1(self, root_input, skip_add):
|
||||
"""Match Q, K and V paths exported by PyTorch 1.*"""
|
||||
another_input = 1 if skip_add.input[0] == root_input else 0
|
||||
qkv_nodes = self.model.match_parent_path(
|
||||
skip_add,
|
||||
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
|
||||
[another_input, None, None, 0, 0, 0],
|
||||
)
|
||||
|
||||
if qkv_nodes is None:
|
||||
return None
|
||||
|
||||
(_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
|
||||
|
||||
# No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
|
||||
v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
|
||||
if v_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match v path")
|
||||
return None
|
||||
(_, _, _, matmul_v) = v_nodes
|
||||
|
||||
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
|
||||
if qk_nodes is not None:
|
||||
(_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
|
||||
else:
|
||||
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
|
||||
if qk_nodes is not None:
|
||||
(_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
|
||||
else:
|
||||
logger.debug("fuse_attention: failed to match qk path")
|
||||
return None
|
||||
|
||||
q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
|
||||
if q_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match q path")
|
||||
return None
|
||||
(_, _transpose_q, reshape_q, matmul_q) = q_nodes
|
||||
|
||||
k_nodes = self.model.match_parent_path(
|
||||
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0]
|
||||
)
|
||||
if k_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match k path")
|
||||
return None
|
||||
|
||||
(_, _, _, _, matmul_k) = k_nodes
|
||||
|
||||
return False, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
|
||||
|
||||
def match_qkv_torch2(self, root_input, skip_add):
|
||||
"""Match Q, K and V paths exported by PyTorch 2.*"""
|
||||
another_input = 1 if skip_add.input[0] == root_input else 0
|
||||
qkv_nodes = self.model.match_parent_path(
|
||||
skip_add,
|
||||
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
||||
[another_input, None, None, 0, 0],
|
||||
)
|
||||
|
||||
if qkv_nodes is None:
|
||||
return None
|
||||
|
||||
(_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
|
||||
|
||||
v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "MatMul"], [1, 0, 0])
|
||||
if v_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match v path")
|
||||
return None
|
||||
(_, _, matmul_v) = v_nodes
|
||||
|
||||
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
|
||||
if qk_nodes is not None:
|
||||
(_softmax_qk, matmul_qk) = qk_nodes
|
||||
else:
|
||||
logger.debug("fuse_attention: failed to match qk path")
|
||||
return None
|
||||
|
||||
q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [0, None, 0, 0])
|
||||
if q_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match q path")
|
||||
return None
|
||||
(mul_q, _transpose_q, reshape_q, matmul_q) = q_nodes
|
||||
|
||||
k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [1, None, 0, 0])
|
||||
if k_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match k path")
|
||||
return None
|
||||
|
||||
(_mul_k, _, _, matmul_k) = k_nodes
|
||||
|
||||
# The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
|
||||
mul_q_nodes = self.model.match_parent_path(
|
||||
mul_q,
|
||||
["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
|
||||
[None, 0, 1, 0, 0, 0, 0, 0],
|
||||
)
|
||||
if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
|
||||
logger.debug("fuse_attention: failed to match mul_q path")
|
||||
return None
|
||||
|
||||
return True, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
|
||||
|
|
|
|||
Loading…
Reference in a new issue