mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
stabilize fusion script with a seperate create_attention_node() (#15670)
### Description <!-- Describe your changes. --> previously it used create_attention_node() from base class in fusion_attention.py. sometimes the changes in that file may silently lead to generating a bad model. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
parent
5885abfb35
commit
d05777ddb6
1 changed files with 109 additions and 4 deletions
|
|
@ -3,7 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
from typing import Dict, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from fusion_attention import AttentionMask, FusionAttention
|
||||
|
|
@ -39,6 +39,114 @@ class FusionT5Attention(FusionAttention):
|
|||
)
|
||||
self.static_kv = 1
|
||||
|
||||
def create_attention_node(
|
||||
self,
|
||||
mask_index: str,
|
||||
q_matmul: NodeProto,
|
||||
k_matmul: NodeProto,
|
||||
v_matmul: NodeProto,
|
||||
num_heads: int,
|
||||
hidden_size: int,
|
||||
input: str,
|
||||
output: str,
|
||||
add_qk_str: str,
|
||||
scale: Optional[float] = None,
|
||||
) -> Union[NodeProto, None]:
|
||||
"""Create an Attention node.
|
||||
Args:
|
||||
mask_index (str): mask input
|
||||
q_matmul (NodeProto): MatMul node in fully connection for Q
|
||||
k_matmul (NodeProto): MatMul node in fully connection for K
|
||||
v_matmul (NodeProto): MatMul node in fully connection for V
|
||||
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
|
||||
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
|
||||
input (str): input name
|
||||
output (str): output name
|
||||
Returns:
|
||||
Union[NodeProto, None]: the node created or None if failed.
|
||||
"""
|
||||
assert num_heads > 0
|
||||
|
||||
if hidden_size > 0 and (hidden_size % num_heads) != 0:
|
||||
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
||||
return None
|
||||
|
||||
q_weight = self.model.get_initializer(q_matmul.input[1])
|
||||
k_weight = self.model.get_initializer(k_matmul.input[1])
|
||||
v_weight = self.model.get_initializer(v_matmul.input[1])
|
||||
|
||||
if q_weight is None:
|
||||
print(
|
||||
f"{q_matmul.input[1]} is not an initializer. "
|
||||
"Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
|
||||
)
|
||||
return None
|
||||
|
||||
qw = NumpyHelper.to_array(q_weight)
|
||||
kw = NumpyHelper.to_array(k_weight)
|
||||
vw = NumpyHelper.to_array(v_weight)
|
||||
|
||||
# assert q and k have same shape as expected
|
||||
assert qw.shape == kw.shape
|
||||
|
||||
qw_in_size = qw.shape[0]
|
||||
kw_in_size = kw.shape[0]
|
||||
vw_in_size = vw.shape[0]
|
||||
|
||||
assert qw_in_size == kw_in_size == vw_in_size
|
||||
|
||||
if hidden_size > 0 and hidden_size != qw_in_size:
|
||||
logger.warning(
|
||||
f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). "
|
||||
"Please provide a correct input hidden size or pass in 0"
|
||||
)
|
||||
|
||||
qw_out_size = np.prod(qw.shape[1:])
|
||||
qkv_weight = np.stack((qw, kw, vw), axis=1)
|
||||
qkv_weight_dim = 3 * qw_out_size
|
||||
|
||||
attention_node_name = self.model.create_node_name("Attention")
|
||||
|
||||
weight = helper.make_tensor(
|
||||
name=attention_node_name + "_qkv_weight",
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[qw_in_size, qkv_weight_dim],
|
||||
vals=qkv_weight.flatten().tolist(),
|
||||
)
|
||||
|
||||
self.model.add_initializer(weight, self.this_graph_name)
|
||||
|
||||
attention_inputs = [
|
||||
input,
|
||||
attention_node_name + "_qkv_weight",
|
||||
"",
|
||||
]
|
||||
if mask_index is not None:
|
||||
attention_inputs.append(mask_index)
|
||||
else:
|
||||
attention_inputs.append("")
|
||||
|
||||
if add_qk_str is not None:
|
||||
attention_inputs.append("") # no past
|
||||
attention_inputs.append(add_qk_str)
|
||||
|
||||
attention_node = helper.make_node(
|
||||
"Attention",
|
||||
inputs=attention_inputs,
|
||||
outputs=[output],
|
||||
name=attention_node_name,
|
||||
)
|
||||
attention_node.domain = "com.microsoft"
|
||||
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
||||
|
||||
if scale is not None:
|
||||
attention_node.attribute.extend([helper.make_attribute("scale", scale)])
|
||||
|
||||
if self.mask_filter_value is not None:
|
||||
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
|
||||
|
||||
return attention_node
|
||||
|
||||
def create_mha_node(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -210,9 +318,6 @@ class FusionT5Attention(FusionAttention):
|
|||
matmul_q,
|
||||
matmul_k,
|
||||
matmul_v,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
q_num_heads,
|
||||
q_hidden_size,
|
||||
input_shape_node.input[0],
|
||||
|
|
|
|||
Loading…
Reference in a new issue