Improve BERT optimization script: Gelu and LayerNorm for models from Tensorflow 2.* (#2957)

* Add unit test.
Add an option --use_onnxruntime to use onnxruntime to do optimization for pytorch model.
Update layer norm and gelu for tensorflow 2.1 keras bert model.
Add logging and use f-strings.
Add extra checking for tensorflow model reshape fusion.
Allow output model to json for test purpose.
update match parent path utility function to return index

* remove function not used.
This commit is contained in:
Tianlei Wu 2020-02-07 11:01:03 -08:00 committed by GitHub
parent 0beb75ce77
commit 62383b0328
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 556 additions and 110 deletions

View file

@ -3,6 +3,7 @@
# Licensed under the MIT License.
#--------------------------------------------------------------------------
import logging
import onnx
import sys
import argparse
@ -11,6 +12,8 @@ from collections import deque
from onnx import ModelProto, TensorProto, numpy_helper
from OnnxModel import OnnxModel
logger = logging.getLogger(__name__)
class BertOnnxModel(OnnxModel):
def __init__(self, model, num_heads, hidden_size, sequence_length, input_int32, float16, gpu_only, verbose):
assert num_heads > 0
@ -219,7 +222,7 @@ class BertOnnxModel(OnnxModel):
self.remove_nodes(nodes_to_remove)
self.update_graph()
print("Fused Attention count:", attention_count)
logger.info(f"Fused Attention count:{attention_count}")
def fuse_gelu(self, gelu_op_name):
self.fuse_gelu_with_elf(gelu_op_name)
@ -294,7 +297,7 @@ class BertOnnxModel(OnnxModel):
self.remove_nodes(nodes_to_remove)
self.add_nodes(nodes_to_add)
if len(nodes_to_add) > 0:
print("Fused {} count:{}".format('FastGelu (approximation)' if gelu_op_name == 'FastGelu' else 'Gelu', len(nodes_to_add)))
logger.info("Fused {} count:{}".format('FastGelu (approximation)' if gelu_op_name == 'FastGelu' else 'Gelu', len(nodes_to_add)))
"""
Fuse Gelu with tanh into one node:
@ -389,7 +392,7 @@ class BertOnnxModel(OnnxModel):
nodes_to_add.append(gelu_node)
if len(nodes_to_add) > 0:
print("Fused {} count: {}".format('Gelu (FastGelu fits better)' if gelu_op_name == 'Gelu' else 'FastGelu', len(nodes_to_add)))
logger.info("Fused {} count: {}".format('Gelu (FastGelu fits better)' if gelu_op_name == 'Gelu' else 'FastGelu', len(nodes_to_add)))
self.remove_nodes(nodes_to_remove)
self.add_nodes(nodes_to_add)
@ -437,7 +440,7 @@ class BertOnnxModel(OnnxModel):
nodes_to_add.append(gelu_node)
if len(nodes_to_add) > 0:
print("Fused FastGelu with Bias count:", len(nodes_to_add))
logger.info(f"Fused FastGelu with Bias count:{len(nodes_to_add)}")
self.remove_nodes(nodes_to_remove)
self.add_nodes(nodes_to_add)
@ -484,7 +487,7 @@ class BertOnnxModel(OnnxModel):
nodes_to_add.append(new_node)
if len(nodes_to_add) > 0:
print("Fused SkipLayerNormalization with Bias count:", len(nodes_to_add))
logger.info(f"Fused SkipLayerNormalization with Bias count:{len(nodes_to_add)}")
self.remove_nodes(nodes_to_remove)
self.add_nodes(nodes_to_add)
@ -540,7 +543,11 @@ class BertOnnxModel(OnnxModel):
concat_2 = self.get_initializer(concat_node.input[2])
if concat_2 is None:
continue
shape.extend(numpy_helper.to_array(concat_2))
concat_value = numpy_helper.to_array(concat_2)
if isinstance(concat_value, list):
shape.extend(concat_value)
else:
shape.append(concat_value)
if len(concat_node.input) == 4 and self.get_initializer(concat_node.input[3]) is None:
path2 = self.match_parent_path(concat_node, ['Unsqueeze', 'Div', 'Gather', 'Shape'], [3, 0, 0, 0], output_name_to_node)
@ -552,7 +559,12 @@ class BertOnnxModel(OnnxModel):
concat_3 = self.get_initializer(concat_node.input[3])
if concat_3 is None:
continue
shape.extend(numpy_helper.to_array(concat_3))
concat_value = numpy_helper.to_array(concat_3)
if isinstance(concat_value, list):
shape.extend(concat_value)
else:
shape.append(concat_value)
root_input = reshape_node.input[0]
same_shape_input = True
@ -582,7 +594,7 @@ class BertOnnxModel(OnnxModel):
nodes_to_remove.extend(path3)
nodes_to_add.append(new_node)
print("Fused Reshape count:", len(nodes_to_add))
logger.info(f"Fused Reshape count:{len(nodes_to_add)}")
self.remove_nodes(nodes_to_remove)
self.add_nodes(nodes_to_add)
@ -615,10 +627,10 @@ class BertOnnxModel(OnnxModel):
output_name_to_node = self.output_name_to_node()
if len(self.mask_indice) == 0:
print("skip embed layer fusion since mask input is not found")
logger.info("skip embed layer fusion since mask input is not found")
return
if len(self.mask_indice) > 1:
print("skip embed layer fusion since there are multiple mask inputs found")
logger.info("skip embed layer fusion since there are multiple mask inputs found")
return
mask_input_name = next(iter(self.mask_indice))
mask_output_name = self.mask_indice[mask_input_name]
@ -635,13 +647,14 @@ class BertOnnxModel(OnnxModel):
break
if normalize_node is None:
print("Failed to find embedding layer")
logger.info("Failed to find embedding layer")
return
# Here we assume the order of embedding is word_embedding +
# position_embedding + segment_embedding.
word_embedding_path = self.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 0])
if word_embedding_path is None:
print("Failed to find word embedding")
logger.info("Failed to find word embedding")
return
add_node, word_embedding_gather = word_embedding_path
input_ids = word_embedding_gather.input[1]
@ -654,14 +667,14 @@ class BertOnnxModel(OnnxModel):
if position_embedding_path is None:
position_embedding_path = self.match_parent_path(add_node, ['Gather', 'Expand', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [1, 1, 1, 1, 0, 0])
if position_embedding_path is None:
print("Failed to find position embedding")
logger.info("Failed to find position embedding")
return
position_embedding_weight_node, position_embedding_expand, _, _, _, position_embedding_shape = position_embedding_path
else:
position_embedding_weight_node, position_embedding_expand, position_embedding_shape = position_embedding_path
if not position_embedding_shape is None and position_embedding_shape.input[0] != input_ids:
print("position and word embedding is expected to be applied on same input")
logger.info("position and word embedding is expected to be applied on same input")
return
else:
_, position_embedding_weight_node = position_embedding_path
@ -670,7 +683,7 @@ class BertOnnxModel(OnnxModel):
if segment_embedding_path is None:
segment_embedding_path = self.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1])
if segment_embedding_path is None:
print("Failed to find segment embedding")
logger.info("Failed to find segment embedding")
return
_, segment_embedding_gather = segment_embedding_path
else:
@ -724,7 +737,7 @@ class BertOnnxModel(OnnxModel):
self.remove_nodes(nodes_to_remove)
self.add_node(embed_node)
self.update_graph()
print("Fused EmbedLayerNormalization count: 1")
logger.info("Fused EmbedLayerNormalization count: 1")
# Change graph input data type int32 if needed.
if self.input_int32:
@ -816,17 +829,6 @@ class BertOnnxModel(OnnxModel):
| |
+----------------------+
TODO: Batch Layer Norm from Keras in Tensorflow:
+----------------------+
| |
| v (B) (A)
Add --> ReduceMean --> Sub --> Mul --> ReduceMean --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add
| | | ^ ^
| | | | |
| +----------------------------------------------------------------------------|-------+ |
| v |
+-------------------------------------------------------------------------------------> Mul--------------------+
"""
def fuse_layer_norm(self):
input_name_to_nodes = self.input_name_to_nodes()
@ -912,8 +914,8 @@ class BertOnnxModel(OnnxModel):
self.remove_nodes(nodes_to_remove)
self.add_nodes(skip_layernorm_nodes)
self.add_nodes(layernorm_nodes)
print("Fused SkipLayerNormalization count:", len(skip_layernorm_nodes))
print("Fused LayerNormalization count:", len(layernorm_nodes))
logger.info(f"Fused SkipLayerNormalization count: {len(skip_layernorm_nodes)}")
logger.info(f"Fused LayerNormalization count: {len(layernorm_nodes)}")
def preprocess(self):
return
@ -945,4 +947,4 @@ class BertOnnxModel(OnnxModel):
# Use symbolic batch dimension in input and output.
self.update_dynamic_batch_io()
print("opset verion", self.model.opset_import[0].version)
logger.info(f"opset verion: {self.model.opset_import[0].version}")

View file

@ -3,6 +3,7 @@
# Licensed under the MIT License.
#--------------------------------------------------------------------------
import logging
import onnx
import sys
import argparse
@ -11,10 +12,91 @@ from collections import deque
from onnx import ModelProto, TensorProto, numpy_helper
from BertOnnxModel import BertOnnxModel
logger = logging.getLogger(__name__)
class BertOnnxModelTF(BertOnnxModel):
def __init(self, model, num_heads, hidden_size, sequence_length, input_int32, float16, gpu_only, verbose):
super().__init__(model, num_heads, hidden_size, sequence_length, verbose)
"""
Fuse Gelu with Erf into one node:
+----------------------------------------------+
| |
| v
[root] --> Mul -----> Erf --> Add --> Mul -->Mul
(A=0.7071067690849304) (B=1) (B=0.5)
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
"""
def fuse_gelu_with_elf(self, gelu_op_name):
input_name_to_nodes = self.input_name_to_nodes()
output_name_to_node = self.output_name_to_node()
nodes_to_remove = []
nodes_to_add = []
for node in self.get_nodes_by_op_type('Erf'):
erf_node = node
if erf_node.output[0] not in input_name_to_nodes:
continue
children = input_name_to_nodes[erf_node.output[0]]
if len(children) != 1 or children[0].op_type != 'Add':
continue
add_after_erf = children[0]
if not self.has_constant_input(add_after_erf, 1):
continue
if add_after_erf.output[0] not in input_name_to_nodes:
continue
children = input_name_to_nodes[add_after_erf.output[0]]
if len(children) != 1 or children[0].op_type != 'Mul':
continue
mul_half = children[0]
if not self.has_constant_input(mul_half, 0.5):
continue
first_mul = self.match_parent(erf_node, 'Mul', 0, output_name_to_node)
if first_mul is None:
continue
i = self.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
if i < 0:
continue
root_node = self.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node)
if root_node is None:
continue
if mul_half.output[0] not in input_name_to_nodes:
continue
children = input_name_to_nodes[mul_half.output[0]]
if len(children) != 1 or children[0].op_type != 'Mul':
continue
last_mul = children[0]
if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]):
continue
subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [last_mul.output[0]], input_name_to_nodes, output_name_to_node):
continue
nodes_to_remove.extend(subgraph_nodes)
gelu_node = onnx.helper.make_node(gelu_op_name,
inputs=[root_node.output[0]],
outputs=[last_mul.output[0]])
gelu_node.domain = "com.microsoft"
nodes_to_add.append(gelu_node)
self.remove_nodes(nodes_to_remove)
self.add_nodes(nodes_to_add)
if len(nodes_to_add) > 0:
logger.info("Fused {} count:{}".format('FastGelu (approximation)' if gelu_op_name == 'FastGelu' else 'Gelu', len(nodes_to_add)))
"""
Fuse Gelu with tanh into one node:
+---------------------------+
@ -111,7 +193,7 @@ class BertOnnxModelTF(BertOnnxModel):
nodes_to_add.append(gelu_node)
if len(nodes_to_add) > 0:
print("Fused {} count: {}".format('Gelu (FastGelu fits better)' if gelu_op_name == 'Gelu' else 'FastGelu', len(nodes_to_add)))
logger.info("Fused {} count: {}".format('Gelu (FastGelu fits better)' if gelu_op_name == 'Gelu' else 'FastGelu', len(nodes_to_add)))
self.remove_nodes(nodes_to_remove)
self.add_nodes(nodes_to_add)
@ -151,11 +233,25 @@ class BertOnnxModelTF(BertOnnxModel):
nodes_to_add.append(new_node)
def __fuse_reshape_after_sotfmax(self, reshape_node, nodes_to_remove, nodes_to_add):
# Check that it is reshape after softmax.
path = self.match_parent_path(reshape_node, ['Transpose', 'MatMul', 'Softmax', 'Add'], [0, 0, 0, 0])
if path is None:
return
path0 = self.match_parent_path(reshape_node, ['Cast', 'Concat', 'Unsqueeze', 'Mul'], [1, 0, 0, 0])
if path0 is None:
return
cast_node, concat_node, unsqueeze_node, mul_node = path0
# Verify that cast has attribute "to" = 7
is_good_cast = False
for att in cast_node.attribute:
if att.name == 'to' and att.i == 7:
is_good_cast = True
break
if not is_good_cast:
return
if not len(concat_node.input) == 2:
return
@ -184,7 +280,7 @@ class BertOnnxModelTF(BertOnnxModel):
def __fuse_reshape_after_normalize(self, reshape_node, nodes_to_remove):
parent = self.get_parent(reshape_node, 0)
if not parent.op_type == self.normalize_name:
if parent is None or not parent.op_type == self.normalize_name:
return
parent_path = self.match_parent_path(
@ -199,6 +295,9 @@ class BertOnnxModelTF(BertOnnxModel):
self.replace_input_of_all_nodes(reshape_node.output[0], reshape_node.input[0])
def fuse_reshape(self):
nodes = self.nodes()
input_name_to_nodes = self.input_name_to_nodes()
@ -212,7 +311,7 @@ class BertOnnxModelTF(BertOnnxModel):
self.__fuse_reshape_after_sotfmax(reshape_node, nodes_to_remove, nodes_to_add)
self.__fuse_reshape_after_normalize(reshape_node, nodes_to_remove)
print("Count of nodes removed for Reshape fuse:", len(nodes_to_remove))
logger.info(f"Count of nodes removed for Reshape fuse:{len(nodes_to_remove)}")
self.remove_nodes(nodes_to_remove)
self.add_nodes(nodes_to_add)
@ -238,51 +337,81 @@ class BertOnnxModelTF(BertOnnxModel):
layernorm_nodes = []
for node in self.nodes():
if node.op_type == 'Add':
return_indice=[]
parent_nodes = self.match_parent_path(
node,
['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean'],
[ 1, 1, 1, 0, 0, 0, 0, 0, 0, 1],
output_name_to_node)
[ 1, 1, None, 0, 0, 0, None, 0, 0, None],
output_name_to_node,
return_indice=return_indice)
if parent_nodes is None:
continue
assert len(return_indice) == 3
if not (return_indice[0] in [0, 1] and return_indice[1] in [0, 1] and return_indice[2] in [0, 1]):
logger.debug("return indice is exepected in [0, 1], but got {return_indice}")
continue
sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes
mul_node_3 = self.get_parent(node, 0, output_name_to_node)
mul_node_3 = self.match_parent(node, 'Mul', 0, output_name_to_node)
if mul_node_3 is None:
logger.debug("mul_node_3 not found")
continue
root_node = self.get_parent(reduce_mean_node_1, 0, output_name_to_node)
if root_node is None:
logger.debug("root node is none")
continue
i, add_weight = self.get_constant_input(add_node_0)
#if add_weight is None or add_weight <= 0 or add_weight > 1.0E-5:
# continue
i, epsilon = self.get_constant_input(add_node_0)
if epsilon is None or epsilon <= 0 or epsilon > 1.0E-5:
logger.debug("epsilon is not matched")
continue
nodes_to_remove.extend([node, sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1,mul_node_3])
if reduce_mean_node_1.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input:
logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
continue
if mul_node_2.input[0] != mul_node_2.input[1]:
logger.debug("mul_node_2 shall have two same inputs")
continue
subgraph_nodes = [node, sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1,mul_node_3]
if not self.is_safe_to_fuse_nodes(subgraph_nodes, node.output, self.input_name_to_nodes(), self.output_name_to_node()):
logger.debug("not safe to fuse layer normalization")
continue
nodes_to_remove.extend(subgraph_nodes)
weight_input = mul_node_1.input[1]
bias_input = sub_node_0.input[0]
if root_node.op_type == 'Add':
nodes_to_remove.append(root_node)
normalize_node = onnx.helper.make_node(self.normalize_name,
inputs=[root_node.input[0], root_node.input[1], weight_input, bias_input],
outputs=[node.output[0]],
name=self.create_node_name(self.normalize_name, name_prefix="SkipLayerNorm"))
normalize_node.domain = "com.microsoft"
skip_layernorm_nodes.extend([normalize_node])
else:
normalize_node = onnx.helper.make_node('LayerNormalization',
inputs=[reduce_mean_node_1.input[0], weight_input, bias_input],
outputs=[node.output[0]], epsilon=add_weight)
layernorm_nodes.extend([normalize_node])
subgraph_nodes.append(root_node)
if not self.is_safe_to_fuse_nodes(subgraph_nodes, node.output, self.input_name_to_nodes(), self.output_name_to_node()):
subgraph_nodes.pop()
else:
nodes_to_remove.append(root_node)
normalize_node = onnx.helper.make_node(self.normalize_name,
inputs=[root_node.input[0], root_node.input[1], weight_input, bias_input],
outputs=[node.output[0]],
name=self.create_node_name(self.normalize_name, name_prefix="SkipLayerNorm"))
normalize_node.domain = "com.microsoft"
skip_layernorm_nodes.extend([normalize_node])
continue
normalize_node = onnx.helper.make_node('LayerNormalization',
inputs=[reduce_mean_node_1.input[0], weight_input, bias_input],
outputs=[node.output[0]], epsilon=epsilon)
layernorm_nodes.extend([normalize_node])
self.remove_nodes(nodes_to_remove)
self.add_nodes(skip_layernorm_nodes)
self.add_nodes(layernorm_nodes)
print("Fused SkipLayerNormalization count:", len(skip_layernorm_nodes))
print("Fused LayerNormalization count:", len(layernorm_nodes))
logger.info(f"Fused SkipLayerNormalization count: {len(skip_layernorm_nodes)}")
logger.info(f"Fused LayerNormalization count: {len(layernorm_nodes)}")
def remove_identity(self):
nodes_to_remove = []
@ -292,7 +421,7 @@ class BertOnnxModelTF(BertOnnxModel):
self.replace_input_of_all_nodes(node.output[0], node.input[0])
nodes_to_remove.append(node)
self.remove_nodes(nodes_to_remove)
print("Removed Identity count:", len(nodes_to_remove))
logger.info(f"Removed Identity count: {len(nodes_to_remove)}")
def fuse_word_embedding(self):
nodes_to_remove = []
@ -319,7 +448,7 @@ class BertOnnxModelTF(BertOnnxModel):
nodes_to_remove.extend([reshape_node_0, cast_node_0, concat_node, unsqueeze_node, cast_node_1, squeeze_node, slice_node, cast_node_2, shape_node, reshape_node_2, node])
self.remove_nodes(nodes_to_remove)
print("Fused word embedding" if len(nodes_to_remove) > 0 else "Failed to fuse word embedding")
logger.info("Fused word embedding" if len(nodes_to_remove) > 0 else "Failed to fuse word embedding")
def fuse_segment_embedding(self):
nodes_to_remove = []
@ -328,11 +457,16 @@ class BertOnnxModelTF(BertOnnxModel):
data_path = self.match_parent_path(node, ['MatMul', 'OneHot', 'Reshape'], [0, 0, 0])
if data_path is None:
continue
matmul_node, onehot_node, reshape_node_0 = data_path
concat_node_0 = self.get_parent(onehot_node, 2, self.output_name_to_node())
if not concat_node_0.op_type == 'Concat':
continue
matmul_node, onehot_node, reshape_node_0 = data_path
subgraph_nodes = [matmul_node, onehot_node, reshape_node_0]
if self.get_initializer(onehot_node.input[2]) is None:
concat_node_0 = self.get_parent(onehot_node, 2)
if concat_node_0 is None or concat_node_0.op_type != 'Concat':
continue
subgraph_nodes.append(concat_node_0)
shape_path = self.match_parent_path(
node,
['Cast', 'Concat', 'Unsqueeze', 'Cast', 'Squeeze', 'Slice', 'Cast', 'Shape', 'Gather'],
@ -349,18 +483,13 @@ class BertOnnxModelTF(BertOnnxModel):
outputs=node.output,
name='segment_embedding_gather')
nodes_to_remove.extend([
matmul_node,
onehot_node,
reshape_node_0,
concat_node_0
])
nodes_to_remove.extend(subgraph_nodes)
nodes_to_remove.extend([cast_node_0, concat_node_1, unsqueeze_node, cast_node_1, squeeze_node, slice_node, cast_node_2, shape_node, node])
self.add_node(gather_node)
self.remove_nodes(nodes_to_remove)
print("Fused segment embedding" if len(nodes_to_remove) > 0 else "Failed to fuse segment embedding")
logger.info("Fused segment embedding" if len(nodes_to_remove) > 0 else "Failed to fuse segment embedding")
def fuse_mask(self):
nodes_to_remove = []
@ -408,10 +537,10 @@ class BertOnnxModelTF(BertOnnxModel):
self.add_node(unsqueeze_added_2)
self.remove_nodes(nodes_to_remove)
print("Fused mask" if len(nodes_to_remove) > 0 else "Failed to fuse mask")
logger.info("Fused mask" if len(nodes_to_remove) > 0 else "Failed to fuse mask")
def preprocess(self):
self.remove_identity()
self.fuse_word_embedding()
self.fuse_segment_embedding()
self.fuse_mask()
self.fuse_mask()

View file

@ -3,6 +3,7 @@
# Licensed under the MIT License.
#--------------------------------------------------------------------------
import logging
import onnx
import sys
import argparse
@ -10,6 +11,8 @@ import numpy as np
from collections import deque
from onnx import ModelProto, TensorProto, numpy_helper
logger = logging.getLogger(__name__)
class OnnxModel:
def __init__(self, model, verbose):
self.model = model
@ -124,18 +127,57 @@ class OnnxModel:
return output_name_to_node[input]
def match_parent(self, node, parent_op_type, input_index=None, output_name_to_node=None, exclude=[]):
def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude=[]):
'''
Find parent node based on constraints on op_type.
Args:
node (str): current node name.
parent_op_type (str): constraint of parent node op_type.
output_name_to_node (dict): dictionary with output name as key, and node as value.
exclude (list): list of nodes that are excluded (not allowed to match as parent).
Returns:
parent: The matched parent node. None if not found.
index: The input index of matched parent node. None if not found.
'''
for i, input in enumerate(node.input):
if input in output_name_to_node:
parent = output_name_to_node[input]
if parent.op_type == parent_op_type and parent not in exclude:
return parent, i
return None, None
def match_parent(self, node, parent_op_type, input_index=None, output_name_to_node=None, exclude=[], return_indice=None):
'''
Find parent node based on constraints on op_type and index.
When input_index is None, we will find the first parent node based on constraints, and return_indice will be appended the corresponding input index.
Args:
node (str): current node name.
parent_op_type (str): constraint of parent node op_type.
input_index (int or None): only check the parent given input index of current node.
output_name_to_node (dict): dictionary with output name as key, and node as value.
exclude (list): list of nodes that are excluded (not allowed to match as parent).
return_indice (list): a list to append the input index when input_index is None.
Returns:
parent: The matched parent node.
'''
assert node is not None
assert input_index is None or input_index >= 0
if output_name_to_node is None:
output_name_to_node = self.output_name_to_node()
if input_index is None:
parents = self.get_parents(node, output_name_to_node)
for parent in parents:
if parent.op_type == parent_op_type and parent not in exclude:
return parent
return None
parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
if return_indice is not None:
return_indice.append(index)
return parent
if input_index < 0 or input_index >= len(node.input):
if input_index >= len(node.input):
return None
parent = self.get_parent(node, input_index, output_name_to_node)
@ -144,7 +186,21 @@ class OnnxModel:
return None
def match_parent_path(self, node, parent_op_types, parent_input_index, output_name_to_node=None):
def match_parent_path(self, node, parent_op_types, parent_input_index, output_name_to_node=None, return_indice=None):
'''
Find a sequence of input edges based on constraints on parent op_type and index.
When input_index is None, we will find the first parent node based on constraints, and return_indice will be appended the corresponding input index.
Args:
node (str): current node name.
parent_op_types (str): constraint of parent node op_type of each input edge.
parent_input_index (list): constraint of input index of each input edge. None means no constraint.
output_name_to_node (dict): dictionary with output name as key, and node as value.
return_indice (list): a list to append the input index when there is no constraint on input index of an edge.
Returns:
parents: a list of matched parent node.
'''
assert(len(parent_input_index) == len(parent_op_types))
if output_name_to_node is None:
@ -153,7 +209,7 @@ class OnnxModel:
current_node = node
matched_parents = []
for i, op_type in enumerate(parent_op_types):
matched_parent = self.match_parent(current_node, op_type, parent_input_index[i], output_name_to_node, exclude=[])
matched_parent = self.match_parent(current_node, op_type, parent_input_index[i], output_name_to_node, exclude=[], return_indice=return_indice)
if matched_parent is None:
return None
@ -370,7 +426,7 @@ class OnnxModel:
self.remove_nodes(unused_nodes)
if len(unused_nodes) > 0:
print("Removed unused constant nodes:", len(unused_nodes))
logger.info(f"Removed unused constant nodes: {len(unused_nodes)}")
def update_graph(self):
graph = self.model.graph
@ -382,7 +438,7 @@ class OnnxModel:
if input_name not in remaining_input_names:
remaining_input_names.append(input_name)
if self.verbose:
print("remaining input names", remaining_input_names)
logger.info(f"remaining input names: {remaining_input_names}" )
# remove graph input that is not used
inputs_to_remove = []
@ -392,7 +448,8 @@ class OnnxModel:
for input in inputs_to_remove:
graph.input.remove(input)
if self.verbose:
print("remove unused input ", len(inputs_to_remove), [input.name for input in inputs_to_remove])
names_to_remove = [input.name for input in inputs_to_remove]
logger.info(f"remove {len(inputs_to_remove)} unused inputs: {names_to_remove}")
# remove weights that are not used
weights_to_remove = []
@ -406,21 +463,22 @@ class OnnxModel:
graph.initializer.remove(initializer)
if self.verbose:
print("remove unused initializers:", len(weights_to_remove), [initializer.name for initializer in weights_to_remove])
print("remaining initializers:", weights_to_keep)
names_to_remove = [initializer.name for initializer in weights_to_remove]
logger.info(f"remove {len(weights_to_remove)} unused initializers: {names_to_remove}")
logger.info(f"remaining initializers:{weights_to_keep}")
self.remove_unused_constant()
def is_safe_to_fuse_nodes(self, nodes_to_remove, keep_outputs, input_name_to_nodes, output_name_to_node):
for node in nodes_to_remove:
for output in node.output:
if output in keep_outputs:
for node_to_remove in nodes_to_remove:
for output_to_remove in node_to_remove.output:
if output_to_remove in keep_outputs:
continue
if output in input_name_to_nodes:
for node in input_name_to_nodes[output]:
if node not in nodes_to_remove:
if output_to_remove in input_name_to_nodes:
for impacted_node in input_name_to_nodes[output_to_remove]:
if impacted_node not in nodes_to_remove:
if self.verbose:
print("warning: it is not safe to remove nodes since output", output, "used by", node)
logger.warning(f"it is not safe to remove nodes since output {output_to_remove} is used by {impacted_node}")
return False
return True

View file

@ -63,3 +63,5 @@ See below for description of all the options:
By default, model uses float32 in computation. If this flag is specified, half-precision float will be used. This option is recommended for NVidia GPU with Tensor Core like V100 and T4. For older GPUs, float32 is likely faster.
- **verbose**: (*optional*)
Print verbose information when this flag is specified.
- **run_onnxruntime**: (*optional*)
Use onnxruntime to do optimization. This option is only avaiable for pytorch model right now.

View file

@ -7,35 +7,65 @@
# SkipLayerNormalization and EmbedLayerNormalization ops to optimize
# performance on NVidia GPU and CPU.
# Note: This script is not required for Bert model exported from PyTorch.
# OnnxRuntime has bert model optimization support internally. The recommended way is
# to set optimization level to ORT_ENABLE_EXTENDED during Bert model inference.
# See the following document for more information:
# https://github.com/microsoft/onnxruntime/blob/master/docs/ONNX_Runtime_Graph_Optimizations.md
# For Bert model exported from PyTorch, OnnxRuntime has bert model optimization support internally.
# You can use the option --use_onnxruntime to use model optimization from OnnxRuntime package.
# For Bert model file like name.onnx, optimized model for GPU or CPU from OnnxRuntime will output as
# name_ort_gpu.onnx or name_ort_cpu.onnx in the same directory.
# This script is retained for experiment purpose. Useful senarios like the following:
# (1) Change model from fp32 to fp16.
# (2) Change input data type from int64 to int32.
# (3) Model cannot be handled to OnnxRuntime graph optimization, and you can modify this script to get optimized model.
# (3) Some model cannot be handled by OnnxRuntime, and you can modify this script to get optimized model.
# This script has been tested using the following models:
# (1) BertForSequenceClassification as in https://github.com/huggingface/transformers/blob/master/examples/run_glue.py
# PyTorch 1.2 or above, and exported to Onnx using opset version 10 or 11.
# (2) BertForQuestionAnswering as in https://github.com/huggingface/transformers/blob/master/examples/run_squad.py
# PyTorch 1.2 or above, and exported to Onnx using opset version 10 or 11.
import logging
import onnx
import os
import sys
import argparse
import numpy as np
from collections import deque
from onnx import ModelProto, TensorProto, numpy_helper
import onnxruntime
from BertOnnxModel import BertOnnxModel
from BertOnnxModelTF import BertOnnxModelTF
def main():
logger = logging.getLogger('')
def run_onnxruntime(onnx_model_path, use_gpu, optimized_model_path=None):
if use_gpu and 'CUDAExecutionProvider' not in onnxruntime.get_available_providers():
logger.error("There is no gpu for onnxruntime to do optimization.")
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
if optimized_model_path is None:
path_prefix = onnx_model_path[:-5] #remove .onnx suffix
optimized_model_path = "{}_ort_{}.onnx".format(path_prefix, "gpu" if use_gpu else "cpu")
sess_options.optimized_model_filepath = optimized_model_path
if not use_gpu:
session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=['CPUExecutionProvider'])
else:
session = onnxruntime.InferenceSession(onnx_model_path, sess_options)
assert 'CUDAExecutionProvider' in session.get_providers() # Make sure there is GPU
assert os.path.exists(optimized_model_path) and os.path.isfile(optimized_model_path)
logger.info("Save optimized model by onnxruntime to {}".format(optimized_model_path))
return optimized_model_path
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--input', required=True, type=str)
parser.add_argument('--output', required=True, type=str)
parser.add_argument('--framework', required=True, type=str, help="Original framework. Only support TensorFlow and PyTorch")
parser.add_argument('--framework', required=True, type=str.lower, choices=["tensorflow", "pytorch"], help="Original framework")
# model parameters
# model parameters.
parser.add_argument('--num_heads', required=False, type=int, default=12, help="number of attention heads")
parser.add_argument('--hidden_size', required=False, type=int, default=768)
parser.add_argument('--sequence_length', required=False, type=int, default=128)
@ -56,23 +86,60 @@ def main():
parser.add_argument('--verbose', required=False, action='store_true')
parser.set_defaults(verbose=False)
parser.add_argument('--use_onnxruntime', required=False, action='store_true')
parser.set_defaults(use_onnxruntime=False)
args = parser.parse_args()
return args
def optimize_model(input, framework, gpu_only, num_heads, hidden_size, sequence_length, input_int32, float16, verbose):
model = ModelProto()
with open(args.input, "rb") as f:
with open(input, "rb") as f:
model.ParseFromString(f.read())
if args.framework.lower() == 'tensorflow':
bert_model = BertOnnxModelTF(model, args.num_heads, args.hidden_size, args.sequence_length, args.input_int32, args.float16, args.gpu_only, args.verbose)
elif args.framework.lower() == 'pytorch':
bert_model = BertOnnxModel(model, args.num_heads, args.hidden_size, args.sequence_length, args.input_int32, args.float16, args.gpu_only, args.verbose)
else:
print("Unsupported framework:" + args.framework)
if framework == 'tensorflow':
bert_model = BertOnnxModelTF(model, num_heads, hidden_size, sequence_length, input_int32, float16, gpu_only, verbose)
else: #framework == 'pytorch'
bert_model = BertOnnxModel(model, num_heads, hidden_size, sequence_length, input_int32, float16, gpu_only, verbose)
bert_model.optimize()
return bert_model
def output_model(model, output):
if output.endswith(".json"): # output to JSON. Only for test purpose.
if isinstance(model, ModelProto):
with open(output, "w") as out:
out.write(str(model))
logger.info("Output JSON to {}.json".format(output))
else:
with open(output, "wb") as out:
out.write(model.SerializeToString())
logger.info("Output final model to {}".format(output))
def main():
args = parse_arguments()
# output logging to stdout
log_handler = logging.StreamHandler(sys.stdout)
log_handler.setFormatter(logging.Formatter('%(asctime)s %(message)s'))
log_handler.setLevel(logging.DEBUG)
logger.addHandler(log_handler)
logger.setLevel(logging.DEBUG)
input_model_path = args.input
if args.use_onnxruntime:
if framework == 'tensorflow':
logger.warning("onnxruntime does not have optimization for tensorflow model. Ignore the option --use_onnxruntime.")
else:
input_model_path = run_onnxruntime(input_model_path, args.gpu_only)
bert_model = optimize_model(input_model_path, args.framework, args.gpu_only, args.num_heads, args.hidden_size, args.sequence_length, args.input_int32, args.float16, args.verbose)
output_model(bert_model.model, args.output)
with open(args.output, "wb") as out:
out.write(bert_model.model.SerializeToString())
if __name__ == "__main__":
main()

View file

@ -0,0 +1,183 @@
#!/usr/bin/env python
# coding: utf-8
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# For live logging, use the command: pytest -o log_cli=true --log-cli-level=DEBUG
import unittest
import os
import onnx
import onnxruntime
from onnx import helper, TensorProto, ModelProto
from onnx.helper import make_node, make_tensor_value_info
import numpy as np
from onnx import numpy_helper
from bert_model_optimization import optimize_model, run_onnxruntime
from OnnxModel import OnnxModel
class TestBertOptimization(unittest.TestCase):
def get_model(self, framework, index):
if framework == "pytorch":
return 'test_data\\bert_squad_pytorch1.4_opset11\\BertForQuestionAnswering_{}.onnx'.format(index)
else:
return 'test_data\\bert_mrpc_tensorflow2.1_opset10\\TFBertForSequenceClassification_{}.onnx'.format(index)
def verify_node_count(self, bert_model, expected_node_count):
for op_type, count in expected_node_count.items():
if len(bert_model.get_nodes_by_op_type(op_type)) != count:
print("{}:{} expected={}".format(op_type, len(bert_model.get_nodes_by_op_type(op_type)), count))
self.assertEqual(len(bert_model.get_nodes_by_op_type(op_type)), count)
def test_pytorch_model_0_cpu_onnxruntime(self):
input = self.get_model("pytorch", 0)
output = 'temp.onnx'
run_onnxruntime(input, use_gpu=False, optimized_model_path=output)
model = ModelProto()
with open(output, "rb") as f:
model.ParseFromString(f.read())
os.remove(output)
bert_model = OnnxModel(model, verbose=False)
expected_node_count = {
'EmbedLayerNormalization': 1,
'Attention': 12,
'SkipLayerNormalization': 24,
'Gelu': 0,
'FastGelu': 0,
'BiasGelu': 12
}
self.verify_node_count(bert_model, expected_node_count)
def test_pytorch_model_0_gpu_onnxruntime(self):
if 'CUDAExecutionProvider' not in onnxruntime.get_available_providers():
print("skip test_pytorch_model_0_gpu_onnxruntime since no gpu found")
return
input = self.get_model("pytorch", 0)
output = 'temp.onnx'
run_onnxruntime(input, use_gpu=True, optimized_model_path=output)
model = ModelProto()
with open(output, "rb") as f:
model.ParseFromString(f.read())
os.remove(output)
bert_model = OnnxModel(model, verbose=False)
expected_node_count = {
'EmbedLayerNormalization': 1,
'Attention': 12,
'SkipLayerNormalization': 24,
'Gelu': 0,
'FastGelu': 12,
'BiasGelu': 0
}
self.verify_node_count(bert_model, expected_node_count)
def test_pytorch_model_1_cpu_onnxruntime(self):
input = self.get_model("pytorch", 1)
output = 'temp.onnx'
run_onnxruntime(input, use_gpu=False, optimized_model_path=output)
model = ModelProto()
with open(output, "rb") as f:
model.ParseFromString(f.read())
os.remove(output)
bert_model = OnnxModel(model, verbose=False)
expected_node_count = {
'EmbedLayerNormalization': 1,
'Attention': 12,
'LayerNormalization': 24,
'SkipLayerNormalization': 0,
'Gelu': 0,
'FastGelu': 0,
'BiasGelu': 12
}
self.verify_node_count(bert_model, expected_node_count)
def test_pytorch_model_1_gpu_onnxruntime(self):
if 'CUDAExecutionProvider' not in onnxruntime.get_available_providers():
print("skip test_pytorch_model_1_gpu_onnxruntime since no gpu found")
return
input = self.get_model("pytorch", 1)
output = 'temp.onnx'
run_onnxruntime(input, use_gpu=True, optimized_model_path=output)
model = ModelProto()
with open(output, "rb") as f:
model.ParseFromString(f.read())
os.remove(output)
bert_model = OnnxModel(model, verbose=False)
expected_node_count = {
'EmbedLayerNormalization': 1,
'Attention': 12,
'LayerNormalization': 24,
'SkipLayerNormalization': 0,
'Gelu': 0,
'FastGelu': 12,
'BiasGelu': 0
}
self.verify_node_count(bert_model, expected_node_count)
def test_pytorch_model_0_cpu(self):
input = self.get_model("pytorch", 0)
bert_model = optimize_model(input, framework='pytorch', gpu_only=False,
num_heads=2, hidden_size=8, sequence_length=10,
input_int32=False, float16=False, verbose=False)
expected_node_count = {
'EmbedLayerNormalization': 1,
'Attention': 12,
'SkipLayerNormalization': 24,
'Gelu': 12,
'FastGelu': 0,
'BiasGelu': 0
}
self.verify_node_count(bert_model, expected_node_count)
def test_pytorch_model_0_gpu(self):
if 'CUDAExecutionProvider' not in onnxruntime.get_available_providers():
print("skip test_pytorch_model_0_gpu since no gpu found")
return
input = self.get_model("pytorch", 0)
bert_model = optimize_model(input, framework='pytorch', gpu_only=True,
num_heads=2, hidden_size=8, sequence_length=10,
input_int32=False, float16=False, verbose=False)
expected_node_count = {
'EmbedLayerNormalization': 1,
'Attention': 12,
'SkipLayerNormalization': 24,
'FastGelu': 12,
'Gelu': 0,
'BiasGelu': 0
}
self.verify_node_count(bert_model, expected_node_count)
def test_tensorflow_model_1_cpu(self):
input = self.get_model("tensorflow", 1)
# The model need constant folding. Use onnxruntime to do so for now.
temp = 'temp.onnx'
run_onnxruntime(input, use_gpu=False, optimized_model_path=temp)
bert_model = optimize_model(temp, framework='tensorflow', gpu_only=False,
num_heads=2, hidden_size=8, sequence_length=7,
input_int32=False, float16=False, verbose=False)
os.remove(temp)
# Optimization for tensorflow model is still on-going.
# TODO: update this after code complete.
expected_node_count = {
'EmbedLayerNormalization': 0,
'Attention': 0,
'LayerNormalization': 0,
'SkipLayerNormalization': 25,
'BiasGelu': 0,
'Gelu': 12,
'FastGelu': 0
}
self.verify_node_count(bert_model, expected_node_count)
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1 @@
Boutput_1JwąÔZ»

View file

@ -0,0 +1,2 @@

BstartJ(Tţ ˝đ·8˝â*0˝<43>s+˝ĎlŘĽć«*˝Dű*˝<>÷řĽ&ü)˝

View file

@ -0,0 +1,2 @@

BendJ(€'±<Œ <9ù ½<@|=ÌAC=9lå<*¨<5Ô<35><]ñ;