Update transformers tool based on latest transformers (#6641)

* bert_base_cased: embedlayer fusion

* xlm_mlm_en_2048: attention fusion
This commit is contained in:
Ye Wang 2021-02-11 10:11:47 -08:00 committed by GitHub
parent a7b6fc08f2
commit b4b829dfcf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 25 deletions

View file

@ -212,6 +212,8 @@ class FusionAttention(Fusion):
root_input = layernorm_node.output[0]
else:
return
elif mul_children is not None and len(mul_children) == 5:
root_input = mul_before_layernorm.output[0]
else:
return

View file

@ -37,7 +37,7 @@ class FusionEmbedLayerNoMask(Fusion):
SkipLayerNormalization
"""
def __init__(self, model: OnnxModel, description='no mask'):
super().__init__(model, "EmbedLayerNormalization", "SkipLayerNormalization", description)
super().__init__(model, "EmbedLayerNormalization", ["SkipLayerNormalization", "LayerNormalization"], description)
self.utils = FusionUtils(model)
self.attention = None
@ -45,16 +45,24 @@ class FusionEmbedLayerNoMask(Fusion):
segment_ids = None
segment_embedding_gather = None
segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1])
if normalize_node.op_type == "SkipLayerNormalization":
segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1])
if segment_embedding_path is None:
segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1])
if segment_embedding_path is None:
logger.info("Segment embedding is not found. Embed layer cannot be fused.")
return
_, segment_embedding_gather = segment_embedding_path
else:
segment_embedding_gather = segment_embedding_path[0]
elif normalize_node.op_type == "LayerNormalization":
segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Add', 'Gather'], [0, 0, 1])
if segment_embedding_path is None:
segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1])
if segment_embedding_path is None:
logger.info("Segment embedding is not found. Embed layer cannot be fused.")
return
_, segment_embedding_gather = segment_embedding_path
else:
segment_embedding_gather = segment_embedding_path[0]
_, _, segment_embedding_gather = segment_embedding_path
segment_ids = segment_embedding_gather.input[1]
@ -92,7 +100,8 @@ class FusionEmbedLayerNoMask(Fusion):
logger.debug(
"Failed to match path SkipLayerNormalization[0] <-- Add <-- Gather or SkipLayerNormalization[0] <-- Gather"
)
return
if node.op_type != "LayerNormalization" or self.model.match_parent_path(node, ['Add', 'Gather'], [0, 1]) is None:
return
self.attention = self.model.find_first_child_by_type(node, 'Attention', input_name_to_nodes, recursive=False)
if self.attention is None:
@ -114,19 +123,23 @@ class FusionEmbedLayerNoMask(Fusion):
if word_embedding_path is not None:
add_node, word_embedding_gather = word_embedding_path
else:
word_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [0])
word_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Add', 'Gather'], [0, 0, 0])
if word_embedding_path is not None:
word_embedding_gather = word_embedding_path[0]
is_distill = True
from packaging.version import Version
import onnxruntime
if Version(onnxruntime.__version__) <= Version("1.4.0"):
logger.warning(
'Please install onnxruntime with version > 1.4.0 for embedlayer fusion support for distilbert')
return
_, add_node, word_embedding_gather = word_embedding_path
else:
logger.info("Word embedding path is not found. Embed layer cannot be fused.")
return
word_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [0])
if word_embedding_path is not None:
word_embedding_gather = word_embedding_path[0]
is_distill = True
from packaging.version import Version
import onnxruntime
if Version(onnxruntime.__version__) <= Version("1.4.0"):
logger.warning(
'Please install onnxruntime with version > 1.4.0 for embedlayer fusion support for distilbert')
return
else:
logger.info("Word embedding path is not found. Embed layer cannot be fused.")
return
input_ids = word_embedding_gather.input[1]
@ -162,8 +175,12 @@ class FusionEmbedLayerNoMask(Fusion):
if position_embedding_path is not None:
position_embedding_weight_node, position_embedding_node_before_gather = position_embedding_path
else:
logger.info("Position embedding path is not found. Embed layer cannot be fused.")
return
position_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather', 'Slice'], [0, 1, 1])
if position_embedding_path is not None:
_, position_embedding_weight_node, position_embedding_node_before_gather = position_embedding_path
else:
logger.info("Position embedding path is not found. Embed layer cannot be fused.")
return
if position_embedding_shape is not None and position_embedding_shape.input[0] != input_ids:
logger.info("position and word embedding is expected to be applied on same input")
@ -191,6 +208,13 @@ class FusionEmbedLayerNoMask(Fusion):
node_name = self.model.create_node_name('EmbedLayerNormalization')
output_name = node_name + "_output"
if normalize_node.op_type == "LayerNormalization":
gamma = normalize_node.input[1]
beta = normalize_node.input[2]
elif normalize_node.op_type == "SkipLayerNormalization":
gamma = normalize_node.input[2]
beta = normalize_node.input[3]
embed_node_inputs = None
if is_distill == False:
segment_path = self.match_segment_path(normalize_node, input_name_to_nodes, output_name_to_node,
@ -206,8 +230,8 @@ class FusionEmbedLayerNoMask(Fusion):
word_embedding_gather.input[0],
position_embedding_weight_node.input[0],
segment_embedding_gather.input[0],
normalize_node.input[2],
normalize_node.input[3] # gamma and beta
gamma,
beta # gamma and beta
]
else:
embed_node_inputs = [
@ -216,8 +240,8 @@ class FusionEmbedLayerNoMask(Fusion):
word_embedding_gather.input[0],
position_embedding_weight_node.input[0],
'',
normalize_node.input[2],
normalize_node.input[3] # gamma and beta
gamma,
beta # gamma and beta
]
embed_node = helper.make_node('EmbedLayerNormalization',