mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Add --use_multi_head_attention in transformers fusion (#14198)
Add an option --use_multi_head_attention to fuse model with MultiHeadAttention operator instead of Attention operator for testing purpose. Note that MultiHeadAttention can be used in self-attention and cross-attention, while Attention operator is used for self-attention only. In Attention operator, there is packed Q/K/V weights for input projection, but that MatMul of input projection is excluded from MultiHeadAttention.
This commit is contained in:
parent
83ad562826
commit
012b34dc4e
6 changed files with 180 additions and 92 deletions
|
|
@ -93,11 +93,14 @@ class FusionAttention(Fusion):
|
|||
hidden_size: int,
|
||||
num_heads: int,
|
||||
attention_mask: AttentionMask,
|
||||
use_multi_head_attention: bool = False,
|
||||
):
|
||||
super().__init__(model, "Attention", ["SkipLayerNormalization", "LayerNormalization"])
|
||||
attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention"
|
||||
super().__init__(model, attention_op_name, ["SkipLayerNormalization", "LayerNormalization"])
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.attention_mask = attention_mask
|
||||
self.use_multi_head_attention = use_multi_head_attention
|
||||
|
||||
# Flags to show warning only once
|
||||
self.num_heads_warning = True
|
||||
|
|
@ -108,18 +111,18 @@ class FusionAttention(Fusion):
|
|||
Detect num_heads and hidden_size from Concat node in the following subgraph:
|
||||
|
||||
SkipLayerNormalization or EmbedLayerNormalization
|
||||
/ \
|
||||
MatMul Shape
|
||||
| |
|
||||
Add Gather(indices=0)
|
||||
\ |
|
||||
\ Unsqueeze
|
||||
\ |
|
||||
\ Concat (*, -1, 12, 64)
|
||||
\ /
|
||||
Reshape
|
||||
|
|
||||
Transpose
|
||||
/ |
|
||||
MatMul Shape
|
||||
| |
|
||||
Add Gather(indices=0)
|
||||
| |
|
||||
| Unsqueeze
|
||||
| |
|
||||
| Concat (*, -1, 12, 64)
|
||||
| /
|
||||
Reshape
|
||||
|
|
||||
Transpose
|
||||
"""
|
||||
if len(concat.input) == 4:
|
||||
num_heads = self.model.get_constant_value(concat.input[2])
|
||||
|
|
@ -307,17 +310,18 @@ class FusionAttention(Fusion):
|
|||
|
||||
attention_node_name = self.model.create_node_name("Attention")
|
||||
|
||||
weight = helper.make_tensor(
|
||||
name=attention_node_name + "_qkv_weight",
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[qw_in_size, qkv_weight_dim],
|
||||
vals=qkv_weight.flatten().tolist(),
|
||||
)
|
||||
if not self.use_multi_head_attention:
|
||||
weight = helper.make_tensor(
|
||||
name=attention_node_name + "_qkv_weight",
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[qw_in_size, qkv_weight_dim],
|
||||
vals=qkv_weight.flatten().tolist(),
|
||||
)
|
||||
|
||||
# Sometimes weights and bias are stored in fp16
|
||||
if q_weight.data_type == 10:
|
||||
weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name))
|
||||
self.model.add_initializer(weight, self.this_graph_name)
|
||||
# Sometimes weights and bias are stored in fp16
|
||||
if q_weight.data_type == 10:
|
||||
weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name))
|
||||
self.model.add_initializer(weight, self.this_graph_name)
|
||||
|
||||
bias = helper.make_tensor(
|
||||
name=attention_node_name + "_qkv_bias",
|
||||
|
|
@ -329,26 +333,48 @@ class FusionAttention(Fusion):
|
|||
bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name))
|
||||
self.model.add_initializer(bias, self.this_graph_name)
|
||||
|
||||
attention_inputs = [
|
||||
input,
|
||||
attention_node_name + "_qkv_weight",
|
||||
attention_node_name + "_qkv_bias",
|
||||
]
|
||||
if mask_index is not None:
|
||||
attention_inputs.append(mask_index)
|
||||
# For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
|
||||
if self.use_multi_head_attention:
|
||||
if add_qk_str is not None:
|
||||
logger.debug("MultiHeadAttention does not support extra_add_qk: cannot fuse the attention.")
|
||||
return None
|
||||
|
||||
attention_inputs = [
|
||||
q_matmul.output[0],
|
||||
k_matmul.output[0],
|
||||
v_matmul.output[0],
|
||||
attention_node_name + "_qkv_bias",
|
||||
]
|
||||
if mask_index is not None:
|
||||
attention_inputs.append(mask_index)
|
||||
|
||||
attention_node = helper.make_node(
|
||||
"MultiHeadAttention",
|
||||
inputs=attention_inputs,
|
||||
outputs=[output],
|
||||
name=attention_node_name,
|
||||
)
|
||||
else:
|
||||
attention_inputs.append("")
|
||||
attention_inputs = [
|
||||
input,
|
||||
attention_node_name + "_qkv_weight",
|
||||
attention_node_name + "_qkv_bias",
|
||||
]
|
||||
if mask_index is not None:
|
||||
attention_inputs.append(mask_index)
|
||||
else:
|
||||
attention_inputs.append("")
|
||||
|
||||
if add_qk_str is not None:
|
||||
attention_inputs.append("")
|
||||
attention_inputs.append(add_qk_str)
|
||||
if add_qk_str is not None:
|
||||
attention_inputs.append("") # no past
|
||||
attention_inputs.append(add_qk_str)
|
||||
|
||||
attention_node = helper.make_node(
|
||||
"Attention",
|
||||
inputs=attention_inputs,
|
||||
outputs=[output],
|
||||
name=attention_node_name,
|
||||
)
|
||||
attention_node = helper.make_node(
|
||||
"Attention",
|
||||
inputs=attention_inputs,
|
||||
outputs=[output],
|
||||
name=attention_node_name,
|
||||
)
|
||||
attention_node.domain = "com.microsoft"
|
||||
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
||||
|
||||
|
|
@ -595,10 +621,11 @@ class FusionAttention(Fusion):
|
|||
|
||||
self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
|
||||
self.nodes_to_remove.extend(qk_nodes)
|
||||
self.nodes_to_remove.extend(q_nodes)
|
||||
self.nodes_to_remove.extend(k_nodes)
|
||||
self.nodes_to_remove.extend(v_nodes)
|
||||
|
||||
# For MultiHeadAttention operator, MatMul nodes for Q/K/V projection shall not be fused.
|
||||
self.nodes_to_remove.extend(q_nodes if not self.use_multi_head_attention else q_nodes[:-1])
|
||||
self.nodes_to_remove.extend(k_nodes if not self.use_multi_head_attention else k_nodes[:-1])
|
||||
self.nodes_to_remove.extend(v_nodes if not self.use_multi_head_attention else v_nodes[:-1])
|
||||
|
||||
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
|
||||
# self.nodes_to_remove.extend(mask_nodes)
|
||||
self.prune_graph = True
|
||||
|
|
|
|||
|
|
@ -63,48 +63,64 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
self.attention = self.model.find_first_child_by_type(
|
||||
layernorm, "Attention", input_name_to_nodes, recursive=False
|
||||
)
|
||||
if self.attention is None:
|
||||
# In case user disables attention fusion, check whether subgraph looks like Attention.
|
||||
if layernorm.output[0] not in input_name_to_nodes:
|
||||
|
||||
if self.attention is not None:
|
||||
return True
|
||||
|
||||
if layernorm.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
children = input_name_to_nodes[layernorm.output[0]]
|
||||
children_types = sorted([child.op_type for child in children])
|
||||
|
||||
# Try find MultiHeadAttention
|
||||
if children_types == ["MatMul", "MatMul", "MatMul", "SkipLayerNormalization"]:
|
||||
for node in children:
|
||||
if node.op_type == "SkipLayerNormalization":
|
||||
path1 = self.model.match_parent_path(
|
||||
node,
|
||||
["Add", "MatMul", "MultiHeadAttention", "MatMul"],
|
||||
[None, None, 0, 0],
|
||||
)
|
||||
if path1 is not None and path1[-1].input[0] == layernorm.output[0]:
|
||||
self.cross_attention = path1[2]
|
||||
return True
|
||||
|
||||
# In case user disables attention fusion, check whether subgraph looks like Attention.
|
||||
# For Albert, there is MatMul+Add after embedding layer before attention.
|
||||
if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes:
|
||||
grandchildren = input_name_to_nodes[children[0].output[0]]
|
||||
if (
|
||||
len(grandchildren) == 1
|
||||
and grandchildren[0].op_type == "Add"
|
||||
and grandchildren[0].output[0] in input_name_to_nodes
|
||||
):
|
||||
nodes = input_name_to_nodes[grandchildren[0].output[0]]
|
||||
for node in nodes:
|
||||
if node.op_type == "Attention":
|
||||
self.attention = node
|
||||
return True
|
||||
children_types = sorted([child.op_type for child in nodes])
|
||||
|
||||
# Two Shape nodes might be merged by ORT
|
||||
if is_distil_bert:
|
||||
# SkipLayerNormailization might exist when model has been optimized by ORT first.
|
||||
if (
|
||||
children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"]
|
||||
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"]
|
||||
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"]
|
||||
):
|
||||
logger.debug("No Attention like subgraph in children of LayerNormalization")
|
||||
return False
|
||||
else:
|
||||
if children_types != ["Add", "MatMul", "MatMul", "MatMul",] and children_types != [
|
||||
"MatMul",
|
||||
"MatMul",
|
||||
"MatMul",
|
||||
"SkipLayerNormalization",
|
||||
]:
|
||||
logger.debug("No Attention like subgraph in children of LayerNormalization")
|
||||
return False
|
||||
children = input_name_to_nodes[layernorm.output[0]]
|
||||
|
||||
# For Albert, there is MatMul+Add after embedding layer before attention.
|
||||
if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes:
|
||||
grandchildren = input_name_to_nodes[children[0].output[0]]
|
||||
if (
|
||||
len(grandchildren) == 1
|
||||
and grandchildren[0].op_type == "Add"
|
||||
and grandchildren[0].output[0] in input_name_to_nodes
|
||||
):
|
||||
nodes = input_name_to_nodes[grandchildren[0].output[0]]
|
||||
for node in nodes:
|
||||
if node.op_type == "Attention":
|
||||
self.attention = node
|
||||
return True
|
||||
children_types = sorted([child.op_type for child in nodes])
|
||||
else:
|
||||
children_types = sorted([child.op_type for child in children])
|
||||
|
||||
# Two Shape nodes might be merged by ORT
|
||||
if is_distil_bert:
|
||||
# SkipLayerNormailization might exist when model has been optimized by ORT first.
|
||||
if (
|
||||
children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"]
|
||||
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"]
|
||||
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"]
|
||||
):
|
||||
logger.debug("No Attention like subgraph in children of LayerNormalization")
|
||||
return False
|
||||
else:
|
||||
if children_types != ["Add", "MatMul", "MatMul", "MatMul",] and children_types != [
|
||||
"MatMul",
|
||||
"MatMul",
|
||||
"MatMul",
|
||||
"SkipLayerNormalization",
|
||||
]:
|
||||
logger.debug("No Attention like subgraph in children of LayerNormalization")
|
||||
return False
|
||||
return True
|
||||
|
||||
def match_position_embedding_distilbert(self, position_embedding_gather, input_ids, output_name_to_node):
|
||||
|
|
@ -713,11 +729,15 @@ class FusionEmbedLayerNormalization(FusionEmbedLayerNoMask):
|
|||
|
||||
for attention_node in attention_nodes:
|
||||
logger.debug("update mask_index in %s", attention_node.name)
|
||||
attention_node.input[3] = embed_node.output[1]
|
||||
if attention_node.op_type == "Attention":
|
||||
attention_node.input[3] = embed_node.output[1]
|
||||
elif attention_node.op_type == "MultiHeadAttention":
|
||||
attention_node.input[4] = embed_node.output[1]
|
||||
|
||||
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
||||
# Reset attention and embed_node so that we know fusion is successful when they are not None.
|
||||
self.attention = None
|
||||
self.cross_attention = None
|
||||
self.embed_node = None
|
||||
super().fuse(node, input_name_to_nodes, output_name_to_node)
|
||||
|
||||
|
|
@ -729,15 +749,19 @@ class FusionEmbedLayerNormalization(FusionEmbedLayerNoMask):
|
|||
self.increase_counter("EmbedLayerNormalization(no mask)")
|
||||
return
|
||||
|
||||
if self.attention is None:
|
||||
if self.attention is None and self.cross_attention is None:
|
||||
logger.debug("EmbedLayerNormalization will not have mask since attention node is not found")
|
||||
self.increase_counter("EmbedLayerNormalization(no mask)")
|
||||
return
|
||||
|
||||
mask_int32 = self.attention.input[3]
|
||||
if self.attention:
|
||||
mask_int32 = self.attention.input[3]
|
||||
else:
|
||||
mask_int32 = self.cross_attention.input[4]
|
||||
|
||||
children_nodes = input_name_to_nodes[mask_int32]
|
||||
if self.model.find_graph_input(mask_int32):
|
||||
attention_nodes = [node for node in children_nodes if node.op_type == "Attention"]
|
||||
attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
|
||||
self.replace_mask(mask_int32, attention_nodes)
|
||||
self.increase_counter("EmbedLayerNormalization(with mask)")
|
||||
return
|
||||
|
|
@ -749,7 +773,7 @@ class FusionEmbedLayerNormalization(FusionEmbedLayerNoMask):
|
|||
|
||||
node = output_name_to_node[mask_int32]
|
||||
if node.op_type in ["ReduceSum", "Cast"]:
|
||||
attention_nodes = [node for node in children_nodes if node.op_type == "Attention"]
|
||||
attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
|
||||
if node.op_type == "ReduceSum":
|
||||
mask_int32 = node.input[0]
|
||||
if len(children_nodes) == len(attention_nodes):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,14 @@ class FusionOptions:
|
|||
self.enable_gelu = True
|
||||
self.enable_layer_norm = True
|
||||
self.enable_attention = True
|
||||
|
||||
# Use MultiHeadAttention instead of Attention operator. The difference:
|
||||
# (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is
|
||||
# merged into one.
|
||||
# (2) Attention could only handle self attention; MultiHeadAttention could handle both self and cross attention.
|
||||
# (3) MultiHeadAttention has only cuda implementation right now.
|
||||
self.use_multi_head_attention = False
|
||||
|
||||
self.enable_skip_layer_norm = True
|
||||
self.enable_embed_layer_norm = True
|
||||
self.enable_bias_skip_layer_norm = True
|
||||
|
|
@ -48,6 +56,8 @@ class FusionOptions:
|
|||
options.enable_layer_norm = False
|
||||
if args.disable_attention:
|
||||
options.enable_attention = False
|
||||
if args.use_multi_head_attention:
|
||||
options.use_multi_head_attention = True
|
||||
if args.disable_skip_layer_norm:
|
||||
options.enable_skip_layer_norm = False
|
||||
if args.disable_embed_layer_norm:
|
||||
|
|
@ -165,3 +175,13 @@ class FusionOptions:
|
|||
help="no attention mask. Only works for model_type=bert",
|
||||
)
|
||||
parser.set_defaults(no_attention_mask=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_multi_head_attention",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use MultiHeadAttention instead of Attention operator for testing purpose. "
|
||||
"Note that MultiHeadAttention might be slower than Attention since MatMul of input projection is excluded. "
|
||||
"MultiHeadAttention has only CUDA implementation so the model can only run with cuda execution provider.",
|
||||
)
|
||||
parser.set_defaults(use_multi_head_attention=False)
|
||||
|
|
|
|||
|
|
@ -385,9 +385,14 @@ class BertOnnxModel(OnnxModel):
|
|||
if (options is None) or options.enable_skip_layer_norm:
|
||||
self.fuse_skip_layer_norm()
|
||||
|
||||
if options is not None:
|
||||
self.attention_mask.set_mask_format(options.attention_mask_format)
|
||||
if options.use_multi_head_attention:
|
||||
self.attention_fusion = FusionAttention(
|
||||
self, self.hidden_size, self.num_heads, self.attention_mask, options.use_multi_head_attention
|
||||
)
|
||||
|
||||
if (options is None) or options.enable_attention:
|
||||
if options is not None:
|
||||
self.attention_mask.set_mask_format(options.attention_mask_format)
|
||||
self.fuse_attention()
|
||||
|
||||
# Perform the MatMul fusion after the Attention fusion as we do not
|
||||
|
|
@ -438,6 +443,7 @@ class BertOnnxModel(OnnxModel):
|
|||
ops = [
|
||||
"EmbedLayerNormalization",
|
||||
"Attention",
|
||||
"MultiHeadAttention",
|
||||
"Gelu",
|
||||
"FastGelu",
|
||||
"BiasGelu",
|
||||
|
|
@ -459,7 +465,7 @@ class BertOnnxModel(OnnxModel):
|
|||
"""
|
||||
op_count = self.get_fused_operator_statistics()
|
||||
embed = op_count["EmbedLayerNormalization"]
|
||||
attention = op_count["Attention"] + op_count["QOrderedAttention"]
|
||||
attention = op_count["Attention"] + op_count["MultiHeadAttention"] + op_count["QOrderedAttention"]
|
||||
gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"]
|
||||
layer_norm = op_count["LayerNormalization"] + op_count["SkipLayerNormalization"]
|
||||
is_perfect = (embed > 0) and (attention > 0) and (attention == gelu) and (layer_norm >= 2 * attention)
|
||||
|
|
|
|||
|
|
@ -33,6 +33,17 @@ class TestFusion(unittest.TestCase):
|
|||
|
||||
self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph))
|
||||
|
||||
def test_multi_head_attention_fusion(self):
|
||||
model = create_bert_attention()
|
||||
dir = "."
|
||||
model_path = os.path.join(dir, "attention.onnx")
|
||||
onnx.save(model, model_path)
|
||||
options = FusionOptions("bert")
|
||||
options.use_multi_head_attention = True
|
||||
optimized_model = optimize_model(model_path, optimization_options=options)
|
||||
os.remove(model_path)
|
||||
self.verify_fusion(optimized_model, "attention_mha.onnx")
|
||||
|
||||
def test_attention_fusion(self):
|
||||
model = create_bert_attention()
|
||||
dir = "."
|
||||
|
|
|
|||
Binary file not shown.
Loading…
Reference in a new issue