Update onnx_model_bert_tf.py (#8457)

Fix a bug: when layernorm and skiplayernorm are not fused, the program will crash
This commit is contained in:
Ye Wang 2021-07-22 13:50:55 -07:00 committed by GitHub
parent 9a6fa057c8
commit e8ee31bcc3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -288,9 +288,8 @@ class BertOnnxModelTF(BertOnnxModel):
start_nodes.extend(skip_layer_norm_nodes)
start_nodes.extend(layer_norm_nodes)
graph_name = self.get_graph_by_node(start_nodes[0]).name
for normalize_node in start_nodes:
graph_name = self.get_graph_by_node(normalize_node).name
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
if normalize_node.op_type == 'LayerNormalization':
add_before_layernorm = self.match_parent(normalize_node, 'Add', 0)