stabilize fusion script with a seperate create_attention_node() (#15670)

### Description
<!-- Describe your changes. -->

previously it used create_attention_node() from base class in
fusion_attention.py. sometimes the changes in that file may silently
lead to generating a bad model.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
Ye Wang 2023-04-25 13:07:58 -07:00 committed by GitHub
parent 5885abfb35
commit d05777ddb6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -3,7 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from typing import Dict, Union
from typing import Dict, Optional, Union
import numpy as np
from fusion_attention import AttentionMask, FusionAttention
@ -39,6 +39,114 @@ class FusionT5Attention(FusionAttention):
)
self.static_kv = 1
def create_attention_node(
self,
mask_index: str,
q_matmul: NodeProto,
k_matmul: NodeProto,
v_matmul: NodeProto,
num_heads: int,
hidden_size: int,
input: str,
output: str,
add_qk_str: str,
scale: Optional[float] = None,
) -> Union[NodeProto, None]:
"""Create an Attention node.
Args:
mask_index (str): mask input
q_matmul (NodeProto): MatMul node in fully connection for Q
k_matmul (NodeProto): MatMul node in fully connection for K
v_matmul (NodeProto): MatMul node in fully connection for V
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
input (str): input name
output (str): output name
Returns:
Union[NodeProto, None]: the node created or None if failed.
"""
assert num_heads > 0
if hidden_size > 0 and (hidden_size % num_heads) != 0:
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
return None
q_weight = self.model.get_initializer(q_matmul.input[1])
k_weight = self.model.get_initializer(k_matmul.input[1])
v_weight = self.model.get_initializer(v_matmul.input[1])
if q_weight is None:
print(
f"{q_matmul.input[1]} is not an initializer. "
"Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
)
return None
qw = NumpyHelper.to_array(q_weight)
kw = NumpyHelper.to_array(k_weight)
vw = NumpyHelper.to_array(v_weight)
# assert q and k have same shape as expected
assert qw.shape == kw.shape
qw_in_size = qw.shape[0]
kw_in_size = kw.shape[0]
vw_in_size = vw.shape[0]
assert qw_in_size == kw_in_size == vw_in_size
if hidden_size > 0 and hidden_size != qw_in_size:
logger.warning(
f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). "
"Please provide a correct input hidden size or pass in 0"
)
qw_out_size = np.prod(qw.shape[1:])
qkv_weight = np.stack((qw, kw, vw), axis=1)
qkv_weight_dim = 3 * qw_out_size
attention_node_name = self.model.create_node_name("Attention")
weight = helper.make_tensor(
name=attention_node_name + "_qkv_weight",
data_type=TensorProto.FLOAT,
dims=[qw_in_size, qkv_weight_dim],
vals=qkv_weight.flatten().tolist(),
)
self.model.add_initializer(weight, self.this_graph_name)
attention_inputs = [
input,
attention_node_name + "_qkv_weight",
"",
]
if mask_index is not None:
attention_inputs.append(mask_index)
else:
attention_inputs.append("")
if add_qk_str is not None:
attention_inputs.append("") # no past
attention_inputs.append(add_qk_str)
attention_node = helper.make_node(
"Attention",
inputs=attention_inputs,
outputs=[output],
name=attention_node_name,
)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
if scale is not None:
attention_node.attribute.extend([helper.make_attribute("scale", scale)])
if self.mask_filter_value is not None:
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
return attention_node
def create_mha_node(
self,
query: str,
@ -210,9 +318,6 @@ class FusionT5Attention(FusionAttention):
matmul_q,
matmul_k,
matmul_v,
None,
None,
None,
q_num_heads,
q_hidden_size,
input_shape_node.input[0],