Script for converting BERT model for performance optimization (#2037)

* script for converting ONNX model for BERT performance optimization

* Remove code that not needed anymore.

* refine the script

* Support BERT model exported from PyTorch 1.3
Keep opset version
Exact match in Attention, Layer normalziation fusions.

* read batch_size from input model directly
This commit is contained in:
Tianlei Wu 2019-10-15 16:50:08 -07:00 committed by GitHub
parent ec136ac60f
commit b0f8ec7a7d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 947 additions and 0 deletions

View file

@ -0,0 +1,55 @@
# BERT Model Optimization Tool Overview
This tool converts a BERT ONNX model exported from PyTorch, and generates a optimized model to run faster in NVidia GPU.
Currently, this script **cannot** process BERT models exported from Tensorflow since the graph has some difference.
## Export an BERT model from PyTorch
For example, after using https://github.com/huggingface/transformers to Train a BERT model in PyTorch 1.3, you can use the following function to export ONNX model.
Please specify do_constant_folding=True. That's required for this tool.
```python
def export_onnx(args, model, output_path):
model.eval() # set the model to inference mode
device = torch.device("cpu")
model.to(device)
dummy_input0 = torch.LongTensor(args.eval_batch_size, args.max_seq_length).fill_(1).to(device)
dummy_input1 = torch.LongTensor(args.eval_batch_size, args.max_seq_length).fill_(1).to(device)
dummy_input2 = torch.LongTensor(args.eval_batch_size, args.max_seq_length).fill_(0).to(device)
dummy_input = (dummy_input0, dummy_input1, dummy_input2)
torch.onnx.export(model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
output_path, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ["input_ids", "input_mask", "segment_ids"],
output_names = ["output"],
dynamic_axes={'input_ids' : {0 : 'batch_size'}, # variable lenght axes
'input_mask' : {0 : 'batch_size'},
'segment_ids' : {0 : 'batch_size'},
'output' : {0 : 'batch_size'}})
```
## Model Optimization
Example of using the script bert_model_optimization.py to convert a BERT-large model to run in V100 GPU:
```console
python bert_model_optimization.py --input input_model.onnx --output optimized_model.onnx --num_heads 24 --hidden_size 1024 --sequence_length 128 --input_int32 --float16
```
## Options
See below for description of all the options:
- **input**: input model path
- **output**: output model path
- **num_heads**: (*default: 12*)
Number of attention heads, like 24 for BERT-large model.
- **hidden_size**: (*default: 768*)
- **sequence_length**: (*default: 128*)
Maximum sequence length.
- **input_int32**: (*optional*)
Exported model ususally uses int64 tensor as input. If this flag is specified, int32 tensors will be used as input, and it could avoid un-necessary Cast nodes and get better performance.
- **float16**: (*optional*)
By default, model uses float32 in computation. If this flag is specified, half-precision float will be used. This option is recommended for NVidia GPU with Tensor Core like V100 and T4. For older GPUs, float32 is likely faster.

View file

@ -0,0 +1,892 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#--------------------------------------------------------------------------
# Convert Bert ONNX model exported from PyTorch to use Attention, Gelu, SkipLayerNormalization and
# EmbedLayerNormalization ops to optimize performance on NVidia GPU.
import onnx
import sys
import argparse
import numpy as np
from collections import deque
from onnx import ModelProto, TensorProto, numpy_helper
class OnnxModel:
def __init__(self, model):
self.model = model
self.node_name_counter = {}
def input_name_to_nodes(self):
input_name_to_nodes = {}
for node in self.model.graph.node:
for input_name in node.input:
if input_name not in input_name_to_nodes:
input_name_to_nodes[input_name] = [node]
else:
input_name_to_nodes[input_name].append(node)
return input_name_to_nodes
def output_name_to_node(self):
output_name_to_node = {}
for node in self.model.graph.node:
for output_name in node.output:
output_name_to_node[output_name] = node
return output_name_to_node
def nodes(self):
return self.model.graph.node
def graph(self):
return self.model.graph
def remove_node(self, node):
if node in self.model.graph.node:
self.model.graph.node.remove(node)
def remove_nodes(self, nodes_to_remove):
for node in nodes_to_remove:
self.remove_node(node)
def add_node(self, node):
self.model.graph.node.extend([node])
def add_nodes(self, nodes_to_add):
self.model.graph.node.extend(nodes_to_add)
def add_initializer(self, tensor):
self.model.graph.initializer.extend([tensor])
def add_input(self, input):
self.model.graph.input.extend([input])
@staticmethod
def replace_node_input(node, old_input_name, new_input_name):
assert isinstance(old_input_name, str) and isinstance(new_input_name, str)
for j in range(len(node.input)):
if node.input[j] == old_input_name:
node.input[j] = new_input_name
def replace_input_of_all_nodes(self, old_input_name, new_input_name):
for node in self.model.graph.node:
OnnxModel.replace_node_input(node, old_input_name, new_input_name)
def get_initializer(self,name):
for tensor in self.model.graph.initializer:
if tensor.name == name:
return tensor
return None
def get_nodes_by_op_type(self, op_type):
return [n for n in self.model.graph.node if n.op_type == op_type]
def get_children(self, node, input_name_to_nodes=None):
if (input_name_to_nodes is None):
input_name_to_nodes = self.input_name_to_nodes()
children = []
for output in node.output:
if output in input_name_to_nodes:
for node in input_name_to_nodes[output]:
children.append(node)
return children
def get_parents(self, node, output_name_to_node=None):
if output_name_to_node is None:
output_name_to_node = self.output_name_to_node()
parents = []
for input in node.input:
if input in output_name_to_node:
parents.append(output_name_to_node[input])
return parents
def get_parent(self, node, i, output_name_to_node=None):
if output_name_to_node is None:
output_name_to_node = self.output_name_to_node()
if len(node.input) <= i:
return None
input = node.input[i]
if input not in output_name_to_node:
return None
return output_name_to_node[input]
def match_parent_path(self, node, parent_op_types, parent_input_index=None, output_name_to_node=None):
if output_name_to_node is None:
output_name_to_node = self.output_name_to_node()
if parent_input_index is None:
parent_input_index = [0] * len(parent_op_types)
assert(len(parent_input_index) == len(parent_op_types))
current_node = node
matched_parents = []
for i, op_type in enumerate(parent_op_types):
input_index = parent_input_index[i]
if input_index >= len(current_node.input):
return None
parent = self.get_parent(current_node, input_index, output_name_to_node)
if parent is None:
return None
if parent.op_type == parent_op_types[i]:
matched_parents.append(parent)
current_node = parent
return matched_parents
def find_first_child_by_type(self, node, child_type, input_name_to_nodes=None, recursive=True):
children = self.get_children(node, input_name_to_nodes)
dq = deque(children)
while len(dq) > 0:
current_node = dq.pop()
if current_node.op_type == child_type:
return current_node
if recursive:
children = self.get_children(current_node, input_name_to_nodes)
for child in children:
dq.appendleft(child)
return None
def find_first_parent_by_type(self, node, parent_type, output_name_to_node=None, recursive=True):
if output_name_to_node is None:
output_name_to_node = self.output_name_to_node()
parents = self.get_parents(node, output_name_to_node)
dq = deque(parents)
while len(dq) > 0:
current_node = dq.pop()
if current_node.op_type == parent_type:
return current_node
if recursive:
parents = self.get_parents(current_node, output_name_to_node)
for parent in parents:
dq.appendleft(parent)
return None
def get_constant_value(self, output_name):
for node in self.get_nodes_by_op_type('Constant'):
if node.output[0] == output_name:
for att in node.attribute:
if att.name == 'value':
return numpy_helper.to_array(att.t)
def get_children_subgraph_nodes(self, root_node, stop_nodes, input_name_to_nodes=None):
if input_name_to_nodes is None:
input_name_to_nodes = self.input_name_to_nodes()
children = input_name_to_nodes[root_node.output[0]]
unique_nodes = []
dq = deque(children)
while len(dq) > 0:
current_node = dq.pop()
if current_node in stop_nodes:
continue
if current_node not in unique_nodes:
unique_nodes.append(current_node)
for output in current_node.output:
if output in input_name_to_nodes:
children = input_name_to_nodes[output]
for child in children:
dq.appendleft(child)
return unique_nodes
def convert_model_float32_to_float16(self):
graph = self.model.graph
initializers = graph.initializer
for input_value_info in graph.input:
if input_value_info.type.tensor_type.elem_type == 1:
input_value_info.type.tensor_type.elem_type = 10
for output_value_info in graph.output:
if output_value_info.type.tensor_type.elem_type == 1:
output_value_info.type.tensor_type.elem_type = 10
for initializer in initializers:
if initializer.data_type == 1:
initializer.CopyFrom(numpy_helper.from_array(numpy_helper.to_array(initializer).astype(np.float16), initializer.name))
for node in graph.node:
if node.op_type == 'Constant':
for att in node.attribute:
if att.name == 'value' and att.t.data_type == 1:
att.CopyFrom(onnx.helper.make_attribute("value", numpy_helper.from_array(numpy_helper.to_array(att.t).astype(np.float16))))
if node.op_type == 'Cast':
for att in node.attribute:
if att.name == 'to' and att.i == 1:
att.CopyFrom(onnx.helper.make_attribute("to", 10))
# create a new name for node
def create_node_name(self, op_type, name_prefix=None):
if op_type in self.node_name_counter:
self.node_name_counter[op_type] += 1
else:
self.node_name_counter[op_type] = 1
if name_prefix is not None:
full_name = name_prefix + str(self.node_name_counter[op_type])
else:
full_name = op_type + "_" + str(self.node_name_counter[op_type])
# Check whether the name is taken:
nodes = self.get_nodes_by_op_type(op_type)
for node in nodes:
if node.name == full_name:
raise Exception("Node name already taken:", full_name)
return full_name
def find_graph_input(self, input_name):
for input in self.model.graph.input:
if input.name == input_name:
return input
return None
def get_parent_subgraph_nodes(self, node, stop_nodes, output_name_to_node=None):
if output_name_to_node is None:
output_name_to_node = self.output_name_to_node()
unique_nodes = []
parents = self.get_parents(node, output_name_to_node)
dq = deque(parents)
while len(dq) > 0:
current_node = dq.pop()
if current_node in stop_nodes:
continue
if current_node not in unique_nodes:
unique_nodes.append(current_node)
for input in current_node.input:
if input in output_name_to_node:
dq.appendleft(output_name_to_node[input])
return unique_nodes
@staticmethod
def input_index(node_output, child_node):
index = 0
for input in child_node.input:
if input == node_output:
return index
index += 1
return -1
def remove_unused_constant(self):
input_name_to_nodes = self.input_name_to_nodes()
#remove unused constant
unused_nodes = []
nodes = self.nodes()
for node in nodes:
if node.op_type == "Constant" and node.output[0] not in input_name_to_nodes:
unused_nodes.append(node)
self.remove_nodes(unused_nodes)
if len(unused_nodes) > 0:
print("Removed unused constant nodes:", len(unused_nodes))
def update_graph(self, verbose=False):
graph = self.model.graph
remaining_input_names = []
for node in graph.node:
if node.op_type != "Constant":
for input_name in node.input:
if input_name not in remaining_input_names:
remaining_input_names.append(input_name)
if verbose:
print("remaining input names", remaining_input_names)
# remove graph input that is not used
inputs_to_remove = []
for input in graph.input:
if input.name not in remaining_input_names:
inputs_to_remove.append(input)
for input in inputs_to_remove:
graph.input.remove(input)
if verbose:
print("remove unused input ", len(inputs_to_remove), [input.name for input in inputs_to_remove])
# remove weights that are not used
weights_to_remove = []
for initializer in graph.initializer:
if initializer.name not in remaining_input_names:
weights_to_remove.append(initializer)
for initializer in weights_to_remove:
graph.initializer.remove(initializer)
if verbose:
print("remove unused initializers:", len(weights_to_remove), [initializer.name for initializer in weights_to_remove])
self.remove_unused_constant()
class BertOnnxModel(OnnxModel):
def __init__(self, model, num_heads, hidden_size, sequence_length):
assert num_heads > 0
assert hidden_size % num_heads == 0
assert sequence_length > 0
super(BertOnnxModel, self).__init__(model)
self.num_heads = num_heads
self.sequence_length = sequence_length
self.hidden_size = hidden_size
self.mask_input = None
self.embed_node = None
# constant node names
self.normalize_name = "SkipLayerNormalization"
self.gelu_name = 'Gelu'
self.attention_name = 'Attention'
def get_normalize_nodes(self):
return self.get_nodes_by_op_type(self.normalize_name)
def normalize_children_types(self):
return ['MatMul', 'MatMul', 'MatMul', 'SkipLayerNormalization']
def set_mask_input(self, input):
if self.mask_input is not None and input != self.mask_input:
raise Exception("Different mask inputs", self.mask_input, input)
self.mask_input = input
def create_attention_node(self, q_matmul, k_matmul, v_matmul, q_add, k_add, v_add, input, output):
q_weight = self.get_initializer(q_matmul.input[1])
k_weight = self.get_initializer(k_matmul.input[1])
v_weight = self.get_initializer(v_matmul.input[1])
q_bias = self.get_initializer(q_add.input[1])
k_bias = self.get_initializer(k_add.input[1])
v_bias = self.get_initializer(v_add.input[1])
qw = numpy_helper.to_array(q_weight)
assert qw.shape == (self.hidden_size, self.hidden_size)
kw = numpy_helper.to_array(k_weight)
assert kw.shape == (self.hidden_size, self.hidden_size)
vw = numpy_helper.to_array(v_weight)
assert vw.shape == (self.hidden_size, self.hidden_size)
qkv_weight = np.stack((qw, kw, vw), axis=-2)
qb = numpy_helper.to_array(q_bias)
assert qb.shape == (self.hidden_size,)
kb = numpy_helper.to_array(k_bias)
assert kb.shape == (self.hidden_size,)
vb = numpy_helper.to_array(v_bias)
assert vb.shape == (self.hidden_size,)
qkv_bias = np.stack((qb, kb, vb), axis=-2)
attention_node_name = self.create_node_name(self.attention_name)
weight = onnx.helper.make_tensor(name=attention_node_name + '_qkv_weight',
data_type=TensorProto.FLOAT,
dims=[self.hidden_size, 3 * self.hidden_size],
vals=qkv_weight.flatten().tolist())
self.add_initializer(weight)
weight_input = onnx.helper.make_tensor_value_info(weight.name, TensorProto.FLOAT, [self.hidden_size, 3 * self.hidden_size])
self.add_input(weight_input)
bias = onnx.helper.make_tensor(name=attention_node_name + '_qkv_bias',
data_type=TensorProto.FLOAT,
dims=[3 * self.hidden_size],
vals=qkv_bias.flatten().tolist())
self.add_initializer(bias)
bias_input = onnx.helper.make_tensor_value_info(bias.name, TensorProto.FLOAT, [3 * self.hidden_size])
self.add_input(bias_input)
attention_node = onnx.helper.make_node(self.attention_name,
inputs=[input, attention_node_name + '_qkv_weight', attention_node_name + '_qkv_bias', self.mask_input],
outputs=[output],
name=attention_node_name)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([onnx.helper.make_attribute("num_heads", self.num_heads)])
self.add_node(attention_node)
def fuse_attention(self, verbose=False):
input_name_to_nodes = self.input_name_to_nodes()
output_name_to_node = self.output_name_to_node()
nodes_to_remove = []
for normalize_node in self.get_normalize_nodes():
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
qkv_nodes = None
root_input = None
for i, input in enumerate(normalize_node.input):
if input not in output_name_to_node:
continue
children = input_name_to_nodes[input]
children_types = sorted([child.op_type for child in children])
if children_types != self.normalize_children_types():
qkv_nodes = self.match_parent_path(normalize_node, ['Add', 'MatMul', 'Reshape', 'Transpose', 'MatMul'], [i, 0, 0, 0, 0])
else:
root_input = input
if root_input is None or qkv_nodes is None:
continue
(add_qkv, matmul_qkv, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
v_nodes = self.match_parent_path(matmul_qkv, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0])
if v_nodes is None:
continue
(transpose_v, reshape_v, add_v, matmul_v) = v_nodes
qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Add', 'Div', 'MatMul'], [0, 0, 0, 0])
if qk_nodes is None:
continue
(softmax_qk, add_qk, div_qk, matmul_qk) = qk_nodes
q_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [0, 0, 0, 0])
if q_nodes is None:
continue
(transpose_q, reshape_q, add_q, matmul_q) = q_nodes
k_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0])
if k_nodes is None:
continue
(transpose_k, reshape_k, add_k, matmul_k) = k_nodes
mask_nodes = self.match_parent_path(add_qk, ['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze'], [1, 0, 1, 0, 0])
if mask_nodes is None:
continue
(mul_mask, sub_mask, cast_mask, unsqueeze_mask, unsqueeze_mask_0) = mask_nodes
if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_v.input[0] == root_input:
self.set_mask_input(unsqueeze_mask_0.input[0])
self.create_attention_node(matmul_q, matmul_k, matmul_v, add_q, add_k, add_v, root_input, reshape_qkv.output[0])
nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv])
nodes_to_remove.extend(qk_nodes)
nodes_to_remove.extend(q_nodes)
nodes_to_remove.extend(k_nodes)
nodes_to_remove.extend(v_nodes)
nodes_to_remove.extend(mask_nodes)
self.remove_nodes(nodes_to_remove)
self.update_graph(verbose)
def fuse_gelu(self):
nodes = self.nodes()
input_name_to_nodes = self.input_name_to_nodes()
output_name_to_node = self.output_name_to_node()
nodes_to_remove = []
nodes_to_add = []
for node in self.get_normalize_nodes():
children = input_name_to_nodes[node.output[0]]
if len(children) != 2:
continue
children_types = sorted([child.op_type for child in children])
if children_types != ['MatMul', 'SkipLayerNormalization']:
continue
matmul_node = self.find_first_child_by_type(node, 'MatMul', input_name_to_nodes)
matmul_child = input_name_to_nodes[matmul_node.output[0]]
if len(matmul_child) != 1 or matmul_child[0].op_type != 'Add':
continue
add_node = matmul_child[0]
children = input_name_to_nodes[add_node.output[0]]
children_types = sorted([child.op_type for child in children])
if children_types != ['Div', 'Mul']:
continue
matmul_2 = self.find_first_child_by_type(add_node, 'MatMul', input_name_to_nodes)
if matmul_2 is None:
continue
subgraph_nodes = self.get_children_subgraph_nodes(add_node, [matmul_2], input_name_to_nodes)
if len(subgraph_nodes) != 5:
continue
nodes_to_remove.extend(subgraph_nodes)
gelu_node = onnx.helper.make_node(self.gelu_name,
inputs=[add_node.output[0]],
outputs=[matmul_2.input[0]])
gelu_node.domain = "com.microsoft"
nodes_to_add.append(gelu_node)
self.remove_nodes(nodes_to_remove)
self.add_nodes(nodes_to_add)
def fuse_reshape(self):
nodes = self.nodes()
input_name_to_nodes = self.input_name_to_nodes()
output_name_to_node = self.output_name_to_node()
nodes_to_remove = []
nodes_to_add = []
for reshape_node in self.get_nodes_by_op_type('Reshape'):
concat_node = output_name_to_node[reshape_node.input[1]]
if concat_node.op_type != 'Concat' or len(concat_node.input) < 3:
continue
path = self.match_parent_path(concat_node, ['Unsqueeze', 'Gather', 'Shape'], [0, 0, 0], output_name_to_node)
if path is None:
continue
(unsqueeze_0, gather_0, shape_0) = path
path = self.match_parent_path(concat_node, ['Unsqueeze', 'Gather', 'Shape'], [1, 0, 0], output_name_to_node)
if path is None:
continue
(unsqueeze_1, gather_1, shape_1) = path
shape = []
gather_value = self.get_constant_value(gather_0.input[1])
if gather_value == 0:
shape.append(0)
gather_value = self.get_constant_value(gather_1.input[1])
if gather_value == 1:
shape.append(0)
if len(shape) != 2:
continue
if (len(concat_node.input) > 2):
concat_2 = self.get_initializer(concat_node.input[2])
if concat_2 is None:
continue
shape.extend(numpy_helper.to_array(concat_2))
if (len(concat_node.input) > 3):
concat_3 = self.get_initializer(concat_node.input[3])
if concat_3 is None:
continue
shape.extend(numpy_helper.to_array(concat_3))
shape_value = np.asarray(shape, dtype=np.int64)
constant_shape_name = self.create_node_name('Constant', 'constant_shape')
new_node = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=[constant_shape_name],
value=onnx.helper.make_tensor(
name='const_tensor',
data_type=TensorProto.INT64,
dims=shape_value.shape,
vals=shape_value))
reshape_node.input[1] = constant_shape_name
nodes_to_remove.extend([concat_node, unsqueeze_0, unsqueeze_1, gather_0, gather_1, shape_0, shape_1])
nodes_to_add.append(new_node)
print("Fused reshape count:", len(nodes_to_add))
self.remove_nodes(nodes_to_remove)
self.add_nodes(nodes_to_add)
"""
Embed Layer Normalization will fuse embeddings and mask processing into one node.
The embeddings before conversion:
(input_ids) --------> Gather ----------+ (segment_ids)
| | |
| v v
+--> Shape --> Expand -> Gather---->Add Gather
| ^ | |
| | v v
+---(optional graph) SkipLayerNormalization
Optional graph is used to generate position list (0, 1, ...). It can be a constant in some model.
"""
def fuse_embed_layer(self, verbose=False):
if self.mask_input is None:
print("skip embed layer fusion since mask input is not found")
return
nodes = self.nodes()
input_name_to_nodes = self.input_name_to_nodes()
output_name_to_node = self.output_name_to_node()
mask_input_name = self.mask_input
nodes_to_remove = []
nodes_to_add = []
# Find the first normalize node could be embedding layer.
normalize_node = None
for node in self.get_normalize_nodes():
if self.match_parent_path(node, ['Add', 'Gather'], [0, 0]) is not None:
if self.find_first_child_by_type(node, 'Attention', input_name_to_nodes, recursive=False) is not None:
normalize_node = node
break
if normalize_node is None:
print("did not find embedding layer")
# Here we assume the order of embedding is word_embedding + position_embedding + segment_embedding.
word_embedding_path = self.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 0])
if word_embedding_path is None:
print("Failed to find word embedding")
return
add_node, word_embedding_gather = word_embedding_path
position_embedding_path = self.match_parent_path(add_node, ['Gather', 'Expand', 'Shape'], [1, 1, 1])
if position_embedding_path is None:
print("Failed to find position embedding")
return
position_embedding_gather, position_embedding_expand, position_embedding_shape = position_embedding_path
segment_embedding_path = self.match_parent_path(normalize_node, ['Gather'], [1])
if segment_embedding_path is None:
print("failed to find segment embedding")
return
segment_embedding_gather = segment_embedding_path[0]
input_ids = word_embedding_gather.input[1]
segment_ids = segment_embedding_gather.input[1]
if position_embedding_shape.input[0] != input_ids:
print("position and word embedding is expected to be applied on same input")
return
subgraph_nodes = self.get_parent_subgraph_nodes(position_embedding_expand, [input_ids], output_name_to_node)
nodes_to_remove.extend(subgraph_nodes)
nodes_to_remove.extend([normalize_node, add_node, segment_embedding_gather, word_embedding_gather, position_embedding_gather, position_embedding_expand])
embed_node = onnx.helper.make_node('EmbedLayerNormalization',
inputs=[input_ids, segment_ids, mask_input_name,
word_embedding_gather.input[0], position_embedding_gather.input[0], segment_embedding_gather.input[0],
normalize_node.input[2], normalize_node.input[3]], # gamma and beta
outputs=["embed_output", "mask_idx"],
name="EmbedLayer")
embed_node.domain = "com.microsoft"
# store embed node for other processing
self.embed_node = embed_node
nodes_to_add.extend([embed_node])
self.replace_input_of_all_nodes(normalize_node.output[0], 'embed_output')
self.replace_input_of_all_nodes(mask_input_name, 'mask_idx')
self.remove_nodes(nodes_to_remove)
self.add_nodes(nodes_to_add)
self.update_graph(verbose)
def get_batch_size_from_graph_input(self):
graph = self.graph()
for input in graph.input:
if input.name in self.embed_node.input[:3]:
tensor_type = input.type.tensor_type
if (tensor_type.HasField("shape")):
for d in tensor_type.shape.dim:
if (d.HasField("dim_value")):
return d.dim_value
elif (d.HasField("dim_param")):
return str(d.dim_param) # unknown dimension with symbolic name
return None
return None
def change_input_to_int32(self):
original_opset_version = self.model.opset_import[0].version
graph = self.graph()
batch_size = self.get_batch_size_from_graph_input()
input_batch_size = batch_size if isinstance(batch_size, int) else 1
new_graph_inputs = []
for input in graph.input:
if input.name in self.embed_node.input[:3]: # Only the first 3 inputs of embed node need int32 conversion.
int32_input = onnx.helper.make_tensor_value_info(input.name, TensorProto.INT32, [input_batch_size, self.sequence_length])
new_graph_inputs.append(int32_input)
else:
new_graph_inputs.append(input)
graph_def = onnx.helper.make_graph(graph.node,
'int32 inputs',
new_graph_inputs,
graph.output,
initializer=graph.initializer,
value_info=graph.value_info)
self.model = onnx.helper.make_model(graph_def, producer_name='bert model optimizer')
if isinstance(batch_size, str):
self.update_dynamic_batch_io(batch_size)
# restore opset version
self.model.opset_import[0].version = original_opset_version
def cast_input_to_int32(self):
for input in self.embed_node.input[:3]:
graph_input = self.find_graph_input(input)
if graph_input is not None and graph_input.type.tensor_type.elem_type == TensorProto.INT64:
cast_output = input + '_int32'
cast_node = onnx.helper.make_node('Cast', inputs=[input], outputs=[cast_output])
cast_node.attribute.extend([onnx.helper.make_attribute("to", int(TensorProto.INT32))])
self.replace_input_of_all_nodes(input, cast_output)
self.add_node(cast_node)
# Update input and output using dynamic batch
def update_dynamic_batch_io(self, dynamic_batch_dim='batch'):
dynamic_batch_inputs = {}
for input in self.model.graph.input:
for embed_input in self.embed_node.input[:3]:
if embed_input == input.name:
dim_proto = input.type.tensor_type.shape.dim[0]
dim_proto.dim_param = dynamic_batch_dim
for output in self.model.graph.output:
dim_proto = output.type.tensor_type.shape.dim[0]
dim_proto.dim_param = dynamic_batch_dim
"""
Layer Normalization will fuse Add + LayerNormalization into one node:
+----------------------+
| |
| v
Add --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
| ^
| |
+-----------------------------------------------+
It also handles cases of duplicated sub nodes exported from older version of PyTorch:
+----------------------+
| v
| +-------> Sub-----------------------------------------------+
| | |
| | v
Add --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
| ^
| |
+----------------------+
"""
def fuse_layer_norm(self):
input_name_to_nodes = self.input_name_to_nodes()
output_name_to_node = self.output_name_to_node()
nodes_to_remove = []
nodes_to_add = []
for node in self.nodes():
if node.op_type == 'Add':
children = self.get_children(node, input_name_to_nodes)
children_types = sorted([child.op_type for child in children])
if children_types != ["ReduceMean", "Sub"] and children_types != ["ReduceMean", "Sub", "Sub"]:
continue
div_node = None
for child in children:
if child.op_type == 'Sub':
div_node = self.find_first_child_by_type(child, 'Div', input_name_to_nodes, recursive=False)
if div_node is not None:
break
if div_node is None:
continue
parent_nodes = self.match_parent_path(div_node, ['Sqrt', 'Add', 'ReduceMean', 'Pow', 'Sub', 'Add'], [1, 0, 0, 0, 0, 0], output_name_to_node)
if parent_nodes is None:
continue
sqrt_node, second_add_node, reduce_mean_node, pow_node, sub_node, first_add_node = parent_nodes
if first_add_node != node:
continue
mul_node = input_name_to_nodes[div_node.output[0]][0]
if mul_node.op_type != 'Mul':
continue
last_add_node = input_name_to_nodes[mul_node.output[0]][0]
if last_add_node.op_type != 'Add':
continue
nodes_to_remove.append(node)
nodes_to_remove.extend(children)
nodes_to_remove.extend([last_add_node, mul_node, div_node, sqrt_node, second_add_node, reduce_mean_node, pow_node])
normalize_node_name = self.create_node_name(self.normalize_name, name_prefix="SkipLayerNorm")
inputs = [i for i in node.input]
inputs.extend([mul_node.input[0], last_add_node.input[1]])
normalize_node = onnx.helper.make_node(self.normalize_name,
inputs=inputs,
outputs=[last_add_node.output[0]],
name=normalize_node_name)
normalize_node.domain = "com.microsoft"
nodes_to_add.extend([normalize_node])
self.remove_nodes(nodes_to_remove)
self.add_nodes(nodes_to_add)
print("Fused layer normalization count:", len(nodes_to_add))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input', required=True, type=str)
parser.add_argument('--output', required=True, type=str)
# model parameters
parser.add_argument('--num_heads', required=False, type=int, default=12, help="number of attention heads")
parser.add_argument('--hidden_size', required=False, type=int, default=768)
parser.add_argument('--sequence_length', required=False, type=int, default=128)
# Use int32 (instead of int64) tensor as input to avoid unnecessary data type cast.
parser.add_argument('--input_int32', required=False, action='store_true')
parser.set_defaults(input_int32=False)
# For NVidia GPU with Tensor Core like V100 and T4, half-precision float brings better performance.
parser.add_argument('--float16', required=False, action='store_true')
parser.set_defaults(float16=False)
parser.add_argument('--verbose', required=False, action='store_true')
parser.set_defaults(verbose=False)
args = parser.parse_args()
model = ModelProto()
with open(args.input, "rb") as f:
model.ParseFromString(f.read())
bert_model = BertOnnxModel(model, args.num_heads, args.hidden_size, args.sequence_length)
bert_model.fuse_layer_norm()
bert_model.fuse_gelu()
bert_model.fuse_reshape()
bert_model.fuse_attention(args.verbose)
bert_model.fuse_embed_layer(args.verbose)
if bert_model.embed_node is None:
print("Failed to fuse embedding layer.")
return
if args.input_int32:
bert_model.change_input_to_int32()
else:
bert_model.cast_input_to_int32()
if args.float16:
bert_model.convert_model_float32_to_float16()
with open(args.output, "wb") as out:
out.write(bert_model.model.SerializeToString())
main()