mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Improve BERT optimization script: Gelu and LayerNorm for models from Tensorflow 2.* (#2957)
* Add unit test. Add an option --use_onnxruntime to use onnxruntime to do optimization for pytorch model. Update layer norm and gelu for tensorflow 2.1 keras bert model. Add logging and use f-strings. Add extra checking for tensorflow model reshape fusion. Allow output model to json for test purpose. update match parent path utility function to return index * remove function not used.
This commit is contained in:
parent
0beb75ce77
commit
62383b0328
18 changed files with 556 additions and 110 deletions
|
|
@ -3,6 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
#--------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import onnx
|
||||
import sys
|
||||
import argparse
|
||||
|
|
@ -11,6 +12,8 @@ from collections import deque
|
|||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
from OnnxModel import OnnxModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BertOnnxModel(OnnxModel):
|
||||
def __init__(self, model, num_heads, hidden_size, sequence_length, input_int32, float16, gpu_only, verbose):
|
||||
assert num_heads > 0
|
||||
|
|
@ -219,7 +222,7 @@ class BertOnnxModel(OnnxModel):
|
|||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
self.update_graph()
|
||||
print("Fused Attention count:", attention_count)
|
||||
logger.info(f"Fused Attention count:{attention_count}")
|
||||
|
||||
def fuse_gelu(self, gelu_op_name):
|
||||
self.fuse_gelu_with_elf(gelu_op_name)
|
||||
|
|
@ -294,7 +297,7 @@ class BertOnnxModel(OnnxModel):
|
|||
self.remove_nodes(nodes_to_remove)
|
||||
self.add_nodes(nodes_to_add)
|
||||
if len(nodes_to_add) > 0:
|
||||
print("Fused {} count:{}".format('FastGelu (approximation)' if gelu_op_name == 'FastGelu' else 'Gelu', len(nodes_to_add)))
|
||||
logger.info("Fused {} count:{}".format('FastGelu (approximation)' if gelu_op_name == 'FastGelu' else 'Gelu', len(nodes_to_add)))
|
||||
|
||||
"""
|
||||
Fuse Gelu with tanh into one node:
|
||||
|
|
@ -389,7 +392,7 @@ class BertOnnxModel(OnnxModel):
|
|||
nodes_to_add.append(gelu_node)
|
||||
|
||||
if len(nodes_to_add) > 0:
|
||||
print("Fused {} count: {}".format('Gelu (FastGelu fits better)' if gelu_op_name == 'Gelu' else 'FastGelu', len(nodes_to_add)))
|
||||
logger.info("Fused {} count: {}".format('Gelu (FastGelu fits better)' if gelu_op_name == 'Gelu' else 'FastGelu', len(nodes_to_add)))
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
self.add_nodes(nodes_to_add)
|
||||
|
|
@ -437,7 +440,7 @@ class BertOnnxModel(OnnxModel):
|
|||
nodes_to_add.append(gelu_node)
|
||||
|
||||
if len(nodes_to_add) > 0:
|
||||
print("Fused FastGelu with Bias count:", len(nodes_to_add))
|
||||
logger.info(f"Fused FastGelu with Bias count:{len(nodes_to_add)}")
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
self.add_nodes(nodes_to_add)
|
||||
|
|
@ -484,7 +487,7 @@ class BertOnnxModel(OnnxModel):
|
|||
nodes_to_add.append(new_node)
|
||||
|
||||
if len(nodes_to_add) > 0:
|
||||
print("Fused SkipLayerNormalization with Bias count:", len(nodes_to_add))
|
||||
logger.info(f"Fused SkipLayerNormalization with Bias count:{len(nodes_to_add)}")
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
self.add_nodes(nodes_to_add)
|
||||
|
|
@ -540,7 +543,11 @@ class BertOnnxModel(OnnxModel):
|
|||
concat_2 = self.get_initializer(concat_node.input[2])
|
||||
if concat_2 is None:
|
||||
continue
|
||||
shape.extend(numpy_helper.to_array(concat_2))
|
||||
concat_value = numpy_helper.to_array(concat_2)
|
||||
if isinstance(concat_value, list):
|
||||
shape.extend(concat_value)
|
||||
else:
|
||||
shape.append(concat_value)
|
||||
|
||||
if len(concat_node.input) == 4 and self.get_initializer(concat_node.input[3]) is None:
|
||||
path2 = self.match_parent_path(concat_node, ['Unsqueeze', 'Div', 'Gather', 'Shape'], [3, 0, 0, 0], output_name_to_node)
|
||||
|
|
@ -552,7 +559,12 @@ class BertOnnxModel(OnnxModel):
|
|||
concat_3 = self.get_initializer(concat_node.input[3])
|
||||
if concat_3 is None:
|
||||
continue
|
||||
shape.extend(numpy_helper.to_array(concat_3))
|
||||
|
||||
concat_value = numpy_helper.to_array(concat_3)
|
||||
if isinstance(concat_value, list):
|
||||
shape.extend(concat_value)
|
||||
else:
|
||||
shape.append(concat_value)
|
||||
|
||||
root_input = reshape_node.input[0]
|
||||
same_shape_input = True
|
||||
|
|
@ -582,7 +594,7 @@ class BertOnnxModel(OnnxModel):
|
|||
nodes_to_remove.extend(path3)
|
||||
nodes_to_add.append(new_node)
|
||||
|
||||
print("Fused Reshape count:", len(nodes_to_add))
|
||||
logger.info(f"Fused Reshape count:{len(nodes_to_add)}")
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
self.add_nodes(nodes_to_add)
|
||||
|
|
@ -615,10 +627,10 @@ class BertOnnxModel(OnnxModel):
|
|||
output_name_to_node = self.output_name_to_node()
|
||||
|
||||
if len(self.mask_indice) == 0:
|
||||
print("skip embed layer fusion since mask input is not found")
|
||||
logger.info("skip embed layer fusion since mask input is not found")
|
||||
return
|
||||
if len(self.mask_indice) > 1:
|
||||
print("skip embed layer fusion since there are multiple mask inputs found")
|
||||
logger.info("skip embed layer fusion since there are multiple mask inputs found")
|
||||
return
|
||||
mask_input_name = next(iter(self.mask_indice))
|
||||
mask_output_name = self.mask_indice[mask_input_name]
|
||||
|
|
@ -635,13 +647,14 @@ class BertOnnxModel(OnnxModel):
|
|||
break
|
||||
|
||||
if normalize_node is None:
|
||||
print("Failed to find embedding layer")
|
||||
logger.info("Failed to find embedding layer")
|
||||
return
|
||||
|
||||
# Here we assume the order of embedding is word_embedding +
|
||||
# position_embedding + segment_embedding.
|
||||
word_embedding_path = self.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 0])
|
||||
if word_embedding_path is None:
|
||||
print("Failed to find word embedding")
|
||||
logger.info("Failed to find word embedding")
|
||||
return
|
||||
add_node, word_embedding_gather = word_embedding_path
|
||||
input_ids = word_embedding_gather.input[1]
|
||||
|
|
@ -654,14 +667,14 @@ class BertOnnxModel(OnnxModel):
|
|||
if position_embedding_path is None:
|
||||
position_embedding_path = self.match_parent_path(add_node, ['Gather', 'Expand', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [1, 1, 1, 1, 0, 0])
|
||||
if position_embedding_path is None:
|
||||
print("Failed to find position embedding")
|
||||
logger.info("Failed to find position embedding")
|
||||
return
|
||||
position_embedding_weight_node, position_embedding_expand, _, _, _, position_embedding_shape = position_embedding_path
|
||||
else:
|
||||
position_embedding_weight_node, position_embedding_expand, position_embedding_shape = position_embedding_path
|
||||
|
||||
if not position_embedding_shape is None and position_embedding_shape.input[0] != input_ids:
|
||||
print("position and word embedding is expected to be applied on same input")
|
||||
logger.info("position and word embedding is expected to be applied on same input")
|
||||
return
|
||||
else:
|
||||
_, position_embedding_weight_node = position_embedding_path
|
||||
|
|
@ -670,7 +683,7 @@ class BertOnnxModel(OnnxModel):
|
|||
if segment_embedding_path is None:
|
||||
segment_embedding_path = self.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1])
|
||||
if segment_embedding_path is None:
|
||||
print("Failed to find segment embedding")
|
||||
logger.info("Failed to find segment embedding")
|
||||
return
|
||||
_, segment_embedding_gather = segment_embedding_path
|
||||
else:
|
||||
|
|
@ -724,7 +737,7 @@ class BertOnnxModel(OnnxModel):
|
|||
self.remove_nodes(nodes_to_remove)
|
||||
self.add_node(embed_node)
|
||||
self.update_graph()
|
||||
print("Fused EmbedLayerNormalization count: 1")
|
||||
logger.info("Fused EmbedLayerNormalization count: 1")
|
||||
|
||||
# Change graph input data type int32 if needed.
|
||||
if self.input_int32:
|
||||
|
|
@ -816,17 +829,6 @@ class BertOnnxModel(OnnxModel):
|
|||
| |
|
||||
+----------------------+
|
||||
|
||||
TODO: Batch Layer Norm from Keras in Tensorflow:
|
||||
+----------------------+
|
||||
| |
|
||||
| v (B) (A)
|
||||
Add --> ReduceMean --> Sub --> Mul --> ReduceMean --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add
|
||||
| | | ^ ^
|
||||
| | | | |
|
||||
| +----------------------------------------------------------------------------|-------+ |
|
||||
| v |
|
||||
+-------------------------------------------------------------------------------------> Mul--------------------+
|
||||
|
||||
"""
|
||||
def fuse_layer_norm(self):
|
||||
input_name_to_nodes = self.input_name_to_nodes()
|
||||
|
|
@ -912,8 +914,8 @@ class BertOnnxModel(OnnxModel):
|
|||
self.remove_nodes(nodes_to_remove)
|
||||
self.add_nodes(skip_layernorm_nodes)
|
||||
self.add_nodes(layernorm_nodes)
|
||||
print("Fused SkipLayerNormalization count:", len(skip_layernorm_nodes))
|
||||
print("Fused LayerNormalization count:", len(layernorm_nodes))
|
||||
logger.info(f"Fused SkipLayerNormalization count: {len(skip_layernorm_nodes)}")
|
||||
logger.info(f"Fused LayerNormalization count: {len(layernorm_nodes)}")
|
||||
|
||||
def preprocess(self):
|
||||
return
|
||||
|
|
@ -945,4 +947,4 @@ class BertOnnxModel(OnnxModel):
|
|||
# Use symbolic batch dimension in input and output.
|
||||
self.update_dynamic_batch_io()
|
||||
|
||||
print("opset verion", self.model.opset_import[0].version)
|
||||
logger.info(f"opset verion: {self.model.opset_import[0].version}")
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
#--------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import onnx
|
||||
import sys
|
||||
import argparse
|
||||
|
|
@ -11,10 +12,91 @@ from collections import deque
|
|||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
from BertOnnxModel import BertOnnxModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BertOnnxModelTF(BertOnnxModel):
|
||||
def __init(self, model, num_heads, hidden_size, sequence_length, input_int32, float16, gpu_only, verbose):
|
||||
super().__init__(model, num_heads, hidden_size, sequence_length, verbose)
|
||||
|
||||
"""
|
||||
Fuse Gelu with Erf into one node:
|
||||
+----------------------------------------------+
|
||||
| |
|
||||
| v
|
||||
[root] --> Mul -----> Erf --> Add --> Mul -->Mul
|
||||
(A=0.7071067690849304) (B=1) (B=0.5)
|
||||
|
||||
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
|
||||
"""
|
||||
def fuse_gelu_with_elf(self, gelu_op_name):
|
||||
input_name_to_nodes = self.input_name_to_nodes()
|
||||
output_name_to_node = self.output_name_to_node()
|
||||
|
||||
nodes_to_remove = []
|
||||
nodes_to_add = []
|
||||
|
||||
for node in self.get_nodes_by_op_type('Erf'):
|
||||
erf_node = node
|
||||
|
||||
if erf_node.output[0] not in input_name_to_nodes:
|
||||
continue
|
||||
children = input_name_to_nodes[erf_node.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != 'Add':
|
||||
continue
|
||||
add_after_erf = children[0]
|
||||
|
||||
if not self.has_constant_input(add_after_erf, 1):
|
||||
continue
|
||||
|
||||
if add_after_erf.output[0] not in input_name_to_nodes:
|
||||
continue
|
||||
children = input_name_to_nodes[add_after_erf.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != 'Mul':
|
||||
continue
|
||||
mul_half = children[0]
|
||||
|
||||
if not self.has_constant_input(mul_half, 0.5):
|
||||
continue
|
||||
|
||||
first_mul = self.match_parent(erf_node, 'Mul', 0, output_name_to_node)
|
||||
if first_mul is None:
|
||||
continue
|
||||
|
||||
i = self.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
|
||||
if i < 0:
|
||||
continue
|
||||
|
||||
root_node = self.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node)
|
||||
if root_node is None:
|
||||
continue
|
||||
|
||||
if mul_half.output[0] not in input_name_to_nodes:
|
||||
continue
|
||||
children = input_name_to_nodes[mul_half.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != 'Mul':
|
||||
continue
|
||||
last_mul = children[0]
|
||||
|
||||
if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]):
|
||||
continue
|
||||
|
||||
subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
|
||||
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [last_mul.output[0]], input_name_to_nodes, output_name_to_node):
|
||||
continue
|
||||
|
||||
nodes_to_remove.extend(subgraph_nodes)
|
||||
gelu_node = onnx.helper.make_node(gelu_op_name,
|
||||
inputs=[root_node.output[0]],
|
||||
outputs=[last_mul.output[0]])
|
||||
gelu_node.domain = "com.microsoft"
|
||||
nodes_to_add.append(gelu_node)
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
self.add_nodes(nodes_to_add)
|
||||
if len(nodes_to_add) > 0:
|
||||
logger.info("Fused {} count:{}".format('FastGelu (approximation)' if gelu_op_name == 'FastGelu' else 'Gelu', len(nodes_to_add)))
|
||||
|
||||
|
||||
"""
|
||||
Fuse Gelu with tanh into one node:
|
||||
+---------------------------+
|
||||
|
|
@ -111,7 +193,7 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
nodes_to_add.append(gelu_node)
|
||||
|
||||
if len(nodes_to_add) > 0:
|
||||
print("Fused {} count: {}".format('Gelu (FastGelu fits better)' if gelu_op_name == 'Gelu' else 'FastGelu', len(nodes_to_add)))
|
||||
logger.info("Fused {} count: {}".format('Gelu (FastGelu fits better)' if gelu_op_name == 'Gelu' else 'FastGelu', len(nodes_to_add)))
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
self.add_nodes(nodes_to_add)
|
||||
|
|
@ -151,11 +233,25 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
nodes_to_add.append(new_node)
|
||||
|
||||
def __fuse_reshape_after_sotfmax(self, reshape_node, nodes_to_remove, nodes_to_add):
|
||||
# Check that it is reshape after softmax.
|
||||
path = self.match_parent_path(reshape_node, ['Transpose', 'MatMul', 'Softmax', 'Add'], [0, 0, 0, 0])
|
||||
if path is None:
|
||||
return
|
||||
|
||||
path0 = self.match_parent_path(reshape_node, ['Cast', 'Concat', 'Unsqueeze', 'Mul'], [1, 0, 0, 0])
|
||||
if path0 is None:
|
||||
return
|
||||
cast_node, concat_node, unsqueeze_node, mul_node = path0
|
||||
|
||||
# Verify that cast has attribute "to" = 7
|
||||
is_good_cast = False
|
||||
for att in cast_node.attribute:
|
||||
if att.name == 'to' and att.i == 7:
|
||||
is_good_cast = True
|
||||
break
|
||||
if not is_good_cast:
|
||||
return
|
||||
|
||||
if not len(concat_node.input) == 2:
|
||||
return
|
||||
|
||||
|
|
@ -184,7 +280,7 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
|
||||
def __fuse_reshape_after_normalize(self, reshape_node, nodes_to_remove):
|
||||
parent = self.get_parent(reshape_node, 0)
|
||||
if not parent.op_type == self.normalize_name:
|
||||
if parent is None or not parent.op_type == self.normalize_name:
|
||||
return
|
||||
|
||||
parent_path = self.match_parent_path(
|
||||
|
|
@ -199,6 +295,9 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
|
||||
self.replace_input_of_all_nodes(reshape_node.output[0], reshape_node.input[0])
|
||||
|
||||
|
||||
|
||||
|
||||
def fuse_reshape(self):
|
||||
nodes = self.nodes()
|
||||
input_name_to_nodes = self.input_name_to_nodes()
|
||||
|
|
@ -212,7 +311,7 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
self.__fuse_reshape_after_sotfmax(reshape_node, nodes_to_remove, nodes_to_add)
|
||||
self.__fuse_reshape_after_normalize(reshape_node, nodes_to_remove)
|
||||
|
||||
print("Count of nodes removed for Reshape fuse:", len(nodes_to_remove))
|
||||
logger.info(f"Count of nodes removed for Reshape fuse:{len(nodes_to_remove)}")
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
self.add_nodes(nodes_to_add)
|
||||
|
|
@ -238,51 +337,81 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
layernorm_nodes = []
|
||||
for node in self.nodes():
|
||||
if node.op_type == 'Add':
|
||||
return_indice=[]
|
||||
parent_nodes = self.match_parent_path(
|
||||
node,
|
||||
['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean'],
|
||||
[ 1, 1, 1, 0, 0, 0, 0, 0, 0, 1],
|
||||
output_name_to_node)
|
||||
[ 1, 1, None, 0, 0, 0, None, 0, 0, None],
|
||||
output_name_to_node,
|
||||
return_indice=return_indice)
|
||||
|
||||
if parent_nodes is None:
|
||||
continue
|
||||
|
||||
assert len(return_indice) == 3
|
||||
if not (return_indice[0] in [0, 1] and return_indice[1] in [0, 1] and return_indice[2] in [0, 1]):
|
||||
logger.debug("return indice is exepected in [0, 1], but got {return_indice}")
|
||||
continue
|
||||
|
||||
sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes
|
||||
|
||||
mul_node_3 = self.get_parent(node, 0, output_name_to_node)
|
||||
mul_node_3 = self.match_parent(node, 'Mul', 0, output_name_to_node)
|
||||
if mul_node_3 is None:
|
||||
logger.debug("mul_node_3 not found")
|
||||
continue
|
||||
|
||||
root_node = self.get_parent(reduce_mean_node_1, 0, output_name_to_node)
|
||||
if root_node is None:
|
||||
logger.debug("root node is none")
|
||||
continue
|
||||
|
||||
i, add_weight = self.get_constant_input(add_node_0)
|
||||
#if add_weight is None or add_weight <= 0 or add_weight > 1.0E-5:
|
||||
# continue
|
||||
i, epsilon = self.get_constant_input(add_node_0)
|
||||
if epsilon is None or epsilon <= 0 or epsilon > 1.0E-5:
|
||||
logger.debug("epsilon is not matched")
|
||||
continue
|
||||
|
||||
nodes_to_remove.extend([node, sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1,mul_node_3])
|
||||
if reduce_mean_node_1.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input:
|
||||
logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
|
||||
continue
|
||||
|
||||
if mul_node_2.input[0] != mul_node_2.input[1]:
|
||||
logger.debug("mul_node_2 shall have two same inputs")
|
||||
continue
|
||||
|
||||
subgraph_nodes = [node, sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1,mul_node_3]
|
||||
if not self.is_safe_to_fuse_nodes(subgraph_nodes, node.output, self.input_name_to_nodes(), self.output_name_to_node()):
|
||||
logger.debug("not safe to fuse layer normalization")
|
||||
continue
|
||||
|
||||
nodes_to_remove.extend(subgraph_nodes)
|
||||
|
||||
weight_input = mul_node_1.input[1]
|
||||
bias_input = sub_node_0.input[0]
|
||||
|
||||
if root_node.op_type == 'Add':
|
||||
nodes_to_remove.append(root_node)
|
||||
normalize_node = onnx.helper.make_node(self.normalize_name,
|
||||
inputs=[root_node.input[0], root_node.input[1], weight_input, bias_input],
|
||||
outputs=[node.output[0]],
|
||||
name=self.create_node_name(self.normalize_name, name_prefix="SkipLayerNorm"))
|
||||
normalize_node.domain = "com.microsoft"
|
||||
skip_layernorm_nodes.extend([normalize_node])
|
||||
else:
|
||||
normalize_node = onnx.helper.make_node('LayerNormalization',
|
||||
inputs=[reduce_mean_node_1.input[0], weight_input, bias_input],
|
||||
outputs=[node.output[0]], epsilon=add_weight)
|
||||
layernorm_nodes.extend([normalize_node])
|
||||
subgraph_nodes.append(root_node)
|
||||
if not self.is_safe_to_fuse_nodes(subgraph_nodes, node.output, self.input_name_to_nodes(), self.output_name_to_node()):
|
||||
subgraph_nodes.pop()
|
||||
else:
|
||||
nodes_to_remove.append(root_node)
|
||||
normalize_node = onnx.helper.make_node(self.normalize_name,
|
||||
inputs=[root_node.input[0], root_node.input[1], weight_input, bias_input],
|
||||
outputs=[node.output[0]],
|
||||
name=self.create_node_name(self.normalize_name, name_prefix="SkipLayerNorm"))
|
||||
normalize_node.domain = "com.microsoft"
|
||||
skip_layernorm_nodes.extend([normalize_node])
|
||||
continue
|
||||
|
||||
normalize_node = onnx.helper.make_node('LayerNormalization',
|
||||
inputs=[reduce_mean_node_1.input[0], weight_input, bias_input],
|
||||
outputs=[node.output[0]], epsilon=epsilon)
|
||||
layernorm_nodes.extend([normalize_node])
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
self.add_nodes(skip_layernorm_nodes)
|
||||
self.add_nodes(layernorm_nodes)
|
||||
print("Fused SkipLayerNormalization count:", len(skip_layernorm_nodes))
|
||||
print("Fused LayerNormalization count:", len(layernorm_nodes))
|
||||
logger.info(f"Fused SkipLayerNormalization count: {len(skip_layernorm_nodes)}")
|
||||
logger.info(f"Fused LayerNormalization count: {len(layernorm_nodes)}")
|
||||
|
||||
def remove_identity(self):
|
||||
nodes_to_remove = []
|
||||
|
|
@ -292,7 +421,7 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
self.replace_input_of_all_nodes(node.output[0], node.input[0])
|
||||
nodes_to_remove.append(node)
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
print("Removed Identity count:", len(nodes_to_remove))
|
||||
logger.info(f"Removed Identity count: {len(nodes_to_remove)}")
|
||||
|
||||
def fuse_word_embedding(self):
|
||||
nodes_to_remove = []
|
||||
|
|
@ -319,7 +448,7 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
nodes_to_remove.extend([reshape_node_0, cast_node_0, concat_node, unsqueeze_node, cast_node_1, squeeze_node, slice_node, cast_node_2, shape_node, reshape_node_2, node])
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
print("Fused word embedding" if len(nodes_to_remove) > 0 else "Failed to fuse word embedding")
|
||||
logger.info("Fused word embedding" if len(nodes_to_remove) > 0 else "Failed to fuse word embedding")
|
||||
|
||||
def fuse_segment_embedding(self):
|
||||
nodes_to_remove = []
|
||||
|
|
@ -328,11 +457,16 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
data_path = self.match_parent_path(node, ['MatMul', 'OneHot', 'Reshape'], [0, 0, 0])
|
||||
if data_path is None:
|
||||
continue
|
||||
matmul_node, onehot_node, reshape_node_0 = data_path
|
||||
concat_node_0 = self.get_parent(onehot_node, 2, self.output_name_to_node())
|
||||
if not concat_node_0.op_type == 'Concat':
|
||||
continue
|
||||
|
||||
matmul_node, onehot_node, reshape_node_0 = data_path
|
||||
|
||||
subgraph_nodes = [matmul_node, onehot_node, reshape_node_0]
|
||||
if self.get_initializer(onehot_node.input[2]) is None:
|
||||
concat_node_0 = self.get_parent(onehot_node, 2)
|
||||
if concat_node_0 is None or concat_node_0.op_type != 'Concat':
|
||||
continue
|
||||
subgraph_nodes.append(concat_node_0)
|
||||
|
||||
shape_path = self.match_parent_path(
|
||||
node,
|
||||
['Cast', 'Concat', 'Unsqueeze', 'Cast', 'Squeeze', 'Slice', 'Cast', 'Shape', 'Gather'],
|
||||
|
|
@ -349,18 +483,13 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
outputs=node.output,
|
||||
name='segment_embedding_gather')
|
||||
|
||||
nodes_to_remove.extend([
|
||||
matmul_node,
|
||||
onehot_node,
|
||||
reshape_node_0,
|
||||
concat_node_0
|
||||
])
|
||||
nodes_to_remove.extend(subgraph_nodes)
|
||||
|
||||
nodes_to_remove.extend([cast_node_0, concat_node_1, unsqueeze_node, cast_node_1, squeeze_node, slice_node, cast_node_2, shape_node, node])
|
||||
self.add_node(gather_node)
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
print("Fused segment embedding" if len(nodes_to_remove) > 0 else "Failed to fuse segment embedding")
|
||||
logger.info("Fused segment embedding" if len(nodes_to_remove) > 0 else "Failed to fuse segment embedding")
|
||||
|
||||
def fuse_mask(self):
|
||||
nodes_to_remove = []
|
||||
|
|
@ -408,10 +537,10 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
self.add_node(unsqueeze_added_2)
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
print("Fused mask" if len(nodes_to_remove) > 0 else "Failed to fuse mask")
|
||||
logger.info("Fused mask" if len(nodes_to_remove) > 0 else "Failed to fuse mask")
|
||||
|
||||
def preprocess(self):
|
||||
self.remove_identity()
|
||||
self.fuse_word_embedding()
|
||||
self.fuse_segment_embedding()
|
||||
self.fuse_mask()
|
||||
self.fuse_mask()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
#--------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import onnx
|
||||
import sys
|
||||
import argparse
|
||||
|
|
@ -10,6 +11,8 @@ import numpy as np
|
|||
from collections import deque
|
||||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(self, model, verbose):
|
||||
self.model = model
|
||||
|
|
@ -124,18 +127,57 @@ class OnnxModel:
|
|||
|
||||
return output_name_to_node[input]
|
||||
|
||||
def match_parent(self, node, parent_op_type, input_index=None, output_name_to_node=None, exclude=[]):
|
||||
def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude=[]):
|
||||
'''
|
||||
Find parent node based on constraints on op_type.
|
||||
|
||||
Args:
|
||||
node (str): current node name.
|
||||
parent_op_type (str): constraint of parent node op_type.
|
||||
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
||||
exclude (list): list of nodes that are excluded (not allowed to match as parent).
|
||||
|
||||
Returns:
|
||||
parent: The matched parent node. None if not found.
|
||||
index: The input index of matched parent node. None if not found.
|
||||
'''
|
||||
for i, input in enumerate(node.input):
|
||||
if input in output_name_to_node:
|
||||
parent = output_name_to_node[input]
|
||||
if parent.op_type == parent_op_type and parent not in exclude:
|
||||
return parent, i
|
||||
|
||||
return None, None
|
||||
|
||||
def match_parent(self, node, parent_op_type, input_index=None, output_name_to_node=None, exclude=[], return_indice=None):
|
||||
'''
|
||||
Find parent node based on constraints on op_type and index.
|
||||
When input_index is None, we will find the first parent node based on constraints, and return_indice will be appended the corresponding input index.
|
||||
|
||||
Args:
|
||||
node (str): current node name.
|
||||
parent_op_type (str): constraint of parent node op_type.
|
||||
input_index (int or None): only check the parent given input index of current node.
|
||||
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
||||
exclude (list): list of nodes that are excluded (not allowed to match as parent).
|
||||
return_indice (list): a list to append the input index when input_index is None.
|
||||
|
||||
Returns:
|
||||
parent: The matched parent node.
|
||||
'''
|
||||
assert node is not None
|
||||
assert input_index is None or input_index >= 0
|
||||
|
||||
if output_name_to_node is None:
|
||||
output_name_to_node = self.output_name_to_node()
|
||||
|
||||
if input_index is None:
|
||||
parents = self.get_parents(node, output_name_to_node)
|
||||
for parent in parents:
|
||||
if parent.op_type == parent_op_type and parent not in exclude:
|
||||
return parent
|
||||
return None
|
||||
parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
|
||||
if return_indice is not None:
|
||||
return_indice.append(index)
|
||||
return parent
|
||||
|
||||
if input_index < 0 or input_index >= len(node.input):
|
||||
if input_index >= len(node.input):
|
||||
return None
|
||||
|
||||
parent = self.get_parent(node, input_index, output_name_to_node)
|
||||
|
|
@ -144,7 +186,21 @@ class OnnxModel:
|
|||
|
||||
return None
|
||||
|
||||
def match_parent_path(self, node, parent_op_types, parent_input_index, output_name_to_node=None):
|
||||
def match_parent_path(self, node, parent_op_types, parent_input_index, output_name_to_node=None, return_indice=None):
|
||||
'''
|
||||
Find a sequence of input edges based on constraints on parent op_type and index.
|
||||
When input_index is None, we will find the first parent node based on constraints, and return_indice will be appended the corresponding input index.
|
||||
|
||||
Args:
|
||||
node (str): current node name.
|
||||
parent_op_types (str): constraint of parent node op_type of each input edge.
|
||||
parent_input_index (list): constraint of input index of each input edge. None means no constraint.
|
||||
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
||||
return_indice (list): a list to append the input index when there is no constraint on input index of an edge.
|
||||
|
||||
Returns:
|
||||
parents: a list of matched parent node.
|
||||
'''
|
||||
assert(len(parent_input_index) == len(parent_op_types))
|
||||
|
||||
if output_name_to_node is None:
|
||||
|
|
@ -153,7 +209,7 @@ class OnnxModel:
|
|||
current_node = node
|
||||
matched_parents = []
|
||||
for i, op_type in enumerate(parent_op_types):
|
||||
matched_parent = self.match_parent(current_node, op_type, parent_input_index[i], output_name_to_node, exclude=[])
|
||||
matched_parent = self.match_parent(current_node, op_type, parent_input_index[i], output_name_to_node, exclude=[], return_indice=return_indice)
|
||||
if matched_parent is None:
|
||||
return None
|
||||
|
||||
|
|
@ -370,7 +426,7 @@ class OnnxModel:
|
|||
self.remove_nodes(unused_nodes)
|
||||
|
||||
if len(unused_nodes) > 0:
|
||||
print("Removed unused constant nodes:", len(unused_nodes))
|
||||
logger.info(f"Removed unused constant nodes: {len(unused_nodes)}")
|
||||
|
||||
def update_graph(self):
|
||||
graph = self.model.graph
|
||||
|
|
@ -382,7 +438,7 @@ class OnnxModel:
|
|||
if input_name not in remaining_input_names:
|
||||
remaining_input_names.append(input_name)
|
||||
if self.verbose:
|
||||
print("remaining input names", remaining_input_names)
|
||||
logger.info(f"remaining input names: {remaining_input_names}" )
|
||||
|
||||
# remove graph input that is not used
|
||||
inputs_to_remove = []
|
||||
|
|
@ -392,7 +448,8 @@ class OnnxModel:
|
|||
for input in inputs_to_remove:
|
||||
graph.input.remove(input)
|
||||
if self.verbose:
|
||||
print("remove unused input ", len(inputs_to_remove), [input.name for input in inputs_to_remove])
|
||||
names_to_remove = [input.name for input in inputs_to_remove]
|
||||
logger.info(f"remove {len(inputs_to_remove)} unused inputs: {names_to_remove}")
|
||||
|
||||
# remove weights that are not used
|
||||
weights_to_remove = []
|
||||
|
|
@ -406,21 +463,22 @@ class OnnxModel:
|
|||
graph.initializer.remove(initializer)
|
||||
|
||||
if self.verbose:
|
||||
print("remove unused initializers:", len(weights_to_remove), [initializer.name for initializer in weights_to_remove])
|
||||
print("remaining initializers:", weights_to_keep)
|
||||
names_to_remove = [initializer.name for initializer in weights_to_remove]
|
||||
logger.info(f"remove {len(weights_to_remove)} unused initializers: {names_to_remove}")
|
||||
logger.info(f"remaining initializers:{weights_to_keep}")
|
||||
|
||||
self.remove_unused_constant()
|
||||
|
||||
def is_safe_to_fuse_nodes(self, nodes_to_remove, keep_outputs, input_name_to_nodes, output_name_to_node):
|
||||
for node in nodes_to_remove:
|
||||
for output in node.output:
|
||||
if output in keep_outputs:
|
||||
for node_to_remove in nodes_to_remove:
|
||||
for output_to_remove in node_to_remove.output:
|
||||
if output_to_remove in keep_outputs:
|
||||
continue
|
||||
|
||||
if output in input_name_to_nodes:
|
||||
for node in input_name_to_nodes[output]:
|
||||
if node not in nodes_to_remove:
|
||||
if output_to_remove in input_name_to_nodes:
|
||||
for impacted_node in input_name_to_nodes[output_to_remove]:
|
||||
if impacted_node not in nodes_to_remove:
|
||||
if self.verbose:
|
||||
print("warning: it is not safe to remove nodes since output", output, "used by", node)
|
||||
logger.warning(f"it is not safe to remove nodes since output {output_to_remove} is used by {impacted_node}")
|
||||
return False
|
||||
return True
|
||||
|
|
@ -63,3 +63,5 @@ See below for description of all the options:
|
|||
By default, model uses float32 in computation. If this flag is specified, half-precision float will be used. This option is recommended for NVidia GPU with Tensor Core like V100 and T4. For older GPUs, float32 is likely faster.
|
||||
- **verbose**: (*optional*)
|
||||
Print verbose information when this flag is specified.
|
||||
- **run_onnxruntime**: (*optional*)
|
||||
Use onnxruntime to do optimization. This option is only avaiable for pytorch model right now.
|
||||
|
|
|
|||
|
|
@ -7,35 +7,65 @@
|
|||
# SkipLayerNormalization and EmbedLayerNormalization ops to optimize
|
||||
# performance on NVidia GPU and CPU.
|
||||
|
||||
# Note: This script is not required for Bert model exported from PyTorch.
|
||||
# OnnxRuntime has bert model optimization support internally. The recommended way is
|
||||
# to set optimization level to ORT_ENABLE_EXTENDED during Bert model inference.
|
||||
# See the following document for more information:
|
||||
# https://github.com/microsoft/onnxruntime/blob/master/docs/ONNX_Runtime_Graph_Optimizations.md
|
||||
|
||||
# For Bert model exported from PyTorch, OnnxRuntime has bert model optimization support internally.
|
||||
# You can use the option --use_onnxruntime to use model optimization from OnnxRuntime package.
|
||||
# For Bert model file like name.onnx, optimized model for GPU or CPU from OnnxRuntime will output as
|
||||
# name_ort_gpu.onnx or name_ort_cpu.onnx in the same directory.
|
||||
# This script is retained for experiment purpose. Useful senarios like the following:
|
||||
# (1) Change model from fp32 to fp16.
|
||||
# (2) Change input data type from int64 to int32.
|
||||
# (3) Model cannot be handled to OnnxRuntime graph optimization, and you can modify this script to get optimized model.
|
||||
|
||||
# (3) Some model cannot be handled by OnnxRuntime, and you can modify this script to get optimized model.
|
||||
|
||||
# This script has been tested using the following models:
|
||||
# (1) BertForSequenceClassification as in https://github.com/huggingface/transformers/blob/master/examples/run_glue.py
|
||||
# PyTorch 1.2 or above, and exported to Onnx using opset version 10 or 11.
|
||||
# (2) BertForQuestionAnswering as in https://github.com/huggingface/transformers/blob/master/examples/run_squad.py
|
||||
# PyTorch 1.2 or above, and exported to Onnx using opset version 10 or 11.
|
||||
|
||||
import logging
|
||||
import onnx
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
import onnxruntime
|
||||
from BertOnnxModel import BertOnnxModel
|
||||
from BertOnnxModelTF import BertOnnxModelTF
|
||||
|
||||
def main():
|
||||
logger = logging.getLogger('')
|
||||
|
||||
def run_onnxruntime(onnx_model_path, use_gpu, optimized_model_path=None):
|
||||
if use_gpu and 'CUDAExecutionProvider' not in onnxruntime.get_available_providers():
|
||||
logger.error("There is no gpu for onnxruntime to do optimization.")
|
||||
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
|
||||
|
||||
if optimized_model_path is None:
|
||||
path_prefix = onnx_model_path[:-5] #remove .onnx suffix
|
||||
optimized_model_path = "{}_ort_{}.onnx".format(path_prefix, "gpu" if use_gpu else "cpu")
|
||||
|
||||
sess_options.optimized_model_filepath = optimized_model_path
|
||||
|
||||
if not use_gpu:
|
||||
session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=['CPUExecutionProvider'])
|
||||
else:
|
||||
session = onnxruntime.InferenceSession(onnx_model_path, sess_options)
|
||||
assert 'CUDAExecutionProvider' in session.get_providers() # Make sure there is GPU
|
||||
|
||||
assert os.path.exists(optimized_model_path) and os.path.isfile(optimized_model_path)
|
||||
logger.info("Save optimized model by onnxruntime to {}".format(optimized_model_path))
|
||||
return optimized_model_path
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input', required=True, type=str)
|
||||
parser.add_argument('--output', required=True, type=str)
|
||||
parser.add_argument('--framework', required=True, type=str, help="Original framework. Only support TensorFlow and PyTorch")
|
||||
parser.add_argument('--framework', required=True, type=str.lower, choices=["tensorflow", "pytorch"], help="Original framework")
|
||||
|
||||
# model parameters
|
||||
# model parameters.
|
||||
parser.add_argument('--num_heads', required=False, type=int, default=12, help="number of attention heads")
|
||||
parser.add_argument('--hidden_size', required=False, type=int, default=768)
|
||||
parser.add_argument('--sequence_length', required=False, type=int, default=128)
|
||||
|
|
@ -56,23 +86,60 @@ def main():
|
|||
parser.add_argument('--verbose', required=False, action='store_true')
|
||||
parser.set_defaults(verbose=False)
|
||||
|
||||
parser.add_argument('--use_onnxruntime', required=False, action='store_true')
|
||||
parser.set_defaults(use_onnxruntime=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
def optimize_model(input, framework, gpu_only, num_heads, hidden_size, sequence_length, input_int32, float16, verbose):
|
||||
model = ModelProto()
|
||||
with open(args.input, "rb") as f:
|
||||
with open(input, "rb") as f:
|
||||
model.ParseFromString(f.read())
|
||||
|
||||
if args.framework.lower() == 'tensorflow':
|
||||
bert_model = BertOnnxModelTF(model, args.num_heads, args.hidden_size, args.sequence_length, args.input_int32, args.float16, args.gpu_only, args.verbose)
|
||||
elif args.framework.lower() == 'pytorch':
|
||||
bert_model = BertOnnxModel(model, args.num_heads, args.hidden_size, args.sequence_length, args.input_int32, args.float16, args.gpu_only, args.verbose)
|
||||
else:
|
||||
print("Unsupported framework:" + args.framework)
|
||||
if framework == 'tensorflow':
|
||||
bert_model = BertOnnxModelTF(model, num_heads, hidden_size, sequence_length, input_int32, float16, gpu_only, verbose)
|
||||
else: #framework == 'pytorch'
|
||||
bert_model = BertOnnxModel(model, num_heads, hidden_size, sequence_length, input_int32, float16, gpu_only, verbose)
|
||||
|
||||
bert_model.optimize()
|
||||
|
||||
return bert_model
|
||||
|
||||
def output_model(model, output):
|
||||
if output.endswith(".json"): # output to JSON. Only for test purpose.
|
||||
if isinstance(model, ModelProto):
|
||||
with open(output, "w") as out:
|
||||
out.write(str(model))
|
||||
logger.info("Output JSON to {}.json".format(output))
|
||||
else:
|
||||
with open(output, "wb") as out:
|
||||
out.write(model.SerializeToString())
|
||||
logger.info("Output final model to {}".format(output))
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
# output logging to stdout
|
||||
log_handler = logging.StreamHandler(sys.stdout)
|
||||
log_handler.setFormatter(logging.Formatter('%(asctime)s %(message)s'))
|
||||
log_handler.setLevel(logging.DEBUG)
|
||||
logger.addHandler(log_handler)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
input_model_path = args.input
|
||||
if args.use_onnxruntime:
|
||||
if framework == 'tensorflow':
|
||||
logger.warning("onnxruntime does not have optimization for tensorflow model. Ignore the option --use_onnxruntime.")
|
||||
else:
|
||||
input_model_path = run_onnxruntime(input_model_path, args.gpu_only)
|
||||
|
||||
bert_model = optimize_model(input_model_path, args.framework, args.gpu_only, args.num_heads, args.hidden_size, args.sequence_length, args.input_int32, args.float16, args.verbose)
|
||||
|
||||
output_model(bert_model.model, args.output)
|
||||
|
||||
|
||||
with open(args.output, "wb") as out:
|
||||
out.write(bert_model.model.SerializeToString())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
183
onnxruntime/python/tools/bert/test_bert_optimization.py
Normal file
183
onnxruntime/python/tools/bert/test_bert_optimization.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
# For live logging, use the command: pytest -o log_cli=true --log-cli-level=DEBUG
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import onnx
|
||||
import onnxruntime
|
||||
from onnx import helper, TensorProto, ModelProto
|
||||
from onnx.helper import make_node, make_tensor_value_info
|
||||
import numpy as np
|
||||
from onnx import numpy_helper
|
||||
from bert_model_optimization import optimize_model, run_onnxruntime
|
||||
from OnnxModel import OnnxModel
|
||||
|
||||
class TestBertOptimization(unittest.TestCase):
|
||||
def get_model(self, framework, index):
|
||||
if framework == "pytorch":
|
||||
return 'test_data\\bert_squad_pytorch1.4_opset11\\BertForQuestionAnswering_{}.onnx'.format(index)
|
||||
else:
|
||||
return 'test_data\\bert_mrpc_tensorflow2.1_opset10\\TFBertForSequenceClassification_{}.onnx'.format(index)
|
||||
|
||||
def verify_node_count(self, bert_model, expected_node_count):
|
||||
for op_type, count in expected_node_count.items():
|
||||
if len(bert_model.get_nodes_by_op_type(op_type)) != count:
|
||||
print("{}:{} expected={}".format(op_type, len(bert_model.get_nodes_by_op_type(op_type)), count))
|
||||
self.assertEqual(len(bert_model.get_nodes_by_op_type(op_type)), count)
|
||||
|
||||
def test_pytorch_model_0_cpu_onnxruntime(self):
|
||||
input = self.get_model("pytorch", 0)
|
||||
output = 'temp.onnx'
|
||||
run_onnxruntime(input, use_gpu=False, optimized_model_path=output)
|
||||
model = ModelProto()
|
||||
with open(output, "rb") as f:
|
||||
model.ParseFromString(f.read())
|
||||
os.remove(output)
|
||||
bert_model = OnnxModel(model, verbose=False)
|
||||
expected_node_count = {
|
||||
'EmbedLayerNormalization': 1,
|
||||
'Attention': 12,
|
||||
'SkipLayerNormalization': 24,
|
||||
'Gelu': 0,
|
||||
'FastGelu': 0,
|
||||
'BiasGelu': 12
|
||||
}
|
||||
self.verify_node_count(bert_model, expected_node_count)
|
||||
|
||||
def test_pytorch_model_0_gpu_onnxruntime(self):
|
||||
if 'CUDAExecutionProvider' not in onnxruntime.get_available_providers():
|
||||
print("skip test_pytorch_model_0_gpu_onnxruntime since no gpu found")
|
||||
return
|
||||
|
||||
input = self.get_model("pytorch", 0)
|
||||
output = 'temp.onnx'
|
||||
run_onnxruntime(input, use_gpu=True, optimized_model_path=output)
|
||||
model = ModelProto()
|
||||
with open(output, "rb") as f:
|
||||
model.ParseFromString(f.read())
|
||||
os.remove(output)
|
||||
bert_model = OnnxModel(model, verbose=False)
|
||||
expected_node_count = {
|
||||
'EmbedLayerNormalization': 1,
|
||||
'Attention': 12,
|
||||
'SkipLayerNormalization': 24,
|
||||
'Gelu': 0,
|
||||
'FastGelu': 12,
|
||||
'BiasGelu': 0
|
||||
}
|
||||
self.verify_node_count(bert_model, expected_node_count)
|
||||
|
||||
def test_pytorch_model_1_cpu_onnxruntime(self):
|
||||
input = self.get_model("pytorch", 1)
|
||||
output = 'temp.onnx'
|
||||
run_onnxruntime(input, use_gpu=False, optimized_model_path=output)
|
||||
model = ModelProto()
|
||||
with open(output, "rb") as f:
|
||||
model.ParseFromString(f.read())
|
||||
os.remove(output)
|
||||
bert_model = OnnxModel(model, verbose=False)
|
||||
expected_node_count = {
|
||||
'EmbedLayerNormalization': 1,
|
||||
'Attention': 12,
|
||||
'LayerNormalization': 24,
|
||||
'SkipLayerNormalization': 0,
|
||||
'Gelu': 0,
|
||||
'FastGelu': 0,
|
||||
'BiasGelu': 12
|
||||
}
|
||||
self.verify_node_count(bert_model, expected_node_count)
|
||||
|
||||
def test_pytorch_model_1_gpu_onnxruntime(self):
|
||||
if 'CUDAExecutionProvider' not in onnxruntime.get_available_providers():
|
||||
print("skip test_pytorch_model_1_gpu_onnxruntime since no gpu found")
|
||||
return
|
||||
|
||||
input = self.get_model("pytorch", 1)
|
||||
output = 'temp.onnx'
|
||||
run_onnxruntime(input, use_gpu=True, optimized_model_path=output)
|
||||
model = ModelProto()
|
||||
with open(output, "rb") as f:
|
||||
model.ParseFromString(f.read())
|
||||
os.remove(output)
|
||||
bert_model = OnnxModel(model, verbose=False)
|
||||
expected_node_count = {
|
||||
'EmbedLayerNormalization': 1,
|
||||
'Attention': 12,
|
||||
'LayerNormalization': 24,
|
||||
'SkipLayerNormalization': 0,
|
||||
'Gelu': 0,
|
||||
'FastGelu': 12,
|
||||
'BiasGelu': 0
|
||||
}
|
||||
self.verify_node_count(bert_model, expected_node_count)
|
||||
|
||||
def test_pytorch_model_0_cpu(self):
|
||||
input = self.get_model("pytorch", 0)
|
||||
bert_model = optimize_model(input, framework='pytorch', gpu_only=False,
|
||||
num_heads=2, hidden_size=8, sequence_length=10,
|
||||
input_int32=False, float16=False, verbose=False)
|
||||
|
||||
expected_node_count = {
|
||||
'EmbedLayerNormalization': 1,
|
||||
'Attention': 12,
|
||||
'SkipLayerNormalization': 24,
|
||||
'Gelu': 12,
|
||||
'FastGelu': 0,
|
||||
'BiasGelu': 0
|
||||
}
|
||||
self.verify_node_count(bert_model, expected_node_count)
|
||||
|
||||
def test_pytorch_model_0_gpu(self):
|
||||
if 'CUDAExecutionProvider' not in onnxruntime.get_available_providers():
|
||||
print("skip test_pytorch_model_0_gpu since no gpu found")
|
||||
return
|
||||
|
||||
input = self.get_model("pytorch", 0)
|
||||
bert_model = optimize_model(input, framework='pytorch', gpu_only=True,
|
||||
num_heads=2, hidden_size=8, sequence_length=10,
|
||||
input_int32=False, float16=False, verbose=False)
|
||||
|
||||
expected_node_count = {
|
||||
'EmbedLayerNormalization': 1,
|
||||
'Attention': 12,
|
||||
'SkipLayerNormalization': 24,
|
||||
'FastGelu': 12,
|
||||
'Gelu': 0,
|
||||
'BiasGelu': 0
|
||||
}
|
||||
self.verify_node_count(bert_model, expected_node_count)
|
||||
|
||||
def test_tensorflow_model_1_cpu(self):
|
||||
input = self.get_model("tensorflow", 1)
|
||||
|
||||
# The model need constant folding. Use onnxruntime to do so for now.
|
||||
temp = 'temp.onnx'
|
||||
run_onnxruntime(input, use_gpu=False, optimized_model_path=temp)
|
||||
|
||||
bert_model = optimize_model(temp, framework='tensorflow', gpu_only=False,
|
||||
num_heads=2, hidden_size=8, sequence_length=7,
|
||||
input_int32=False, float16=False, verbose=False)
|
||||
os.remove(temp)
|
||||
|
||||
# Optimization for tensorflow model is still on-going.
|
||||
# TODO: update this after code complete.
|
||||
expected_node_count = {
|
||||
'EmbedLayerNormalization': 0,
|
||||
'Attention': 0,
|
||||
'LayerNormalization': 0,
|
||||
'SkipLayerNormalization': 25,
|
||||
'BiasGelu': 0,
|
||||
'Gelu': 12,
|
||||
'FastGelu': 0
|
||||
}
|
||||
self.verify_node_count(bert_model, expected_node_count)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1 @@
|
|||
Boutput_1Jw‘ą‚ÔZ»
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,2 @@
|
|||
|
||||
BstartJ(Tţ˝đ·8˝â*0˝C˝<43>s+˝ĎlŘĽć«*˝Dű*˝<>÷řĽ&ü)˝
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
|
||||
BendJ(€'±<Œ <9ù <î½<@|=ÌAC=9lå<*¨<5Ô<35><]ñ;
|
||||
Loading…
Reference in a new issue