Optimize flaubert (#5651)

* optimize flaubert

* fix an issue and format

* revert non-relevent change

* review comments
This commit is contained in:
Ye Wang 2020-11-03 09:51:42 -08:00 committed by GitHub
parent 9b010963b7
commit a028ca41ec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 4 deletions

View file

@ -85,7 +85,7 @@ class FusionAttention(Fusion):
Fuse Attention subgraph into one Attention node.
"""
def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int, attention_mask: AttentionMask):
super().__init__(model, "Attention", "SkipLayerNormalization")
super().__init__(model, "Attention", ["SkipLayerNormalization", "LayerNormalization"])
self.hidden_size = hidden_size
self.num_heads = num_heads
self.attention_mask = attention_mask
@ -154,15 +154,25 @@ class FusionAttention(Fusion):
return attention_node
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
# Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
start_node = normalize_node
if normalize_node.op_type == 'LayerNormalization':
add_before_layernorm = self.model.match_parent(normalize_node, 'Add', 0)
if add_before_layernorm is not None:
start_node = add_before_layernorm
else:
return
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
qkv_nodes = self.model.match_parent_path(normalize_node, ['Add', 'MatMul', 'Reshape', 'Transpose', 'MatMul'],
qkv_nodes = self.model.match_parent_path(start_node, ['Add', 'MatMul', 'Reshape', 'Transpose', 'MatMul'],
[None, 0, 0, 0, 0])
einsum_node = None
if qkv_nodes is not None:
(_, matmul_qkv, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
else:
# Match Albert
qkv_nodes = self.model.match_parent_path(normalize_node, ['Add', 'Einsum', 'Transpose', 'MatMul'],
qkv_nodes = self.model.match_parent_path(start_node, ['Add', 'Einsum', 'Transpose', 'MatMul'],
[1, 0, 0, 0])
if qkv_nodes is not None:
(_, einsum_node, transpose_qkv, matmul_qkv) = qkv_nodes
@ -170,7 +180,7 @@ class FusionAttention(Fusion):
return
other_inputs = []
for i, input in enumerate(normalize_node.input):
for i, input in enumerate(start_node.input):
if input not in output_name_to_node:
continue
@ -181,6 +191,26 @@ class FusionAttention(Fusion):
return
root_input = other_inputs[0]
"""
Match flaubert Mask
|
Mul --> LayerNormalization --> Attention --> MatMul --> Add
| |
| |
+---------------------------------------------------------
"""
mul_before_layernorm = self.model.match_parent(start_node, 'Mul', 0)
if mul_before_layernorm is not None:
mul_children = input_name_to_nodes[mul_before_layernorm.output[0]]
if mul_children is not None and len(mul_children) == 2:
layernorm_node = mul_children[1]
if layernorm_node.op_type == 'LayerNormalization':
root_input = layernorm_node.output[0]
else:
return
else:
return
children = input_name_to_nodes[root_input]
children_types = [child.op_type for child in children]
if children_types.count('MatMul') != 3:

View file

@ -107,4 +107,14 @@ class FusionBiasSkipLayerNormalization(Fusion):
name=self.model.create_node_name("SkipLayerNormalization",
"SkipLayerNorm_AddBias_"))
new_node.domain = "com.microsoft"
# Pass attribute "epsilon" from skiplayernorm node to skiplayernorm(add bias)
for att in node.attribute:
if att.name == 'epsilon':
new_node.attribute.extend([att])
# Set default epsilon if no epsilon exists from skiplayernorm
if len(new_node.attribute) == 0:
new_node.attribute.extend([helper.make_attribute("epsilon", 1.0E-12)])
self.nodes_to_add.append(new_node)

View file

@ -278,6 +278,8 @@ def load_pt_model_from_tf(model_name):
config, model = tf2pt_pipeline(model_name)
return config, model
def validate_and_optimize_onnx(model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu,
precision, optimize_onnx, validate_onnx, use_raw_attention_mask, overwrite, config,
model_fusion_statistics, onnx_model_path, example_inputs, example_outputs_flatten):

View file

@ -335,6 +335,8 @@ class TestBertOptimization(unittest.TestCase):
# output not close issue
self.test_optimizer_on_huggingface_model("flaubert/flaubert_base_cased", [0, 12, 0, 0, 12, 0, 25],
validate_model=False)
self.test_optimizer_on_huggingface_model("flaubert/flaubert_small_cased", [0, 6, 0, 0, 6, 12, 1],
validate_model=False)
def test_huggingface_dialogpt_fusion(self):
self.test_optimizer_on_huggingface_model("microsoft/DialoGPT-small", [0, 12, 0, 12, 0, 25, 0])