From d05777ddb673575ec555b95b19b597797f664b2f Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Tue, 25 Apr 2023 13:07:58 -0700 Subject: [PATCH] stabilize fusion script with a seperate create_attention_node() (#15670) ### Description previously it used create_attention_node() from base class in fusion_attention.py. sometimes the changes in that file may silently lead to generating a bad model. ### Motivation and Context --------- Co-authored-by: Ubuntu --- .../tools/transformers/onnx_model_t5.py | 113 +++++++++++++++++- 1 file changed, 109 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py index 5be31dbb26..8fb31da4a6 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_t5.py +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -from typing import Dict, Union +from typing import Dict, Optional, Union import numpy as np from fusion_attention import AttentionMask, FusionAttention @@ -39,6 +39,114 @@ class FusionT5Attention(FusionAttention): ) self.static_kv = 1 + def create_attention_node( + self, + mask_index: str, + q_matmul: NodeProto, + k_matmul: NodeProto, + v_matmul: NodeProto, + num_heads: int, + hidden_size: int, + input: str, + output: str, + add_qk_str: str, + scale: Optional[float] = None, + ) -> Union[NodeProto, None]: + """Create an Attention node. + Args: + mask_index (str): mask input + q_matmul (NodeProto): MatMul node in fully connection for Q + k_matmul (NodeProto): MatMul node in fully connection for K + v_matmul (NodeProto): MatMul node in fully connection for V + num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. + hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning. + input (str): input name + output (str): output name + Returns: + Union[NodeProto, None]: the node created or None if failed. + """ + assert num_heads > 0 + + if hidden_size > 0 and (hidden_size % num_heads) != 0: + logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") + return None + + q_weight = self.model.get_initializer(q_matmul.input[1]) + k_weight = self.model.get_initializer(k_matmul.input[1]) + v_weight = self.model.get_initializer(v_matmul.input[1]) + + if q_weight is None: + print( + f"{q_matmul.input[1]} is not an initializer. " + "Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion" + ) + return None + + qw = NumpyHelper.to_array(q_weight) + kw = NumpyHelper.to_array(k_weight) + vw = NumpyHelper.to_array(v_weight) + + # assert q and k have same shape as expected + assert qw.shape == kw.shape + + qw_in_size = qw.shape[0] + kw_in_size = kw.shape[0] + vw_in_size = vw.shape[0] + + assert qw_in_size == kw_in_size == vw_in_size + + if hidden_size > 0 and hidden_size != qw_in_size: + logger.warning( + f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). " + "Please provide a correct input hidden size or pass in 0" + ) + + qw_out_size = np.prod(qw.shape[1:]) + qkv_weight = np.stack((qw, kw, vw), axis=1) + qkv_weight_dim = 3 * qw_out_size + + attention_node_name = self.model.create_node_name("Attention") + + weight = helper.make_tensor( + name=attention_node_name + "_qkv_weight", + data_type=TensorProto.FLOAT, + dims=[qw_in_size, qkv_weight_dim], + vals=qkv_weight.flatten().tolist(), + ) + + self.model.add_initializer(weight, self.this_graph_name) + + attention_inputs = [ + input, + attention_node_name + "_qkv_weight", + "", + ] + if mask_index is not None: + attention_inputs.append(mask_index) + else: + attention_inputs.append("") + + if add_qk_str is not None: + attention_inputs.append("") # no past + attention_inputs.append(add_qk_str) + + attention_node = helper.make_node( + "Attention", + inputs=attention_inputs, + outputs=[output], + name=attention_node_name, + ) + attention_node.domain = "com.microsoft" + attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + + if scale is not None: + attention_node.attribute.extend([helper.make_attribute("scale", scale)]) + + if self.mask_filter_value is not None: + attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))]) + + return attention_node + def create_mha_node( self, query: str, @@ -210,9 +318,6 @@ class FusionT5Attention(FusionAttention): matmul_q, matmul_k, matmul_v, - None, - None, - None, q_num_heads, q_hidden_size, input_shape_node.input[0],