Attention fusion for UNet onnx model export from PyTorch 2.* (#16629)

### Description
Tested with stable diffusion unet models exported by pytorch nightly.

Example to run:
```
cd onnxruntime/python/tools/transformers/
python optimizer.py --input unet.onnx  --output unet_fp16.onnx --model_type unet --float16 --opt_level 0
```
This commit is contained in:
Tianlei Wu 2023-07-11 14:35:48 -07:00 committed by GitHub
parent b4bf7d5044
commit 2de5807703
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -39,41 +39,72 @@ class FusionAttentionUnet(Fusion):
self.num_heads_warning = True
self.hidden_size_warning = True
def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, layernorm_node: NodeProto) -> Tuple[int, int]:
"""Detect num_heads and hidden_size from a reshape node.
def get_num_heads(self, reshape_q: NodeProto, is_torch2: bool = False) -> int:
"""Detect num_heads from a reshape node.
Args:
reshape_q (NodeProto): reshape node for Q
add_q (NodeProto): add node for Q
is_torch2 (bool): graph pattern is from PyTorch 2.*
Returns:
int: num_heads, or 0 if not found
"""
num_heads = 0
if is_torch2:
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
reshape_parent = self.model.get_parent(reshape_q, 1)
if reshape_parent and reshape_parent.op_type == "Concat" and len(reshape_parent.input) == 4:
num_heads = self.model.get_constant_value(reshape_parent.input[2])
if isinstance(num_heads, np.ndarray) and list(num_heads.shape) == [1]:
num_heads = int(num_heads)
else:
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
q_shape_value = self.model.get_constant_value(reshape_q.input[1])
if isinstance(q_shape_value, np.ndarray) and list(q_shape_value.shape) == [4]:
num_heads = int(q_shape_value[2])
if isinstance(num_heads, int) and num_heads > 0:
return num_heads
return 0
def get_hidden_size(self, layernorm_node):
"""Detect hidden_size from LayerNormalization node.
Args:
layernorm_node (NodeProto): LayerNormalization node before Q, K and V
Returns:
int: hidden_size, or 0 if not found
"""
layernorm_bias = self.model.get_initializer(layernorm_node.input[2])
if layernorm_bias:
return NumpyHelper.to_array(layernorm_bias).shape[0]
return 0
def get_num_heads_and_hidden_size(
self, reshape_q: NodeProto, layernorm_node: NodeProto, is_torch2: bool = False
) -> Tuple[int, int]:
"""Detect num_heads and hidden_size.
Args:
reshape_q (NodeProto): reshape node for Q
is_torch2 (bool): graph pattern is from PyTorch 2.*
layernorm_node (NodeProto): LayerNormalization node before Q, K, V
Returns:
Tuple[int, int]: num_heads and hidden_size
"""
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
q_shape_value = self.model.get_constant_value(reshape_q.input[1])
if q_shape_value is None:
logger.debug(f"{reshape_q.input[1]} is not constant.")
return self.num_heads, self.hidden_size # Fall back to user specified value
if len(q_shape_value) != 4 or q_shape_value[2] <= 0:
logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, -1].")
return self.num_heads, self.hidden_size # Fall back to user specified value
num_heads = q_shape_value[2]
layernorm_bias = self.model.get_initializer(layernorm_node.input[1])
if layernorm_bias is None:
logger.debug(f"{layernorm_node.input[1]} is not initializer.")
return self.num_heads, self.hidden_size # Fall back to user specified value
hidden_size = NumpyHelper.to_array(layernorm_bias).shape[0]
num_heads = self.get_num_heads(reshape_q, is_torch2)
if num_heads <= 0:
num_heads = self.num_heads # Fall back to user specified value
if self.num_heads > 0 and num_heads != self.num_heads:
if self.num_heads_warning:
logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
self.num_heads_warning = False # Do not show the warning more than once
hidden_size = self.get_hidden_size(layernorm_node)
if hidden_size <= 0:
hidden_size = self.hidden_size # Fall back to user specified value
if self.hidden_size > 0 and hidden_size != self.hidden_size:
if self.hidden_size_warning:
logger.warning(
@ -359,60 +390,21 @@ class FusionAttentionUnet(Fusion):
children_nodes = input_name_to_nodes[root_input]
skip_add = None
for node in children_nodes:
if node.op_type == "Add": # or node.op_type == "SkipLayerNormalization":
if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet
skip_add = node
break
if skip_add is None:
return
another_input = 1 if skip_add.input[0] == root_input else 0
qkv_nodes = self.model.match_parent_path(
skip_add,
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
[another_input, None, None, 0, 0, 0],
)
if qkv_nodes is None:
match_qkv = self.match_qkv_torch1(root_input, skip_add) or self.match_qkv_torch2(root_input, skip_add)
if match_qkv is None:
return
(_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
# No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
if v_nodes is None:
logger.debug("fuse_attention: failed to match v path")
return
(_, _, _, matmul_v) = v_nodes
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
if qk_nodes is not None:
(_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
else:
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
if qk_nodes is not None:
(_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
else:
logger.debug("fuse_attention: failed to match qk path")
return
q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
if q_nodes is None:
logger.debug("fuse_attention: failed to match q path")
return
(_, _transpose_q, reshape_q, matmul_q) = q_nodes
k_nodes = self.model.match_parent_path(
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0]
)
if k_nodes is None:
logger.debug("fuse_attention: failed to match k path")
return
(_, _, _, _, matmul_k) = k_nodes
is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v = match_qkv
attention_last_node = reshape_qkv
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node)
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
if q_num_heads <= 0:
logger.debug("fuse_attention: failed to detect num_heads")
return
@ -437,3 +429,104 @@ class FusionAttentionUnet(Fusion):
# Use prune graph to remove nodes since they are shared by all attention nodes.
self.prune_graph = True
def match_qkv_torch1(self, root_input, skip_add):
"""Match Q, K and V paths exported by PyTorch 1.*"""
another_input = 1 if skip_add.input[0] == root_input else 0
qkv_nodes = self.model.match_parent_path(
skip_add,
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
[another_input, None, None, 0, 0, 0],
)
if qkv_nodes is None:
return None
(_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
# No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
if v_nodes is None:
logger.debug("fuse_attention: failed to match v path")
return None
(_, _, _, matmul_v) = v_nodes
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
if qk_nodes is not None:
(_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
else:
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
if qk_nodes is not None:
(_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
else:
logger.debug("fuse_attention: failed to match qk path")
return None
q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
if q_nodes is None:
logger.debug("fuse_attention: failed to match q path")
return None
(_, _transpose_q, reshape_q, matmul_q) = q_nodes
k_nodes = self.model.match_parent_path(
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0]
)
if k_nodes is None:
logger.debug("fuse_attention: failed to match k path")
return None
(_, _, _, _, matmul_k) = k_nodes
return False, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
def match_qkv_torch2(self, root_input, skip_add):
"""Match Q, K and V paths exported by PyTorch 2.*"""
another_input = 1 if skip_add.input[0] == root_input else 0
qkv_nodes = self.model.match_parent_path(
skip_add,
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
[another_input, None, None, 0, 0],
)
if qkv_nodes is None:
return None
(_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "MatMul"], [1, 0, 0])
if v_nodes is None:
logger.debug("fuse_attention: failed to match v path")
return None
(_, _, matmul_v) = v_nodes
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
if qk_nodes is not None:
(_softmax_qk, matmul_qk) = qk_nodes
else:
logger.debug("fuse_attention: failed to match qk path")
return None
q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [0, None, 0, 0])
if q_nodes is None:
logger.debug("fuse_attention: failed to match q path")
return None
(mul_q, _transpose_q, reshape_q, matmul_q) = q_nodes
k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [1, None, 0, 0])
if k_nodes is None:
logger.debug("fuse_attention: failed to match k path")
return None
(_mul_k, _, _, matmul_k) = k_nodes
# The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
mul_q_nodes = self.model.match_parent_path(
mul_q,
["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
[None, 0, 1, 0, 0, 0, 0, 0],
)
if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
logger.debug("fuse_attention: failed to match mul_q path")
return None
return True, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v