mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
optimize a bert model converted using tf2onnx (#5492)
* optimize a bert model converted using tf2onnx * add test data * update * remove comments * format * Revert "format" This reverts commit f8ae88cb564bce5caf4780e56561403f3ba3d524. * Revert "remove comments" This reverts commit 59d8a693581a731fd0291b70fe2c9cec6c4950fe. * add a squeeze node to convert a 3-d mask to 2-d * update * update * verify and add comments
This commit is contained in:
parent
3323fb6082
commit
5f516899bf
5 changed files with 190 additions and 27 deletions
|
|
@ -131,15 +131,18 @@ class FusionAttention(Fusion):
|
|||
weight = helper.make_tensor(name=attention_node_name + '_qkv_weight',
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[self.hidden_size, 3 * self.hidden_size],
|
||||
vals=bytes(qkv_weight.flatten()),
|
||||
raw=True)
|
||||
vals=qkv_weight.flatten().tolist())
|
||||
# Sometimes weights and bias are stored in fp16
|
||||
if q_weight.data_type == 10:
|
||||
weight.CopyFrom(numpy_helper.from_array(numpy_helper.to_array(weight).astype(np.float16), weight.name))
|
||||
self.model.add_initializer(weight)
|
||||
|
||||
bias = helper.make_tensor(name=attention_node_name + '_qkv_bias',
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[3 * self.hidden_size],
|
||||
vals=bytes(qkv_bias.flatten()),
|
||||
raw=True)
|
||||
vals=qkv_bias.flatten().tolist())
|
||||
if q_bias.data_type == 10:
|
||||
bias.CopyFrom(numpy_helper.from_array(numpy_helper.to_array(bias).astype(np.float16), bias.name))
|
||||
self.model.add_initializer(bias)
|
||||
|
||||
attnetion_inputs = [input, attention_node_name + '_qkv_weight', attention_node_name + '_qkv_bias']
|
||||
|
|
|
|||
|
|
@ -121,24 +121,28 @@ class FusionLayerNormalizationTF(Fusion):
|
|||
|
||||
def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
|
||||
"""
|
||||
Layer Norm from Keras in Tensorflow:
|
||||
+----------------------+
|
||||
| |
|
||||
| v (B) (B) (A)
|
||||
Add --> ReduceMean --> Sub --> Mul --> ReduceMean --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add
|
||||
| | | ^ ^
|
||||
| | | | |
|
||||
| +----------------------------------------------------------------------------|-------+ |
|
||||
| v |
|
||||
+-------------------------------------------------------------------------------------> Mul--------------------+
|
||||
Layer Norm from Tensorflow model(using keras2onnx or tf2onnx):
|
||||
+------------------------------------+
|
||||
| |
|
||||
| |
|
||||
(Cast_1) |
|
||||
| |
|
||||
| v (B) (B) (A)
|
||||
Add --> (Cast_1) --> ReduceMean --> Sub --> Mul --> ReduceMean --> (Cast_3) --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add
|
||||
| | | ^ ^
|
||||
| | | | |
|
||||
| +--------------------------------------------------(Cast_2)-------------------------------|-------+ |
|
||||
| v |
|
||||
+---------------------------------------------------------------------------------------------------------------> Mul--------------------+
|
||||
"""
|
||||
return_indice = []
|
||||
parent_nodes = self.model.match_parent_path(
|
||||
_, parent_nodes, return_indice = self.model.match_parent_paths(
|
||||
node,
|
||||
['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean'],
|
||||
[ 1, 1, None, 0, 0, 0, None, 0, 0, None],
|
||||
output_name_to_node,
|
||||
return_indice=return_indice) # yapf: disable
|
||||
[(['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean'],
|
||||
[ 1, 1, None, 0, 0, 0, None, 0, 0, None]),
|
||||
(['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'Cast', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean'],
|
||||
[ 1, 1, None, 0, 0, 0, 0, None, 0, 0, None])],
|
||||
output_name_to_node) # yapf: disable
|
||||
|
||||
if parent_nodes is None:
|
||||
return
|
||||
|
|
@ -148,24 +152,35 @@ class FusionLayerNormalizationTF(Fusion):
|
|||
logger.debug("return indice is exepected in [0, 1], but got {return_indice}")
|
||||
return
|
||||
|
||||
sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes
|
||||
sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0 = parent_nodes[:6]
|
||||
reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes[-4:]
|
||||
|
||||
cast_node_3 = None
|
||||
if len(parent_nodes) == 11:
|
||||
cast_node_3 = parent_nodes[6]
|
||||
assert(cast_node_3.op_type == 'Cast')
|
||||
|
||||
mul_node_3 = self.model.match_parent(node, 'Mul', 0, output_name_to_node)
|
||||
if mul_node_3 is None:
|
||||
logger.debug("mul_node_3 not found")
|
||||
return
|
||||
|
||||
root_node = self.model.get_parent(reduce_mean_node_1, 0, output_name_to_node)
|
||||
node_before_reduce = self.model.get_parent(reduce_mean_node_1, 0, output_name_to_node)
|
||||
root_node = node_before_reduce if cast_node_3 is None else self.model.get_parent(node_before_reduce, 0, output_name_to_node)
|
||||
if root_node is None:
|
||||
logger.debug("root node is none")
|
||||
return
|
||||
|
||||
i, epsilon = self.model.get_constant_input(add_node_0)
|
||||
if epsilon is None or epsilon <= 0 or epsilon > 1.0E-5:
|
||||
if epsilon is None or epsilon <= 0 or (epsilon > 1.0E-5 and cast_node_3 is None):
|
||||
logger.debug("epsilon is not matched")
|
||||
return
|
||||
|
||||
if reduce_mean_node_1.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input:
|
||||
if cast_node_3 is None and (reduce_mean_node_1.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input):
|
||||
logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
|
||||
return
|
||||
|
||||
if cast_node_3 is not None and (node_before_reduce.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input):
|
||||
logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
|
||||
return
|
||||
|
||||
|
|
@ -177,6 +192,14 @@ class FusionLayerNormalizationTF(Fusion):
|
|||
node, sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, reduce_mean_node_0,
|
||||
mul_node_2, sub_node_1, reduce_mean_node_1, mul_node_3
|
||||
]
|
||||
|
||||
if cast_node_3 is not None:
|
||||
cast_node_2 = self.model.match_parent(mul_node_0, 'Cast', 0, output_name_to_node)
|
||||
if cast_node_2 is None:
|
||||
logger.debug("cast_node_2 not found")
|
||||
return
|
||||
subgraph_nodes.extend([node_before_reduce, cast_node_2, cast_node_3])
|
||||
|
||||
if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, node.output, self.model.input_name_to_nodes(),
|
||||
self.model.output_name_to_node()):
|
||||
logger.debug("not safe to fuse layer normalization")
|
||||
|
|
@ -189,7 +212,8 @@ class FusionLayerNormalizationTF(Fusion):
|
|||
|
||||
#TODO: add epsilon attribute
|
||||
fused_node = helper.make_node('LayerNormalization',
|
||||
inputs=[reduce_mean_node_1.input[0], weight_input, bias_input],
|
||||
inputs=[mul_node_3.input[0], weight_input, bias_input],
|
||||
outputs=[node.output[0]])
|
||||
fused_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
|
||||
self.nodes_to_add.append(fused_node)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import sys
|
|||
import argparse
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
from onnx import ModelProto, TensorProto, numpy_helper, helper
|
||||
from onnx_model_bert import BertOnnxModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -295,6 +295,126 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
self.prune_graph()
|
||||
break
|
||||
|
||||
def check_attention_input(self, matmul_q, matmul_k, matmul_v, parent, output_name_to_node):
|
||||
for x in [matmul_q, matmul_k, matmul_v]:
|
||||
root_input = x.input[0]
|
||||
root_node = output_name_to_node[root_input]
|
||||
if root_node == parent:
|
||||
continue
|
||||
logger.debug(f"Check attention input failed:{root_input}, {parent.output[0]}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def fuse_attention(self):
|
||||
output_name_to_node = self.output_name_to_node()
|
||||
|
||||
nodes_to_remove = []
|
||||
attention_count = 0
|
||||
|
||||
skip_layer_norm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization")
|
||||
for normalize_node in skip_layer_norm_nodes:
|
||||
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
|
||||
parent = self.get_parent(normalize_node, 1)
|
||||
if parent is None or parent.op_type not in ["SkipLayerNormalization", "LayerNormalization", "Reshape"]:
|
||||
parent = self.get_parent(normalize_node, 0)
|
||||
if parent is None or parent.op_type not in ["SkipLayerNormalization", "LayerNormalization", "Reshape"]:
|
||||
logger.debug("Failed to match parent of normalize_node")
|
||||
continue
|
||||
|
||||
qkv_nodes = self.match_parent_path(normalize_node, ['Add', 'MatMul', 'Reshape', 'Transpose', 'MatMul'],
|
||||
[0, 0, 0, 0, 0])
|
||||
if qkv_nodes is None:
|
||||
qkv_nodes = self.match_parent_path(normalize_node, ['MatMul', 'Reshape', 'Transpose', 'MatMul'],
|
||||
[1, 0, 0, 0])
|
||||
if qkv_nodes is None:
|
||||
logger.debug("Failed to match qkv nodes")
|
||||
continue
|
||||
|
||||
(reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes[-3:]
|
||||
v_nodes = self.match_parent_path(matmul_qkv, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0])
|
||||
if v_nodes is None:
|
||||
logger.debug("Failed to match v path")
|
||||
continue
|
||||
|
||||
(transpose_v, reshape_v, add_v, matmul_v) = v_nodes
|
||||
qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Add', "Mul", 'MatMul'], [0, 0, 0, 0])
|
||||
if qk_nodes is None:
|
||||
logger.debug("Failed to match qk_paths")
|
||||
continue
|
||||
(softmax_qk, add_qk, mul_qk, matmul_qk) = qk_nodes
|
||||
|
||||
q_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [0, 0, 0, 0])
|
||||
if q_nodes is None:
|
||||
logger.debug("Failed to match q path")
|
||||
continue
|
||||
(transpose_q, reshape_q, add_q, matmul_q) = q_nodes
|
||||
|
||||
k_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0])
|
||||
if k_nodes is None:
|
||||
logger.debug("Failed to match k path")
|
||||
continue
|
||||
(transpose_k, reshape_k, add_k, matmul_k) = k_nodes
|
||||
|
||||
mask_nodes = self.match_parent_path(add_qk, ['Mul', 'Sub', 'Unsqueeze'], [1, 0, 1])
|
||||
if mask_nodes is None:
|
||||
mask_nodes = self.match_parent_path(add_qk, ['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Mul'], [1, 0, 1, 0, 0])
|
||||
if mask_nodes is None:
|
||||
logger.debug("Failed to match mask path")
|
||||
continue
|
||||
|
||||
if not self.has_constant_input(mask_nodes[1], 1):
|
||||
logger.debug("Sub node expected to have an input with constant value 1.0.")
|
||||
continue
|
||||
|
||||
# add a squeeze node to convert a 3-d mask to 2-d
|
||||
squeeze_node = self.match_parent_path(mask_nodes[-1], ['Squeeze'], [0])
|
||||
squeeze_node_name = "Squeeze_3d_to_2d_mask"
|
||||
squeeze_output_name = squeeze_node_name + "_output"
|
||||
if squeeze_node is None and len(mask_nodes) == 5:
|
||||
mask_input = mask_nodes[-1].input[1]
|
||||
self.add_node(
|
||||
helper.make_node("Squeeze", [mask_input], [squeeze_output_name], squeeze_node_name, axes=[1]))
|
||||
mask_nodes[-1].input[0] = squeeze_output_name
|
||||
|
||||
is_same_root = self.check_attention_input(matmul_q, matmul_k, matmul_v, parent, output_name_to_node)
|
||||
if is_same_root:
|
||||
mask_index = self.attention_mask.process_mask(squeeze_output_name)
|
||||
logger.debug("Create an Attention node.")
|
||||
attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v,
|
||||
add_q, add_k, add_v, parent.output[0],
|
||||
reshape_qkv.output[0])
|
||||
if parent.op_type == 'Reshape':
|
||||
# Temporary work around: we require the skiplayernorm and attention op be fed with 3-d input
|
||||
hidden_size = numpy_helper.to_array(self.get_initializer(parent.input[1]))[1]
|
||||
tensor = helper.make_tensor(
|
||||
name=parent.name + "_modified",
|
||||
data_type=TensorProto.INT64,
|
||||
dims=[3],
|
||||
vals=np.int64([[1, -1, hidden_size]]).tobytes(),
|
||||
raw=True)
|
||||
self.add_initializer(tensor)
|
||||
parent.input[1] = parent.name + "_modified"
|
||||
|
||||
if attention_node is None:
|
||||
continue
|
||||
|
||||
self.add_node(attention_node)
|
||||
attention_count += 1
|
||||
|
||||
nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv])
|
||||
nodes_to_remove.extend(qk_nodes)
|
||||
nodes_to_remove.extend(q_nodes)
|
||||
nodes_to_remove.extend(k_nodes)
|
||||
nodes_to_remove.extend(v_nodes)
|
||||
nodes_to_remove.extend(mask_nodes)
|
||||
else:
|
||||
logger.debug("Root node not matched.")
|
||||
continue
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
self.update_graph()
|
||||
logger.info(f"Fused Attention count:{attention_count}")
|
||||
|
||||
def preprocess(self):
|
||||
self.remove_identity()
|
||||
self.process_embedding()
|
||||
|
|
@ -315,4 +435,5 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
|
||||
def postprocess(self):
|
||||
self.remove_reshape_before_first_attention()
|
||||
self.prune_graph()
|
||||
# Temporary work around for the following comment as it will cause topological issues for a bert model
|
||||
# self.prune_graph()
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -30,6 +30,7 @@ BERT_TEST_MODELS = {
|
|||
"gpt2_past": ('gpt2_pytorch1.5_opset11', 'gpt2_past.onnx'),
|
||||
"gpt2_past_mask": ('FUSION', 'gpt2_past_mask_one_layer.onnx'),
|
||||
"multiple_embed": ('FUSION', 'embed_layer_norm_multiple.onnx'),
|
||||
"bert_tf2onnx_0": ('other_models', 'bert_tf2onnx_0.onnx')
|
||||
}
|
||||
|
||||
skip_on_ort_version = pytest.mark.skipif(onnxruntime.__version__ == ('1.3.0'),
|
||||
|
|
@ -297,6 +298,20 @@ class TestBertOptimization(unittest.TestCase):
|
|||
}
|
||||
self.verify_node_count(model, expected_node_count, 'test_multiple_embed')
|
||||
|
||||
def test_bert_tf2onnx_0(self):
|
||||
input = _get_test_model_path('bert_tf2onnx_0')
|
||||
model = optimize_model(input, 'bert_tf', num_heads=2, hidden_size=8)
|
||||
expected_node_count = {
|
||||
'EmbedLayerNormalization': 0,
|
||||
'Attention': 6,
|
||||
'Gelu': 0,
|
||||
'FastGelu': 6,
|
||||
'BiasGelu': 0,
|
||||
'LayerNormalization': 0,
|
||||
'SkipLayerNormalization': 13
|
||||
}
|
||||
self.verify_node_count(model, expected_node_count, 'test_bert_tf2onnx_0')
|
||||
|
||||
def test_huggingface_bert_fusion(self):
|
||||
self.test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=1)
|
||||
self.test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=2)
|
||||
|
|
@ -343,4 +358,4 @@ class TestBertOptimization(unittest.TestCase):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
Loading…
Reference in a new issue