From 2de5807703cb1553cc299ecd203cd531ab62e77b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 11 Jul 2023 14:35:48 -0700 Subject: [PATCH] Attention fusion for UNet onnx model export from PyTorch 2.* (#16629) ### Description Tested with stable diffusion unet models exported by pytorch nightly. Example to run: ``` cd onnxruntime/python/tools/transformers/ python optimizer.py --input unet.onnx --output unet_fp16.onnx --model_type unet --float16 --opt_level 0 ``` --- .../transformers/fusion_attention_unet.py | 225 +++++++++++++----- 1 file changed, 159 insertions(+), 66 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index dcbc640923..f286206e5b 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -39,41 +39,72 @@ class FusionAttentionUnet(Fusion): self.num_heads_warning = True self.hidden_size_warning = True - def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, layernorm_node: NodeProto) -> Tuple[int, int]: - """Detect num_heads and hidden_size from a reshape node. + def get_num_heads(self, reshape_q: NodeProto, is_torch2: bool = False) -> int: + """Detect num_heads from a reshape node. Args: reshape_q (NodeProto): reshape node for Q - add_q (NodeProto): add node for Q + is_torch2 (bool): graph pattern is from PyTorch 2.* + Returns: + int: num_heads, or 0 if not found + """ + num_heads = 0 + if is_torch2: + # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size] + reshape_parent = self.model.get_parent(reshape_q, 1) + if reshape_parent and reshape_parent.op_type == "Concat" and len(reshape_parent.input) == 4: + num_heads = self.model.get_constant_value(reshape_parent.input[2]) + if isinstance(num_heads, np.ndarray) and list(num_heads.shape) == [1]: + num_heads = int(num_heads) + else: + # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size] + q_shape_value = self.model.get_constant_value(reshape_q.input[1]) + if isinstance(q_shape_value, np.ndarray) and list(q_shape_value.shape) == [4]: + num_heads = int(q_shape_value[2]) + if isinstance(num_heads, int) and num_heads > 0: + return num_heads + + return 0 + + def get_hidden_size(self, layernorm_node): + """Detect hidden_size from LayerNormalization node. + Args: + layernorm_node (NodeProto): LayerNormalization node before Q, K and V + Returns: + int: hidden_size, or 0 if not found + """ + layernorm_bias = self.model.get_initializer(layernorm_node.input[2]) + if layernorm_bias: + return NumpyHelper.to_array(layernorm_bias).shape[0] + + return 0 + + def get_num_heads_and_hidden_size( + self, reshape_q: NodeProto, layernorm_node: NodeProto, is_torch2: bool = False + ) -> Tuple[int, int]: + """Detect num_heads and hidden_size. + + Args: + reshape_q (NodeProto): reshape node for Q + is_torch2 (bool): graph pattern is from PyTorch 2.* + layernorm_node (NodeProto): LayerNormalization node before Q, K, V Returns: Tuple[int, int]: num_heads and hidden_size """ - - # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size] - q_shape_value = self.model.get_constant_value(reshape_q.input[1]) - if q_shape_value is None: - logger.debug(f"{reshape_q.input[1]} is not constant.") - return self.num_heads, self.hidden_size # Fall back to user specified value - - if len(q_shape_value) != 4 or q_shape_value[2] <= 0: - logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, -1].") - return self.num_heads, self.hidden_size # Fall back to user specified value - - num_heads = q_shape_value[2] - - layernorm_bias = self.model.get_initializer(layernorm_node.input[1]) - if layernorm_bias is None: - logger.debug(f"{layernorm_node.input[1]} is not initializer.") - return self.num_heads, self.hidden_size # Fall back to user specified value - - hidden_size = NumpyHelper.to_array(layernorm_bias).shape[0] + num_heads = self.get_num_heads(reshape_q, is_torch2) + if num_heads <= 0: + num_heads = self.num_heads # Fall back to user specified value if self.num_heads > 0 and num_heads != self.num_heads: if self.num_heads_warning: logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.") self.num_heads_warning = False # Do not show the warning more than once + hidden_size = self.get_hidden_size(layernorm_node) + if hidden_size <= 0: + hidden_size = self.hidden_size # Fall back to user specified value + if self.hidden_size > 0 and hidden_size != self.hidden_size: if self.hidden_size_warning: logger.warning( @@ -359,60 +390,21 @@ class FusionAttentionUnet(Fusion): children_nodes = input_name_to_nodes[root_input] skip_add = None for node in children_nodes: - if node.op_type == "Add": # or node.op_type == "SkipLayerNormalization": + if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet skip_add = node break if skip_add is None: return - another_input = 1 if skip_add.input[0] == root_input else 0 - qkv_nodes = self.model.match_parent_path( - skip_add, - ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], - [another_input, None, None, 0, 0, 0], - ) - - if qkv_nodes is None: + match_qkv = self.match_qkv_torch1(root_input, skip_add) or self.match_qkv_torch2(root_input, skip_add) + if match_qkv is None: return - (_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes - - # No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input. - v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0]) - if v_nodes is None: - logger.debug("fuse_attention: failed to match v path") - return - (_, _, _, matmul_v) = v_nodes - - qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0]) - if qk_nodes is not None: - (_softmax_qk, _mul_qk, matmul_qk) = qk_nodes - else: - qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0]) - if qk_nodes is not None: - (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes - else: - logger.debug("fuse_attention: failed to match qk path") - return - - q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0]) - if q_nodes is None: - logger.debug("fuse_attention: failed to match q path") - return - (_, _transpose_q, reshape_q, matmul_q) = q_nodes - - k_nodes = self.model.match_parent_path( - matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0] - ) - if k_nodes is None: - logger.debug("fuse_attention: failed to match k path") - return - - (_, _, _, _, matmul_k) = k_nodes + is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v = match_qkv attention_last_node = reshape_qkv - q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node) + q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2) if q_num_heads <= 0: logger.debug("fuse_attention: failed to detect num_heads") return @@ -437,3 +429,104 @@ class FusionAttentionUnet(Fusion): # Use prune graph to remove nodes since they are shared by all attention nodes. self.prune_graph = True + + def match_qkv_torch1(self, root_input, skip_add): + """Match Q, K and V paths exported by PyTorch 1.*""" + another_input = 1 if skip_add.input[0] == root_input else 0 + qkv_nodes = self.model.match_parent_path( + skip_add, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [another_input, None, None, 0, 0, 0], + ) + + if qkv_nodes is None: + return None + + (_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes + + # No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input. + v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0]) + if v_nodes is None: + logger.debug("fuse_attention: failed to match v path") + return None + (_, _, _, matmul_v) = v_nodes + + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0]) + if qk_nodes is not None: + (_softmax_qk, _mul_qk, matmul_qk) = qk_nodes + else: + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0]) + if qk_nodes is not None: + (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes + else: + logger.debug("fuse_attention: failed to match qk path") + return None + + q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0]) + if q_nodes is None: + logger.debug("fuse_attention: failed to match q path") + return None + (_, _transpose_q, reshape_q, matmul_q) = q_nodes + + k_nodes = self.model.match_parent_path( + matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0] + ) + if k_nodes is None: + logger.debug("fuse_attention: failed to match k path") + return None + + (_, _, _, _, matmul_k) = k_nodes + + return False, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v + + def match_qkv_torch2(self, root_input, skip_add): + """Match Q, K and V paths exported by PyTorch 2.*""" + another_input = 1 if skip_add.input[0] == root_input else 0 + qkv_nodes = self.model.match_parent_path( + skip_add, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [another_input, None, None, 0, 0], + ) + + if qkv_nodes is None: + return None + + (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes + + v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "MatMul"], [1, 0, 0]) + if v_nodes is None: + logger.debug("fuse_attention: failed to match v path") + return None + (_, _, matmul_v) = v_nodes + + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) + if qk_nodes is not None: + (_softmax_qk, matmul_qk) = qk_nodes + else: + logger.debug("fuse_attention: failed to match qk path") + return None + + q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [0, None, 0, 0]) + if q_nodes is None: + logger.debug("fuse_attention: failed to match q path") + return None + (mul_q, _transpose_q, reshape_q, matmul_q) = q_nodes + + k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [1, None, 0, 0]) + if k_nodes is None: + logger.debug("fuse_attention: failed to match k path") + return None + + (_mul_k, _, _, matmul_k) = k_nodes + + # The scalar for Q and K is sqrt(1.0/sqrt(head_size)). + mul_q_nodes = self.model.match_parent_path( + mul_q, + ["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"], + [None, 0, 1, 0, 0, 0, 0, 0], + ) + if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q: + logger.debug("fuse_attention: failed to match mul_q path") + return None + + return True, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v