mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-23 02:38:28 +00:00
Script for converting BERT model for performance optimization (#2037)
* script for converting ONNX model for BERT performance optimization * Remove code that not needed anymore. * refine the script * Support BERT model exported from PyTorch 1.3 Keep opset version Exact match in Attention, Layer normalziation fusions. * read batch_size from input model directly
This commit is contained in:
parent
ec136ac60f
commit
b0f8ec7a7d
2 changed files with 947 additions and 0 deletions
55
onnxruntime/python/tools/bert/README.md
Normal file
55
onnxruntime/python/tools/bert/README.md
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
# BERT Model Optimization Tool Overview
|
||||
|
||||
This tool converts a BERT ONNX model exported from PyTorch, and generates a optimized model to run faster in NVidia GPU.
|
||||
|
||||
Currently, this script **cannot** process BERT models exported from Tensorflow since the graph has some difference.
|
||||
|
||||
## Export an BERT model from PyTorch
|
||||
For example, after using https://github.com/huggingface/transformers to Train a BERT model in PyTorch 1.3, you can use the following function to export ONNX model.
|
||||
|
||||
Please specify do_constant_folding=True. That's required for this tool.
|
||||
|
||||
```python
|
||||
def export_onnx(args, model, output_path):
|
||||
model.eval() # set the model to inference mode
|
||||
device = torch.device("cpu")
|
||||
model.to(device)
|
||||
dummy_input0 = torch.LongTensor(args.eval_batch_size, args.max_seq_length).fill_(1).to(device)
|
||||
dummy_input1 = torch.LongTensor(args.eval_batch_size, args.max_seq_length).fill_(1).to(device)
|
||||
dummy_input2 = torch.LongTensor(args.eval_batch_size, args.max_seq_length).fill_(0).to(device)
|
||||
dummy_input = (dummy_input0, dummy_input1, dummy_input2)
|
||||
torch.onnx.export(model, # model being run
|
||||
dummy_input, # model input (or a tuple for multiple inputs)
|
||||
output_path, # where to save the model (can be a file or file-like object)
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=10, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names = ["input_ids", "input_mask", "segment_ids"],
|
||||
output_names = ["output"],
|
||||
dynamic_axes={'input_ids' : {0 : 'batch_size'}, # variable lenght axes
|
||||
'input_mask' : {0 : 'batch_size'},
|
||||
'segment_ids' : {0 : 'batch_size'},
|
||||
'output' : {0 : 'batch_size'}})
|
||||
```
|
||||
## Model Optimization
|
||||
|
||||
Example of using the script bert_model_optimization.py to convert a BERT-large model to run in V100 GPU:
|
||||
```console
|
||||
python bert_model_optimization.py --input input_model.onnx --output optimized_model.onnx --num_heads 24 --hidden_size 1024 --sequence_length 128 --input_int32 --float16
|
||||
```
|
||||
|
||||
## Options
|
||||
|
||||
See below for description of all the options:
|
||||
|
||||
- **input**: input model path
|
||||
- **output**: output model path
|
||||
- **num_heads**: (*default: 12*)
|
||||
Number of attention heads, like 24 for BERT-large model.
|
||||
- **hidden_size**: (*default: 768*)
|
||||
- **sequence_length**: (*default: 128*)
|
||||
Maximum sequence length.
|
||||
- **input_int32**: (*optional*)
|
||||
Exported model ususally uses int64 tensor as input. If this flag is specified, int32 tensors will be used as input, and it could avoid un-necessary Cast nodes and get better performance.
|
||||
- **float16**: (*optional*)
|
||||
By default, model uses float32 in computation. If this flag is specified, half-precision float will be used. This option is recommended for NVidia GPU with Tensor Core like V100 and T4. For older GPUs, float32 is likely faster.
|
||||
892
onnxruntime/python/tools/bert/bert_model_optimization.py
Normal file
892
onnxruntime/python/tools/bert/bert_model_optimization.py
Normal file
|
|
@ -0,0 +1,892 @@
|
|||
#-------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
#--------------------------------------------------------------------------
|
||||
|
||||
# Convert Bert ONNX model exported from PyTorch to use Attention, Gelu, SkipLayerNormalization and
|
||||
# EmbedLayerNormalization ops to optimize performance on NVidia GPU.
|
||||
|
||||
import onnx
|
||||
import sys
|
||||
import argparse
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.node_name_counter = {}
|
||||
|
||||
def input_name_to_nodes(self):
|
||||
input_name_to_nodes = {}
|
||||
for node in self.model.graph.node:
|
||||
for input_name in node.input:
|
||||
if input_name not in input_name_to_nodes:
|
||||
input_name_to_nodes[input_name] = [node]
|
||||
else:
|
||||
input_name_to_nodes[input_name].append(node)
|
||||
return input_name_to_nodes
|
||||
|
||||
def output_name_to_node(self):
|
||||
output_name_to_node = {}
|
||||
for node in self.model.graph.node:
|
||||
for output_name in node.output:
|
||||
output_name_to_node[output_name] = node
|
||||
return output_name_to_node
|
||||
|
||||
def nodes(self):
|
||||
return self.model.graph.node
|
||||
|
||||
def graph(self):
|
||||
return self.model.graph
|
||||
|
||||
def remove_node(self, node):
|
||||
if node in self.model.graph.node:
|
||||
self.model.graph.node.remove(node)
|
||||
|
||||
def remove_nodes(self, nodes_to_remove):
|
||||
for node in nodes_to_remove:
|
||||
self.remove_node(node)
|
||||
|
||||
def add_node(self, node):
|
||||
self.model.graph.node.extend([node])
|
||||
|
||||
def add_nodes(self, nodes_to_add):
|
||||
self.model.graph.node.extend(nodes_to_add)
|
||||
|
||||
def add_initializer(self, tensor):
|
||||
self.model.graph.initializer.extend([tensor])
|
||||
|
||||
def add_input(self, input):
|
||||
self.model.graph.input.extend([input])
|
||||
|
||||
@staticmethod
|
||||
def replace_node_input(node, old_input_name, new_input_name):
|
||||
assert isinstance(old_input_name, str) and isinstance(new_input_name, str)
|
||||
for j in range(len(node.input)):
|
||||
if node.input[j] == old_input_name:
|
||||
node.input[j] = new_input_name
|
||||
|
||||
def replace_input_of_all_nodes(self, old_input_name, new_input_name):
|
||||
for node in self.model.graph.node:
|
||||
OnnxModel.replace_node_input(node, old_input_name, new_input_name)
|
||||
|
||||
def get_initializer(self,name):
|
||||
for tensor in self.model.graph.initializer:
|
||||
if tensor.name == name:
|
||||
return tensor
|
||||
return None
|
||||
|
||||
def get_nodes_by_op_type(self, op_type):
|
||||
return [n for n in self.model.graph.node if n.op_type == op_type]
|
||||
|
||||
def get_children(self, node, input_name_to_nodes=None):
|
||||
if (input_name_to_nodes is None):
|
||||
input_name_to_nodes = self.input_name_to_nodes()
|
||||
|
||||
children = []
|
||||
for output in node.output:
|
||||
if output in input_name_to_nodes:
|
||||
for node in input_name_to_nodes[output]:
|
||||
children.append(node)
|
||||
return children
|
||||
|
||||
def get_parents(self, node, output_name_to_node=None):
|
||||
if output_name_to_node is None:
|
||||
output_name_to_node = self.output_name_to_node()
|
||||
|
||||
parents = []
|
||||
for input in node.input:
|
||||
if input in output_name_to_node:
|
||||
parents.append(output_name_to_node[input])
|
||||
return parents
|
||||
|
||||
def get_parent(self, node, i, output_name_to_node=None):
|
||||
if output_name_to_node is None:
|
||||
output_name_to_node = self.output_name_to_node()
|
||||
|
||||
if len(node.input) <= i:
|
||||
return None
|
||||
|
||||
input = node.input[i]
|
||||
if input not in output_name_to_node:
|
||||
return None
|
||||
|
||||
return output_name_to_node[input]
|
||||
|
||||
def match_parent_path(self, node, parent_op_types, parent_input_index=None, output_name_to_node=None):
|
||||
if output_name_to_node is None:
|
||||
output_name_to_node = self.output_name_to_node()
|
||||
|
||||
if parent_input_index is None:
|
||||
parent_input_index = [0] * len(parent_op_types)
|
||||
|
||||
assert(len(parent_input_index) == len(parent_op_types))
|
||||
current_node = node
|
||||
matched_parents = []
|
||||
for i, op_type in enumerate(parent_op_types):
|
||||
input_index = parent_input_index[i]
|
||||
if input_index >= len(current_node.input):
|
||||
return None
|
||||
parent = self.get_parent(current_node, input_index, output_name_to_node)
|
||||
if parent is None:
|
||||
return None
|
||||
if parent.op_type == parent_op_types[i]:
|
||||
matched_parents.append(parent)
|
||||
current_node = parent
|
||||
return matched_parents
|
||||
|
||||
def find_first_child_by_type(self, node, child_type, input_name_to_nodes=None, recursive=True):
|
||||
children = self.get_children(node, input_name_to_nodes)
|
||||
dq = deque(children)
|
||||
while len(dq) > 0:
|
||||
current_node = dq.pop()
|
||||
if current_node.op_type == child_type:
|
||||
return current_node
|
||||
|
||||
if recursive:
|
||||
children = self.get_children(current_node, input_name_to_nodes)
|
||||
for child in children:
|
||||
dq.appendleft(child)
|
||||
|
||||
return None
|
||||
|
||||
def find_first_parent_by_type(self, node, parent_type, output_name_to_node=None, recursive=True):
|
||||
if output_name_to_node is None:
|
||||
output_name_to_node = self.output_name_to_node()
|
||||
|
||||
parents = self.get_parents(node, output_name_to_node)
|
||||
dq = deque(parents)
|
||||
while len(dq) > 0:
|
||||
current_node = dq.pop()
|
||||
if current_node.op_type == parent_type:
|
||||
return current_node
|
||||
|
||||
if recursive:
|
||||
parents = self.get_parents(current_node, output_name_to_node)
|
||||
for parent in parents:
|
||||
dq.appendleft(parent)
|
||||
|
||||
return None
|
||||
|
||||
def get_constant_value(self, output_name):
|
||||
for node in self.get_nodes_by_op_type('Constant'):
|
||||
if node.output[0] == output_name:
|
||||
for att in node.attribute:
|
||||
if att.name == 'value':
|
||||
return numpy_helper.to_array(att.t)
|
||||
|
||||
def get_children_subgraph_nodes(self, root_node, stop_nodes, input_name_to_nodes=None):
|
||||
if input_name_to_nodes is None:
|
||||
input_name_to_nodes = self.input_name_to_nodes()
|
||||
|
||||
children = input_name_to_nodes[root_node.output[0]]
|
||||
|
||||
unique_nodes = []
|
||||
|
||||
dq = deque(children)
|
||||
while len(dq) > 0:
|
||||
current_node = dq.pop()
|
||||
if current_node in stop_nodes:
|
||||
continue
|
||||
|
||||
if current_node not in unique_nodes:
|
||||
unique_nodes.append(current_node)
|
||||
|
||||
for output in current_node.output:
|
||||
if output in input_name_to_nodes:
|
||||
children = input_name_to_nodes[output]
|
||||
for child in children:
|
||||
dq.appendleft(child)
|
||||
|
||||
return unique_nodes
|
||||
|
||||
def convert_model_float32_to_float16(self):
|
||||
graph = self.model.graph
|
||||
initializers = graph.initializer
|
||||
|
||||
for input_value_info in graph.input:
|
||||
if input_value_info.type.tensor_type.elem_type == 1:
|
||||
input_value_info.type.tensor_type.elem_type = 10
|
||||
|
||||
for output_value_info in graph.output:
|
||||
if output_value_info.type.tensor_type.elem_type == 1:
|
||||
output_value_info.type.tensor_type.elem_type = 10
|
||||
|
||||
for initializer in initializers:
|
||||
if initializer.data_type == 1:
|
||||
initializer.CopyFrom(numpy_helper.from_array(numpy_helper.to_array(initializer).astype(np.float16), initializer.name))
|
||||
|
||||
for node in graph.node:
|
||||
if node.op_type == 'Constant':
|
||||
for att in node.attribute:
|
||||
if att.name == 'value' and att.t.data_type == 1:
|
||||
att.CopyFrom(onnx.helper.make_attribute("value", numpy_helper.from_array(numpy_helper.to_array(att.t).astype(np.float16))))
|
||||
if node.op_type == 'Cast':
|
||||
for att in node.attribute:
|
||||
if att.name == 'to' and att.i == 1:
|
||||
att.CopyFrom(onnx.helper.make_attribute("to", 10))
|
||||
|
||||
# create a new name for node
|
||||
def create_node_name(self, op_type, name_prefix=None):
|
||||
if op_type in self.node_name_counter:
|
||||
self.node_name_counter[op_type] += 1
|
||||
else:
|
||||
self.node_name_counter[op_type] = 1
|
||||
|
||||
if name_prefix is not None:
|
||||
full_name = name_prefix + str(self.node_name_counter[op_type])
|
||||
else:
|
||||
full_name = op_type + "_" + str(self.node_name_counter[op_type])
|
||||
|
||||
# Check whether the name is taken:
|
||||
nodes = self.get_nodes_by_op_type(op_type)
|
||||
for node in nodes:
|
||||
if node.name == full_name:
|
||||
raise Exception("Node name already taken:", full_name)
|
||||
|
||||
return full_name
|
||||
|
||||
|
||||
def find_graph_input(self, input_name):
|
||||
for input in self.model.graph.input:
|
||||
if input.name == input_name:
|
||||
return input
|
||||
return None
|
||||
|
||||
def get_parent_subgraph_nodes(self, node, stop_nodes, output_name_to_node=None):
|
||||
if output_name_to_node is None:
|
||||
output_name_to_node = self.output_name_to_node()
|
||||
|
||||
unique_nodes = []
|
||||
|
||||
parents = self.get_parents(node, output_name_to_node)
|
||||
dq = deque(parents)
|
||||
while len(dq) > 0:
|
||||
current_node = dq.pop()
|
||||
if current_node in stop_nodes:
|
||||
continue
|
||||
|
||||
if current_node not in unique_nodes:
|
||||
unique_nodes.append(current_node)
|
||||
|
||||
for input in current_node.input:
|
||||
if input in output_name_to_node:
|
||||
dq.appendleft(output_name_to_node[input])
|
||||
|
||||
return unique_nodes
|
||||
|
||||
@staticmethod
|
||||
def input_index(node_output, child_node):
|
||||
index = 0
|
||||
for input in child_node.input:
|
||||
if input == node_output:
|
||||
return index
|
||||
index += 1
|
||||
return -1
|
||||
|
||||
def remove_unused_constant(self):
|
||||
input_name_to_nodes = self.input_name_to_nodes()
|
||||
|
||||
#remove unused constant
|
||||
unused_nodes = []
|
||||
nodes = self.nodes()
|
||||
for node in nodes:
|
||||
if node.op_type == "Constant" and node.output[0] not in input_name_to_nodes:
|
||||
unused_nodes.append(node)
|
||||
|
||||
self.remove_nodes(unused_nodes)
|
||||
|
||||
if len(unused_nodes) > 0:
|
||||
print("Removed unused constant nodes:", len(unused_nodes))
|
||||
|
||||
def update_graph(self, verbose=False):
|
||||
graph = self.model.graph
|
||||
|
||||
remaining_input_names = []
|
||||
for node in graph.node:
|
||||
if node.op_type != "Constant":
|
||||
for input_name in node.input:
|
||||
if input_name not in remaining_input_names:
|
||||
remaining_input_names.append(input_name)
|
||||
if verbose:
|
||||
print("remaining input names", remaining_input_names)
|
||||
|
||||
# remove graph input that is not used
|
||||
inputs_to_remove = []
|
||||
for input in graph.input:
|
||||
if input.name not in remaining_input_names:
|
||||
inputs_to_remove.append(input)
|
||||
for input in inputs_to_remove:
|
||||
graph.input.remove(input)
|
||||
if verbose:
|
||||
print("remove unused input ", len(inputs_to_remove), [input.name for input in inputs_to_remove])
|
||||
|
||||
# remove weights that are not used
|
||||
weights_to_remove = []
|
||||
for initializer in graph.initializer:
|
||||
if initializer.name not in remaining_input_names:
|
||||
weights_to_remove.append(initializer)
|
||||
for initializer in weights_to_remove:
|
||||
graph.initializer.remove(initializer)
|
||||
if verbose:
|
||||
print("remove unused initializers:", len(weights_to_remove), [initializer.name for initializer in weights_to_remove])
|
||||
|
||||
self.remove_unused_constant()
|
||||
|
||||
class BertOnnxModel(OnnxModel):
|
||||
def __init__(self, model, num_heads, hidden_size, sequence_length):
|
||||
assert num_heads > 0
|
||||
assert hidden_size % num_heads == 0
|
||||
assert sequence_length > 0
|
||||
|
||||
super(BertOnnxModel, self).__init__(model)
|
||||
self.num_heads = num_heads
|
||||
self.sequence_length = sequence_length
|
||||
self.hidden_size = hidden_size
|
||||
self.mask_input = None
|
||||
self.embed_node = None
|
||||
|
||||
# constant node names
|
||||
self.normalize_name = "SkipLayerNormalization"
|
||||
self.gelu_name = 'Gelu'
|
||||
self.attention_name = 'Attention'
|
||||
|
||||
def get_normalize_nodes(self):
|
||||
return self.get_nodes_by_op_type(self.normalize_name)
|
||||
|
||||
def normalize_children_types(self):
|
||||
return ['MatMul', 'MatMul', 'MatMul', 'SkipLayerNormalization']
|
||||
|
||||
def set_mask_input(self, input):
|
||||
if self.mask_input is not None and input != self.mask_input:
|
||||
raise Exception("Different mask inputs", self.mask_input, input)
|
||||
|
||||
self.mask_input = input
|
||||
|
||||
def create_attention_node(self, q_matmul, k_matmul, v_matmul, q_add, k_add, v_add, input, output):
|
||||
q_weight = self.get_initializer(q_matmul.input[1])
|
||||
k_weight = self.get_initializer(k_matmul.input[1])
|
||||
v_weight = self.get_initializer(v_matmul.input[1])
|
||||
q_bias = self.get_initializer(q_add.input[1])
|
||||
k_bias = self.get_initializer(k_add.input[1])
|
||||
v_bias = self.get_initializer(v_add.input[1])
|
||||
|
||||
qw = numpy_helper.to_array(q_weight)
|
||||
assert qw.shape == (self.hidden_size, self.hidden_size)
|
||||
|
||||
kw = numpy_helper.to_array(k_weight)
|
||||
assert kw.shape == (self.hidden_size, self.hidden_size)
|
||||
|
||||
vw = numpy_helper.to_array(v_weight)
|
||||
assert vw.shape == (self.hidden_size, self.hidden_size)
|
||||
|
||||
qkv_weight = np.stack((qw, kw, vw), axis=-2)
|
||||
|
||||
qb = numpy_helper.to_array(q_bias)
|
||||
assert qb.shape == (self.hidden_size,)
|
||||
|
||||
kb = numpy_helper.to_array(k_bias)
|
||||
assert kb.shape == (self.hidden_size,)
|
||||
|
||||
vb = numpy_helper.to_array(v_bias)
|
||||
assert vb.shape == (self.hidden_size,)
|
||||
|
||||
qkv_bias = np.stack((qb, kb, vb), axis=-2)
|
||||
|
||||
attention_node_name = self.create_node_name(self.attention_name)
|
||||
|
||||
weight = onnx.helper.make_tensor(name=attention_node_name + '_qkv_weight',
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[self.hidden_size, 3 * self.hidden_size],
|
||||
vals=qkv_weight.flatten().tolist())
|
||||
self.add_initializer(weight)
|
||||
|
||||
weight_input = onnx.helper.make_tensor_value_info(weight.name, TensorProto.FLOAT, [self.hidden_size, 3 * self.hidden_size])
|
||||
self.add_input(weight_input)
|
||||
|
||||
bias = onnx.helper.make_tensor(name=attention_node_name + '_qkv_bias',
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[3 * self.hidden_size],
|
||||
vals=qkv_bias.flatten().tolist())
|
||||
self.add_initializer(bias)
|
||||
|
||||
bias_input = onnx.helper.make_tensor_value_info(bias.name, TensorProto.FLOAT, [3 * self.hidden_size])
|
||||
self.add_input(bias_input)
|
||||
|
||||
attention_node = onnx.helper.make_node(self.attention_name,
|
||||
inputs=[input, attention_node_name + '_qkv_weight', attention_node_name + '_qkv_bias', self.mask_input],
|
||||
outputs=[output],
|
||||
name=attention_node_name)
|
||||
attention_node.domain = "com.microsoft"
|
||||
attention_node.attribute.extend([onnx.helper.make_attribute("num_heads", self.num_heads)])
|
||||
|
||||
self.add_node(attention_node)
|
||||
|
||||
def fuse_attention(self, verbose=False):
|
||||
input_name_to_nodes = self.input_name_to_nodes()
|
||||
output_name_to_node = self.output_name_to_node()
|
||||
|
||||
nodes_to_remove = []
|
||||
|
||||
for normalize_node in self.get_normalize_nodes():
|
||||
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
|
||||
qkv_nodes = None
|
||||
root_input = None
|
||||
for i, input in enumerate(normalize_node.input):
|
||||
if input not in output_name_to_node:
|
||||
continue
|
||||
children = input_name_to_nodes[input]
|
||||
children_types = sorted([child.op_type for child in children])
|
||||
if children_types != self.normalize_children_types():
|
||||
qkv_nodes = self.match_parent_path(normalize_node, ['Add', 'MatMul', 'Reshape', 'Transpose', 'MatMul'], [i, 0, 0, 0, 0])
|
||||
else:
|
||||
root_input = input
|
||||
|
||||
if root_input is None or qkv_nodes is None:
|
||||
continue
|
||||
|
||||
(add_qkv, matmul_qkv, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
|
||||
|
||||
v_nodes = self.match_parent_path(matmul_qkv, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0])
|
||||
if v_nodes is None:
|
||||
continue
|
||||
(transpose_v, reshape_v, add_v, matmul_v) = v_nodes
|
||||
|
||||
qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Add', 'Div', 'MatMul'], [0, 0, 0, 0])
|
||||
if qk_nodes is None:
|
||||
continue
|
||||
(softmax_qk, add_qk, div_qk, matmul_qk) = qk_nodes
|
||||
|
||||
q_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [0, 0, 0, 0])
|
||||
if q_nodes is None:
|
||||
continue
|
||||
(transpose_q, reshape_q, add_q, matmul_q) = q_nodes
|
||||
|
||||
k_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0])
|
||||
if k_nodes is None:
|
||||
continue
|
||||
(transpose_k, reshape_k, add_k, matmul_k) = k_nodes
|
||||
|
||||
mask_nodes = self.match_parent_path(add_qk, ['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze'], [1, 0, 1, 0, 0])
|
||||
if mask_nodes is None:
|
||||
continue
|
||||
(mul_mask, sub_mask, cast_mask, unsqueeze_mask, unsqueeze_mask_0) = mask_nodes
|
||||
|
||||
if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_v.input[0] == root_input:
|
||||
self.set_mask_input(unsqueeze_mask_0.input[0])
|
||||
self.create_attention_node(matmul_q, matmul_k, matmul_v, add_q, add_k, add_v, root_input, reshape_qkv.output[0])
|
||||
nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv])
|
||||
nodes_to_remove.extend(qk_nodes)
|
||||
nodes_to_remove.extend(q_nodes)
|
||||
nodes_to_remove.extend(k_nodes)
|
||||
nodes_to_remove.extend(v_nodes)
|
||||
nodes_to_remove.extend(mask_nodes)
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
self.update_graph(verbose)
|
||||
|
||||
def fuse_gelu(self):
|
||||
nodes = self.nodes()
|
||||
input_name_to_nodes = self.input_name_to_nodes()
|
||||
output_name_to_node = self.output_name_to_node()
|
||||
|
||||
nodes_to_remove = []
|
||||
nodes_to_add = []
|
||||
|
||||
for node in self.get_normalize_nodes():
|
||||
|
||||
children = input_name_to_nodes[node.output[0]]
|
||||
if len(children) != 2:
|
||||
continue
|
||||
|
||||
children_types = sorted([child.op_type for child in children])
|
||||
if children_types != ['MatMul', 'SkipLayerNormalization']:
|
||||
continue
|
||||
|
||||
matmul_node = self.find_first_child_by_type(node, 'MatMul', input_name_to_nodes)
|
||||
matmul_child = input_name_to_nodes[matmul_node.output[0]]
|
||||
if len(matmul_child) != 1 or matmul_child[0].op_type != 'Add':
|
||||
continue
|
||||
add_node = matmul_child[0]
|
||||
|
||||
children = input_name_to_nodes[add_node.output[0]]
|
||||
|
||||
children_types = sorted([child.op_type for child in children])
|
||||
if children_types != ['Div', 'Mul']:
|
||||
continue
|
||||
|
||||
matmul_2 = self.find_first_child_by_type(add_node, 'MatMul', input_name_to_nodes)
|
||||
if matmul_2 is None:
|
||||
continue
|
||||
|
||||
subgraph_nodes = self.get_children_subgraph_nodes(add_node, [matmul_2], input_name_to_nodes)
|
||||
if len(subgraph_nodes) != 5:
|
||||
continue
|
||||
|
||||
nodes_to_remove.extend(subgraph_nodes)
|
||||
gelu_node = onnx.helper.make_node(self.gelu_name,
|
||||
inputs=[add_node.output[0]],
|
||||
outputs=[matmul_2.input[0]])
|
||||
gelu_node.domain = "com.microsoft"
|
||||
nodes_to_add.append(gelu_node)
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
self.add_nodes(nodes_to_add)
|
||||
|
||||
def fuse_reshape(self):
|
||||
nodes = self.nodes()
|
||||
input_name_to_nodes = self.input_name_to_nodes()
|
||||
output_name_to_node = self.output_name_to_node()
|
||||
|
||||
nodes_to_remove = []
|
||||
nodes_to_add = []
|
||||
|
||||
for reshape_node in self.get_nodes_by_op_type('Reshape'):
|
||||
concat_node = output_name_to_node[reshape_node.input[1]]
|
||||
if concat_node.op_type != 'Concat' or len(concat_node.input) < 3:
|
||||
continue
|
||||
|
||||
path = self.match_parent_path(concat_node, ['Unsqueeze', 'Gather', 'Shape'], [0, 0, 0], output_name_to_node)
|
||||
if path is None:
|
||||
continue
|
||||
(unsqueeze_0, gather_0, shape_0) = path
|
||||
|
||||
path = self.match_parent_path(concat_node, ['Unsqueeze', 'Gather', 'Shape'], [1, 0, 0], output_name_to_node)
|
||||
if path is None:
|
||||
continue
|
||||
(unsqueeze_1, gather_1, shape_1) = path
|
||||
|
||||
shape = []
|
||||
gather_value = self.get_constant_value(gather_0.input[1])
|
||||
if gather_value == 0:
|
||||
shape.append(0)
|
||||
|
||||
gather_value = self.get_constant_value(gather_1.input[1])
|
||||
if gather_value == 1:
|
||||
shape.append(0)
|
||||
|
||||
if len(shape) != 2:
|
||||
continue
|
||||
|
||||
if (len(concat_node.input) > 2):
|
||||
concat_2 = self.get_initializer(concat_node.input[2])
|
||||
if concat_2 is None:
|
||||
continue
|
||||
shape.extend(numpy_helper.to_array(concat_2))
|
||||
|
||||
if (len(concat_node.input) > 3):
|
||||
concat_3 = self.get_initializer(concat_node.input[3])
|
||||
if concat_3 is None:
|
||||
continue
|
||||
shape.extend(numpy_helper.to_array(concat_3))
|
||||
shape_value = np.asarray(shape, dtype=np.int64)
|
||||
|
||||
constant_shape_name = self.create_node_name('Constant', 'constant_shape')
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
inputs=[],
|
||||
outputs=[constant_shape_name],
|
||||
value=onnx.helper.make_tensor(
|
||||
name='const_tensor',
|
||||
data_type=TensorProto.INT64,
|
||||
dims=shape_value.shape,
|
||||
vals=shape_value))
|
||||
reshape_node.input[1] = constant_shape_name
|
||||
nodes_to_remove.extend([concat_node, unsqueeze_0, unsqueeze_1, gather_0, gather_1, shape_0, shape_1])
|
||||
nodes_to_add.append(new_node)
|
||||
|
||||
print("Fused reshape count:", len(nodes_to_add))
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
self.add_nodes(nodes_to_add)
|
||||
|
||||
"""
|
||||
Embed Layer Normalization will fuse embeddings and mask processing into one node.
|
||||
The embeddings before conversion:
|
||||
|
||||
(input_ids) --------> Gather ----------+ (segment_ids)
|
||||
| | |
|
||||
| v v
|
||||
+--> Shape --> Expand -> Gather---->Add Gather
|
||||
| ^ | |
|
||||
| | v v
|
||||
+---(optional graph) SkipLayerNormalization
|
||||
|
||||
Optional graph is used to generate position list (0, 1, ...). It can be a constant in some model.
|
||||
"""
|
||||
def fuse_embed_layer(self, verbose=False):
|
||||
if self.mask_input is None:
|
||||
print("skip embed layer fusion since mask input is not found")
|
||||
return
|
||||
|
||||
nodes = self.nodes()
|
||||
input_name_to_nodes = self.input_name_to_nodes()
|
||||
output_name_to_node = self.output_name_to_node()
|
||||
mask_input_name = self.mask_input
|
||||
|
||||
nodes_to_remove = []
|
||||
nodes_to_add = []
|
||||
|
||||
# Find the first normalize node could be embedding layer.
|
||||
normalize_node = None
|
||||
for node in self.get_normalize_nodes():
|
||||
if self.match_parent_path(node, ['Add', 'Gather'], [0, 0]) is not None:
|
||||
if self.find_first_child_by_type(node, 'Attention', input_name_to_nodes, recursive=False) is not None:
|
||||
normalize_node = node
|
||||
break
|
||||
|
||||
if normalize_node is None:
|
||||
print("did not find embedding layer")
|
||||
|
||||
# Here we assume the order of embedding is word_embedding + position_embedding + segment_embedding.
|
||||
word_embedding_path = self.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 0])
|
||||
if word_embedding_path is None:
|
||||
print("Failed to find word embedding")
|
||||
return
|
||||
add_node, word_embedding_gather = word_embedding_path
|
||||
|
||||
position_embedding_path = self.match_parent_path(add_node, ['Gather', 'Expand', 'Shape'], [1, 1, 1])
|
||||
if position_embedding_path is None:
|
||||
print("Failed to find position embedding")
|
||||
return
|
||||
position_embedding_gather, position_embedding_expand, position_embedding_shape = position_embedding_path
|
||||
|
||||
segment_embedding_path = self.match_parent_path(normalize_node, ['Gather'], [1])
|
||||
if segment_embedding_path is None:
|
||||
print("failed to find segment embedding")
|
||||
return
|
||||
segment_embedding_gather = segment_embedding_path[0]
|
||||
|
||||
input_ids = word_embedding_gather.input[1]
|
||||
segment_ids = segment_embedding_gather.input[1]
|
||||
|
||||
if position_embedding_shape.input[0] != input_ids:
|
||||
print("position and word embedding is expected to be applied on same input")
|
||||
return
|
||||
|
||||
subgraph_nodes = self.get_parent_subgraph_nodes(position_embedding_expand, [input_ids], output_name_to_node)
|
||||
|
||||
nodes_to_remove.extend(subgraph_nodes)
|
||||
nodes_to_remove.extend([normalize_node, add_node, segment_embedding_gather, word_embedding_gather, position_embedding_gather, position_embedding_expand])
|
||||
|
||||
embed_node = onnx.helper.make_node('EmbedLayerNormalization',
|
||||
inputs=[input_ids, segment_ids, mask_input_name,
|
||||
word_embedding_gather.input[0], position_embedding_gather.input[0], segment_embedding_gather.input[0],
|
||||
normalize_node.input[2], normalize_node.input[3]], # gamma and beta
|
||||
outputs=["embed_output", "mask_idx"],
|
||||
name="EmbedLayer")
|
||||
embed_node.domain = "com.microsoft"
|
||||
# store embed node for other processing
|
||||
self.embed_node = embed_node
|
||||
|
||||
nodes_to_add.extend([embed_node])
|
||||
|
||||
self.replace_input_of_all_nodes(normalize_node.output[0], 'embed_output')
|
||||
self.replace_input_of_all_nodes(mask_input_name, 'mask_idx')
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
self.add_nodes(nodes_to_add)
|
||||
self.update_graph(verbose)
|
||||
|
||||
def get_batch_size_from_graph_input(self):
|
||||
graph = self.graph()
|
||||
for input in graph.input:
|
||||
if input.name in self.embed_node.input[:3]:
|
||||
tensor_type = input.type.tensor_type
|
||||
if (tensor_type.HasField("shape")):
|
||||
for d in tensor_type.shape.dim:
|
||||
if (d.HasField("dim_value")):
|
||||
return d.dim_value
|
||||
elif (d.HasField("dim_param")):
|
||||
return str(d.dim_param) # unknown dimension with symbolic name
|
||||
return None
|
||||
return None
|
||||
|
||||
def change_input_to_int32(self):
|
||||
original_opset_version = self.model.opset_import[0].version
|
||||
graph = self.graph()
|
||||
|
||||
batch_size = self.get_batch_size_from_graph_input()
|
||||
input_batch_size = batch_size if isinstance(batch_size, int) else 1
|
||||
new_graph_inputs = []
|
||||
for input in graph.input:
|
||||
if input.name in self.embed_node.input[:3]: # Only the first 3 inputs of embed node need int32 conversion.
|
||||
int32_input = onnx.helper.make_tensor_value_info(input.name, TensorProto.INT32, [input_batch_size, self.sequence_length])
|
||||
new_graph_inputs.append(int32_input)
|
||||
else:
|
||||
new_graph_inputs.append(input)
|
||||
|
||||
graph_def = onnx.helper.make_graph(graph.node,
|
||||
'int32 inputs',
|
||||
new_graph_inputs,
|
||||
graph.output,
|
||||
initializer=graph.initializer,
|
||||
value_info=graph.value_info)
|
||||
|
||||
self.model = onnx.helper.make_model(graph_def, producer_name='bert model optimizer')
|
||||
|
||||
if isinstance(batch_size, str):
|
||||
self.update_dynamic_batch_io(batch_size)
|
||||
|
||||
# restore opset version
|
||||
self.model.opset_import[0].version = original_opset_version
|
||||
|
||||
def cast_input_to_int32(self):
|
||||
for input in self.embed_node.input[:3]:
|
||||
graph_input = self.find_graph_input(input)
|
||||
if graph_input is not None and graph_input.type.tensor_type.elem_type == TensorProto.INT64:
|
||||
cast_output = input + '_int32'
|
||||
cast_node = onnx.helper.make_node('Cast', inputs=[input], outputs=[cast_output])
|
||||
cast_node.attribute.extend([onnx.helper.make_attribute("to", int(TensorProto.INT32))])
|
||||
self.replace_input_of_all_nodes(input, cast_output)
|
||||
self.add_node(cast_node)
|
||||
|
||||
# Update input and output using dynamic batch
|
||||
def update_dynamic_batch_io(self, dynamic_batch_dim='batch'):
|
||||
dynamic_batch_inputs = {}
|
||||
for input in self.model.graph.input:
|
||||
for embed_input in self.embed_node.input[:3]:
|
||||
if embed_input == input.name:
|
||||
dim_proto = input.type.tensor_type.shape.dim[0]
|
||||
dim_proto.dim_param = dynamic_batch_dim
|
||||
|
||||
for output in self.model.graph.output:
|
||||
dim_proto = output.type.tensor_type.shape.dim[0]
|
||||
dim_proto.dim_param = dynamic_batch_dim
|
||||
|
||||
"""
|
||||
Layer Normalization will fuse Add + LayerNormalization into one node:
|
||||
+----------------------+
|
||||
| |
|
||||
| v
|
||||
Add --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
|
||||
| ^
|
||||
| |
|
||||
+-----------------------------------------------+
|
||||
|
||||
It also handles cases of duplicated sub nodes exported from older version of PyTorch:
|
||||
+----------------------+
|
||||
| v
|
||||
| +-------> Sub-----------------------------------------------+
|
||||
| | |
|
||||
| | v
|
||||
Add --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
|
||||
| ^
|
||||
| |
|
||||
+----------------------+
|
||||
"""
|
||||
def fuse_layer_norm(self):
|
||||
input_name_to_nodes = self.input_name_to_nodes()
|
||||
output_name_to_node = self.output_name_to_node()
|
||||
|
||||
nodes_to_remove = []
|
||||
nodes_to_add = []
|
||||
|
||||
for node in self.nodes():
|
||||
if node.op_type == 'Add':
|
||||
children = self.get_children(node, input_name_to_nodes)
|
||||
children_types = sorted([child.op_type for child in children])
|
||||
if children_types != ["ReduceMean", "Sub"] and children_types != ["ReduceMean", "Sub", "Sub"]:
|
||||
continue
|
||||
|
||||
div_node = None
|
||||
for child in children:
|
||||
if child.op_type == 'Sub':
|
||||
div_node = self.find_first_child_by_type(child, 'Div', input_name_to_nodes, recursive=False)
|
||||
if div_node is not None:
|
||||
break
|
||||
if div_node is None:
|
||||
continue
|
||||
|
||||
parent_nodes = self.match_parent_path(div_node, ['Sqrt', 'Add', 'ReduceMean', 'Pow', 'Sub', 'Add'], [1, 0, 0, 0, 0, 0], output_name_to_node)
|
||||
if parent_nodes is None:
|
||||
continue
|
||||
|
||||
sqrt_node, second_add_node, reduce_mean_node, pow_node, sub_node, first_add_node = parent_nodes
|
||||
if first_add_node != node:
|
||||
continue
|
||||
|
||||
mul_node = input_name_to_nodes[div_node.output[0]][0]
|
||||
if mul_node.op_type != 'Mul':
|
||||
continue
|
||||
|
||||
last_add_node = input_name_to_nodes[mul_node.output[0]][0]
|
||||
if last_add_node.op_type != 'Add':
|
||||
continue
|
||||
|
||||
nodes_to_remove.append(node)
|
||||
nodes_to_remove.extend(children)
|
||||
nodes_to_remove.extend([last_add_node, mul_node, div_node, sqrt_node, second_add_node, reduce_mean_node, pow_node])
|
||||
|
||||
normalize_node_name = self.create_node_name(self.normalize_name, name_prefix="SkipLayerNorm")
|
||||
inputs = [i for i in node.input]
|
||||
inputs.extend([mul_node.input[0], last_add_node.input[1]])
|
||||
normalize_node = onnx.helper.make_node(self.normalize_name,
|
||||
inputs=inputs,
|
||||
outputs=[last_add_node.output[0]],
|
||||
name=normalize_node_name)
|
||||
normalize_node.domain = "com.microsoft"
|
||||
nodes_to_add.extend([normalize_node])
|
||||
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
self.add_nodes(nodes_to_add)
|
||||
print("Fused layer normalization count:", len(nodes_to_add))
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input', required=True, type=str)
|
||||
parser.add_argument('--output', required=True, type=str)
|
||||
|
||||
# model parameters
|
||||
parser.add_argument('--num_heads', required=False, type=int, default=12, help="number of attention heads")
|
||||
parser.add_argument('--hidden_size', required=False, type=int, default=768)
|
||||
parser.add_argument('--sequence_length', required=False, type=int, default=128)
|
||||
|
||||
# Use int32 (instead of int64) tensor as input to avoid unnecessary data type cast.
|
||||
parser.add_argument('--input_int32', required=False, action='store_true')
|
||||
parser.set_defaults(input_int32=False)
|
||||
|
||||
# For NVidia GPU with Tensor Core like V100 and T4, half-precision float brings better performance.
|
||||
parser.add_argument('--float16', required=False, action='store_true')
|
||||
parser.set_defaults(float16=False)
|
||||
|
||||
parser.add_argument('--verbose', required=False, action='store_true')
|
||||
parser.set_defaults(verbose=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model = ModelProto()
|
||||
with open(args.input, "rb") as f:
|
||||
model.ParseFromString(f.read())
|
||||
|
||||
bert_model = BertOnnxModel(model, args.num_heads, args.hidden_size, args.sequence_length)
|
||||
|
||||
bert_model.fuse_layer_norm()
|
||||
|
||||
bert_model.fuse_gelu()
|
||||
|
||||
bert_model.fuse_reshape()
|
||||
|
||||
bert_model.fuse_attention(args.verbose)
|
||||
|
||||
bert_model.fuse_embed_layer(args.verbose)
|
||||
|
||||
if bert_model.embed_node is None:
|
||||
print("Failed to fuse embedding layer.")
|
||||
return
|
||||
|
||||
if args.input_int32:
|
||||
bert_model.change_input_to_int32()
|
||||
else:
|
||||
bert_model.cast_input_to_int32()
|
||||
|
||||
|
||||
if args.float16:
|
||||
bert_model.convert_model_float32_to_float16()
|
||||
|
||||
with open(args.output, "wb") as out:
|
||||
out.write(bert_model.model.SerializeToString())
|
||||
|
||||
main()
|
||||
Loading…
Reference in a new issue