Add --use_multi_head_attention in transformers fusion (#14198)

Add an option --use_multi_head_attention to fuse model with
MultiHeadAttention operator instead of Attention operator for testing
purpose.

Note that MultiHeadAttention can be used in self-attention and
cross-attention, while Attention operator is used for self-attention
only. In Attention operator, there is packed Q/K/V weights for input
projection, but that MatMul of input projection is excluded from
MultiHeadAttention.
This commit is contained in:
Tianlei Wu 2023-01-11 13:20:05 -08:00 committed by GitHub
parent 83ad562826
commit 012b34dc4e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 180 additions and 92 deletions

View file

@ -93,11 +93,14 @@ class FusionAttention(Fusion):
hidden_size: int,
num_heads: int,
attention_mask: AttentionMask,
use_multi_head_attention: bool = False,
):
super().__init__(model, "Attention", ["SkipLayerNormalization", "LayerNormalization"])
attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention"
super().__init__(model, attention_op_name, ["SkipLayerNormalization", "LayerNormalization"])
self.hidden_size = hidden_size
self.num_heads = num_heads
self.attention_mask = attention_mask
self.use_multi_head_attention = use_multi_head_attention
# Flags to show warning only once
self.num_heads_warning = True
@ -108,18 +111,18 @@ class FusionAttention(Fusion):
Detect num_heads and hidden_size from Concat node in the following subgraph:
SkipLayerNormalization or EmbedLayerNormalization
/ \
MatMul Shape
| |
Add Gather(indices=0)
\ |
\ Unsqueeze
\ |
\ Concat (*, -1, 12, 64)
\ /
Reshape
|
Transpose
/ |
MatMul Shape
| |
Add Gather(indices=0)
| |
| Unsqueeze
| |
| Concat (*, -1, 12, 64)
| /
Reshape
|
Transpose
"""
if len(concat.input) == 4:
num_heads = self.model.get_constant_value(concat.input[2])
@ -307,17 +310,18 @@ class FusionAttention(Fusion):
attention_node_name = self.model.create_node_name("Attention")
weight = helper.make_tensor(
name=attention_node_name + "_qkv_weight",
data_type=TensorProto.FLOAT,
dims=[qw_in_size, qkv_weight_dim],
vals=qkv_weight.flatten().tolist(),
)
if not self.use_multi_head_attention:
weight = helper.make_tensor(
name=attention_node_name + "_qkv_weight",
data_type=TensorProto.FLOAT,
dims=[qw_in_size, qkv_weight_dim],
vals=qkv_weight.flatten().tolist(),
)
# Sometimes weights and bias are stored in fp16
if q_weight.data_type == 10:
weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name))
self.model.add_initializer(weight, self.this_graph_name)
# Sometimes weights and bias are stored in fp16
if q_weight.data_type == 10:
weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name))
self.model.add_initializer(weight, self.this_graph_name)
bias = helper.make_tensor(
name=attention_node_name + "_qkv_bias",
@ -329,26 +333,48 @@ class FusionAttention(Fusion):
bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name))
self.model.add_initializer(bias, self.this_graph_name)
attention_inputs = [
input,
attention_node_name + "_qkv_weight",
attention_node_name + "_qkv_bias",
]
if mask_index is not None:
attention_inputs.append(mask_index)
# For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
if self.use_multi_head_attention:
if add_qk_str is not None:
logger.debug("MultiHeadAttention does not support extra_add_qk: cannot fuse the attention.")
return None
attention_inputs = [
q_matmul.output[0],
k_matmul.output[0],
v_matmul.output[0],
attention_node_name + "_qkv_bias",
]
if mask_index is not None:
attention_inputs.append(mask_index)
attention_node = helper.make_node(
"MultiHeadAttention",
inputs=attention_inputs,
outputs=[output],
name=attention_node_name,
)
else:
attention_inputs.append("")
attention_inputs = [
input,
attention_node_name + "_qkv_weight",
attention_node_name + "_qkv_bias",
]
if mask_index is not None:
attention_inputs.append(mask_index)
else:
attention_inputs.append("")
if add_qk_str is not None:
attention_inputs.append("")
attention_inputs.append(add_qk_str)
if add_qk_str is not None:
attention_inputs.append("") # no past
attention_inputs.append(add_qk_str)
attention_node = helper.make_node(
"Attention",
inputs=attention_inputs,
outputs=[output],
name=attention_node_name,
)
attention_node = helper.make_node(
"Attention",
inputs=attention_inputs,
outputs=[output],
name=attention_node_name,
)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
@ -595,10 +621,11 @@ class FusionAttention(Fusion):
self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
self.nodes_to_remove.extend(qk_nodes)
self.nodes_to_remove.extend(q_nodes)
self.nodes_to_remove.extend(k_nodes)
self.nodes_to_remove.extend(v_nodes)
# For MultiHeadAttention operator, MatMul nodes for Q/K/V projection shall not be fused.
self.nodes_to_remove.extend(q_nodes if not self.use_multi_head_attention else q_nodes[:-1])
self.nodes_to_remove.extend(k_nodes if not self.use_multi_head_attention else k_nodes[:-1])
self.nodes_to_remove.extend(v_nodes if not self.use_multi_head_attention else v_nodes[:-1])
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
# self.nodes_to_remove.extend(mask_nodes)
self.prune_graph = True

View file

@ -63,48 +63,64 @@ class FusionEmbedLayerNoMask(Fusion):
self.attention = self.model.find_first_child_by_type(
layernorm, "Attention", input_name_to_nodes, recursive=False
)
if self.attention is None:
# In case user disables attention fusion, check whether subgraph looks like Attention.
if layernorm.output[0] not in input_name_to_nodes:
if self.attention is not None:
return True
if layernorm.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[layernorm.output[0]]
children_types = sorted([child.op_type for child in children])
# Try find MultiHeadAttention
if children_types == ["MatMul", "MatMul", "MatMul", "SkipLayerNormalization"]:
for node in children:
if node.op_type == "SkipLayerNormalization":
path1 = self.model.match_parent_path(
node,
["Add", "MatMul", "MultiHeadAttention", "MatMul"],
[None, None, 0, 0],
)
if path1 is not None and path1[-1].input[0] == layernorm.output[0]:
self.cross_attention = path1[2]
return True
# In case user disables attention fusion, check whether subgraph looks like Attention.
# For Albert, there is MatMul+Add after embedding layer before attention.
if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes:
grandchildren = input_name_to_nodes[children[0].output[0]]
if (
len(grandchildren) == 1
and grandchildren[0].op_type == "Add"
and grandchildren[0].output[0] in input_name_to_nodes
):
nodes = input_name_to_nodes[grandchildren[0].output[0]]
for node in nodes:
if node.op_type == "Attention":
self.attention = node
return True
children_types = sorted([child.op_type for child in nodes])
# Two Shape nodes might be merged by ORT
if is_distil_bert:
# SkipLayerNormailization might exist when model has been optimized by ORT first.
if (
children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"]
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"]
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"]
):
logger.debug("No Attention like subgraph in children of LayerNormalization")
return False
else:
if children_types != ["Add", "MatMul", "MatMul", "MatMul",] and children_types != [
"MatMul",
"MatMul",
"MatMul",
"SkipLayerNormalization",
]:
logger.debug("No Attention like subgraph in children of LayerNormalization")
return False
children = input_name_to_nodes[layernorm.output[0]]
# For Albert, there is MatMul+Add after embedding layer before attention.
if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes:
grandchildren = input_name_to_nodes[children[0].output[0]]
if (
len(grandchildren) == 1
and grandchildren[0].op_type == "Add"
and grandchildren[0].output[0] in input_name_to_nodes
):
nodes = input_name_to_nodes[grandchildren[0].output[0]]
for node in nodes:
if node.op_type == "Attention":
self.attention = node
return True
children_types = sorted([child.op_type for child in nodes])
else:
children_types = sorted([child.op_type for child in children])
# Two Shape nodes might be merged by ORT
if is_distil_bert:
# SkipLayerNormailization might exist when model has been optimized by ORT first.
if (
children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"]
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"]
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"]
):
logger.debug("No Attention like subgraph in children of LayerNormalization")
return False
else:
if children_types != ["Add", "MatMul", "MatMul", "MatMul",] and children_types != [
"MatMul",
"MatMul",
"MatMul",
"SkipLayerNormalization",
]:
logger.debug("No Attention like subgraph in children of LayerNormalization")
return False
return True
def match_position_embedding_distilbert(self, position_embedding_gather, input_ids, output_name_to_node):
@ -713,11 +729,15 @@ class FusionEmbedLayerNormalization(FusionEmbedLayerNoMask):
for attention_node in attention_nodes:
logger.debug("update mask_index in %s", attention_node.name)
attention_node.input[3] = embed_node.output[1]
if attention_node.op_type == "Attention":
attention_node.input[3] = embed_node.output[1]
elif attention_node.op_type == "MultiHeadAttention":
attention_node.input[4] = embed_node.output[1]
def fuse(self, node, input_name_to_nodes, output_name_to_node):
# Reset attention and embed_node so that we know fusion is successful when they are not None.
self.attention = None
self.cross_attention = None
self.embed_node = None
super().fuse(node, input_name_to_nodes, output_name_to_node)
@ -729,15 +749,19 @@ class FusionEmbedLayerNormalization(FusionEmbedLayerNoMask):
self.increase_counter("EmbedLayerNormalization(no mask)")
return
if self.attention is None:
if self.attention is None and self.cross_attention is None:
logger.debug("EmbedLayerNormalization will not have mask since attention node is not found")
self.increase_counter("EmbedLayerNormalization(no mask)")
return
mask_int32 = self.attention.input[3]
if self.attention:
mask_int32 = self.attention.input[3]
else:
mask_int32 = self.cross_attention.input[4]
children_nodes = input_name_to_nodes[mask_int32]
if self.model.find_graph_input(mask_int32):
attention_nodes = [node for node in children_nodes if node.op_type == "Attention"]
attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
self.replace_mask(mask_int32, attention_nodes)
self.increase_counter("EmbedLayerNormalization(with mask)")
return
@ -749,7 +773,7 @@ class FusionEmbedLayerNormalization(FusionEmbedLayerNoMask):
node = output_name_to_node[mask_int32]
if node.op_type in ["ReduceSum", "Cast"]:
attention_nodes = [node for node in children_nodes if node.op_type == "Attention"]
attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
if node.op_type == "ReduceSum":
mask_int32 = node.input[0]
if len(children_nodes) == len(attention_nodes):

View file

@ -19,6 +19,14 @@ class FusionOptions:
self.enable_gelu = True
self.enable_layer_norm = True
self.enable_attention = True
# Use MultiHeadAttention instead of Attention operator. The difference:
# (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is
# merged into one.
# (2) Attention could only handle self attention; MultiHeadAttention could handle both self and cross attention.
# (3) MultiHeadAttention has only cuda implementation right now.
self.use_multi_head_attention = False
self.enable_skip_layer_norm = True
self.enable_embed_layer_norm = True
self.enable_bias_skip_layer_norm = True
@ -48,6 +56,8 @@ class FusionOptions:
options.enable_layer_norm = False
if args.disable_attention:
options.enable_attention = False
if args.use_multi_head_attention:
options.use_multi_head_attention = True
if args.disable_skip_layer_norm:
options.enable_skip_layer_norm = False
if args.disable_embed_layer_norm:
@ -165,3 +175,13 @@ class FusionOptions:
help="no attention mask. Only works for model_type=bert",
)
parser.set_defaults(no_attention_mask=False)
parser.add_argument(
"--use_multi_head_attention",
required=False,
action="store_true",
help="Use MultiHeadAttention instead of Attention operator for testing purpose. "
"Note that MultiHeadAttention might be slower than Attention since MatMul of input projection is excluded. "
"MultiHeadAttention has only CUDA implementation so the model can only run with cuda execution provider.",
)
parser.set_defaults(use_multi_head_attention=False)

View file

@ -385,9 +385,14 @@ class BertOnnxModel(OnnxModel):
if (options is None) or options.enable_skip_layer_norm:
self.fuse_skip_layer_norm()
if options is not None:
self.attention_mask.set_mask_format(options.attention_mask_format)
if options.use_multi_head_attention:
self.attention_fusion = FusionAttention(
self, self.hidden_size, self.num_heads, self.attention_mask, options.use_multi_head_attention
)
if (options is None) or options.enable_attention:
if options is not None:
self.attention_mask.set_mask_format(options.attention_mask_format)
self.fuse_attention()
# Perform the MatMul fusion after the Attention fusion as we do not
@ -438,6 +443,7 @@ class BertOnnxModel(OnnxModel):
ops = [
"EmbedLayerNormalization",
"Attention",
"MultiHeadAttention",
"Gelu",
"FastGelu",
"BiasGelu",
@ -459,7 +465,7 @@ class BertOnnxModel(OnnxModel):
"""
op_count = self.get_fused_operator_statistics()
embed = op_count["EmbedLayerNormalization"]
attention = op_count["Attention"] + op_count["QOrderedAttention"]
attention = op_count["Attention"] + op_count["MultiHeadAttention"] + op_count["QOrderedAttention"]
gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"]
layer_norm = op_count["LayerNormalization"] + op_count["SkipLayerNormalization"]
is_perfect = (embed > 0) and (attention > 0) and (attention == gelu) and (layer_norm >= 2 * attention)

View file

@ -33,6 +33,17 @@ class TestFusion(unittest.TestCase):
self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph))
def test_multi_head_attention_fusion(self):
model = create_bert_attention()
dir = "."
model_path = os.path.join(dir, "attention.onnx")
onnx.save(model, model_path)
options = FusionOptions("bert")
options.use_multi_head_attention = True
optimized_model = optimize_model(model_path, optimization_options=options)
os.remove(model_path)
self.verify_fusion(optimized_model, "attention_mha.onnx")
def test_attention_fusion(self):
model = create_bert_attention()
dir = "."