mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Update transformers tool based on latest transformers (#6641)
* bert_base_cased: embedlayer fusion * xlm_mlm_en_2048: attention fusion
This commit is contained in:
parent
a7b6fc08f2
commit
b4b829dfcf
2 changed files with 51 additions and 25 deletions
|
|
@ -212,6 +212,8 @@ class FusionAttention(Fusion):
|
|||
root_input = layernorm_node.output[0]
|
||||
else:
|
||||
return
|
||||
elif mul_children is not None and len(mul_children) == 5:
|
||||
root_input = mul_before_layernorm.output[0]
|
||||
else:
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
SkipLayerNormalization
|
||||
"""
|
||||
def __init__(self, model: OnnxModel, description='no mask'):
|
||||
super().__init__(model, "EmbedLayerNormalization", "SkipLayerNormalization", description)
|
||||
super().__init__(model, "EmbedLayerNormalization", ["SkipLayerNormalization", "LayerNormalization"], description)
|
||||
self.utils = FusionUtils(model)
|
||||
self.attention = None
|
||||
|
||||
|
|
@ -45,16 +45,24 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
segment_ids = None
|
||||
segment_embedding_gather = None
|
||||
|
||||
segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1])
|
||||
if normalize_node.op_type == "SkipLayerNormalization":
|
||||
segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1])
|
||||
|
||||
if segment_embedding_path is None:
|
||||
segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1])
|
||||
if segment_embedding_path is None:
|
||||
logger.info("Segment embedding is not found. Embed layer cannot be fused.")
|
||||
return
|
||||
_, segment_embedding_gather = segment_embedding_path
|
||||
else:
|
||||
segment_embedding_gather = segment_embedding_path[0]
|
||||
elif normalize_node.op_type == "LayerNormalization":
|
||||
segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Add', 'Gather'], [0, 0, 1])
|
||||
|
||||
if segment_embedding_path is None:
|
||||
segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1])
|
||||
if segment_embedding_path is None:
|
||||
logger.info("Segment embedding is not found. Embed layer cannot be fused.")
|
||||
return
|
||||
_, segment_embedding_gather = segment_embedding_path
|
||||
else:
|
||||
segment_embedding_gather = segment_embedding_path[0]
|
||||
_, _, segment_embedding_gather = segment_embedding_path
|
||||
|
||||
segment_ids = segment_embedding_gather.input[1]
|
||||
|
||||
|
|
@ -92,7 +100,8 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
logger.debug(
|
||||
"Failed to match path SkipLayerNormalization[0] <-- Add <-- Gather or SkipLayerNormalization[0] <-- Gather"
|
||||
)
|
||||
return
|
||||
if node.op_type != "LayerNormalization" or self.model.match_parent_path(node, ['Add', 'Gather'], [0, 1]) is None:
|
||||
return
|
||||
|
||||
self.attention = self.model.find_first_child_by_type(node, 'Attention', input_name_to_nodes, recursive=False)
|
||||
if self.attention is None:
|
||||
|
|
@ -114,19 +123,23 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
if word_embedding_path is not None:
|
||||
add_node, word_embedding_gather = word_embedding_path
|
||||
else:
|
||||
word_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [0])
|
||||
word_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Add', 'Gather'], [0, 0, 0])
|
||||
if word_embedding_path is not None:
|
||||
word_embedding_gather = word_embedding_path[0]
|
||||
is_distill = True
|
||||
from packaging.version import Version
|
||||
import onnxruntime
|
||||
if Version(onnxruntime.__version__) <= Version("1.4.0"):
|
||||
logger.warning(
|
||||
'Please install onnxruntime with version > 1.4.0 for embedlayer fusion support for distilbert')
|
||||
return
|
||||
_, add_node, word_embedding_gather = word_embedding_path
|
||||
else:
|
||||
logger.info("Word embedding path is not found. Embed layer cannot be fused.")
|
||||
return
|
||||
word_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [0])
|
||||
if word_embedding_path is not None:
|
||||
word_embedding_gather = word_embedding_path[0]
|
||||
is_distill = True
|
||||
from packaging.version import Version
|
||||
import onnxruntime
|
||||
if Version(onnxruntime.__version__) <= Version("1.4.0"):
|
||||
logger.warning(
|
||||
'Please install onnxruntime with version > 1.4.0 for embedlayer fusion support for distilbert')
|
||||
return
|
||||
else:
|
||||
logger.info("Word embedding path is not found. Embed layer cannot be fused.")
|
||||
return
|
||||
|
||||
input_ids = word_embedding_gather.input[1]
|
||||
|
||||
|
|
@ -162,8 +175,12 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
if position_embedding_path is not None:
|
||||
position_embedding_weight_node, position_embedding_node_before_gather = position_embedding_path
|
||||
else:
|
||||
logger.info("Position embedding path is not found. Embed layer cannot be fused.")
|
||||
return
|
||||
position_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather', 'Slice'], [0, 1, 1])
|
||||
if position_embedding_path is not None:
|
||||
_, position_embedding_weight_node, position_embedding_node_before_gather = position_embedding_path
|
||||
else:
|
||||
logger.info("Position embedding path is not found. Embed layer cannot be fused.")
|
||||
return
|
||||
|
||||
if position_embedding_shape is not None and position_embedding_shape.input[0] != input_ids:
|
||||
logger.info("position and word embedding is expected to be applied on same input")
|
||||
|
|
@ -191,6 +208,13 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
node_name = self.model.create_node_name('EmbedLayerNormalization')
|
||||
output_name = node_name + "_output"
|
||||
|
||||
if normalize_node.op_type == "LayerNormalization":
|
||||
gamma = normalize_node.input[1]
|
||||
beta = normalize_node.input[2]
|
||||
elif normalize_node.op_type == "SkipLayerNormalization":
|
||||
gamma = normalize_node.input[2]
|
||||
beta = normalize_node.input[3]
|
||||
|
||||
embed_node_inputs = None
|
||||
if is_distill == False:
|
||||
segment_path = self.match_segment_path(normalize_node, input_name_to_nodes, output_name_to_node,
|
||||
|
|
@ -206,8 +230,8 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
word_embedding_gather.input[0],
|
||||
position_embedding_weight_node.input[0],
|
||||
segment_embedding_gather.input[0],
|
||||
normalize_node.input[2],
|
||||
normalize_node.input[3] # gamma and beta
|
||||
gamma,
|
||||
beta # gamma and beta
|
||||
]
|
||||
else:
|
||||
embed_node_inputs = [
|
||||
|
|
@ -216,8 +240,8 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
word_embedding_gather.input[0],
|
||||
position_embedding_weight_node.input[0],
|
||||
'',
|
||||
normalize_node.input[2],
|
||||
normalize_node.input[3] # gamma and beta
|
||||
gamma,
|
||||
beta # gamma and beta
|
||||
]
|
||||
|
||||
embed_node = helper.make_node('EmbedLayerNormalization',
|
||||
|
|
|
|||
Loading…
Reference in a new issue