diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index b413c3dcd9..da984c44ad 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -212,6 +212,8 @@ class FusionAttention(Fusion): root_input = layernorm_node.output[0] else: return + elif mul_children is not None and len(mul_children) == 5: + root_input = mul_before_layernorm.output[0] else: return diff --git a/onnxruntime/python/tools/transformers/fusion_embedlayer.py b/onnxruntime/python/tools/transformers/fusion_embedlayer.py index c5812caab6..cd9289d308 100644 --- a/onnxruntime/python/tools/transformers/fusion_embedlayer.py +++ b/onnxruntime/python/tools/transformers/fusion_embedlayer.py @@ -37,7 +37,7 @@ class FusionEmbedLayerNoMask(Fusion): SkipLayerNormalization """ def __init__(self, model: OnnxModel, description='no mask'): - super().__init__(model, "EmbedLayerNormalization", "SkipLayerNormalization", description) + super().__init__(model, "EmbedLayerNormalization", ["SkipLayerNormalization", "LayerNormalization"], description) self.utils = FusionUtils(model) self.attention = None @@ -45,16 +45,24 @@ class FusionEmbedLayerNoMask(Fusion): segment_ids = None segment_embedding_gather = None - segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1]) + if normalize_node.op_type == "SkipLayerNormalization": + segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1]) + + if segment_embedding_path is None: + segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1]) + if segment_embedding_path is None: + logger.info("Segment embedding is not found. Embed layer cannot be fused.") + return + _, segment_embedding_gather = segment_embedding_path + else: + segment_embedding_gather = segment_embedding_path[0] + elif normalize_node.op_type == "LayerNormalization": + segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Add', 'Gather'], [0, 0, 1]) - if segment_embedding_path is None: - segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1]) if segment_embedding_path is None: logger.info("Segment embedding is not found. Embed layer cannot be fused.") return - _, segment_embedding_gather = segment_embedding_path - else: - segment_embedding_gather = segment_embedding_path[0] + _, _, segment_embedding_gather = segment_embedding_path segment_ids = segment_embedding_gather.input[1] @@ -92,7 +100,8 @@ class FusionEmbedLayerNoMask(Fusion): logger.debug( "Failed to match path SkipLayerNormalization[0] <-- Add <-- Gather or SkipLayerNormalization[0] <-- Gather" ) - return + if node.op_type != "LayerNormalization" or self.model.match_parent_path(node, ['Add', 'Gather'], [0, 1]) is None: + return self.attention = self.model.find_first_child_by_type(node, 'Attention', input_name_to_nodes, recursive=False) if self.attention is None: @@ -114,19 +123,23 @@ class FusionEmbedLayerNoMask(Fusion): if word_embedding_path is not None: add_node, word_embedding_gather = word_embedding_path else: - word_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [0]) + word_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Add', 'Gather'], [0, 0, 0]) if word_embedding_path is not None: - word_embedding_gather = word_embedding_path[0] - is_distill = True - from packaging.version import Version - import onnxruntime - if Version(onnxruntime.__version__) <= Version("1.4.0"): - logger.warning( - 'Please install onnxruntime with version > 1.4.0 for embedlayer fusion support for distilbert') - return + _, add_node, word_embedding_gather = word_embedding_path else: - logger.info("Word embedding path is not found. Embed layer cannot be fused.") - return + word_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [0]) + if word_embedding_path is not None: + word_embedding_gather = word_embedding_path[0] + is_distill = True + from packaging.version import Version + import onnxruntime + if Version(onnxruntime.__version__) <= Version("1.4.0"): + logger.warning( + 'Please install onnxruntime with version > 1.4.0 for embedlayer fusion support for distilbert') + return + else: + logger.info("Word embedding path is not found. Embed layer cannot be fused.") + return input_ids = word_embedding_gather.input[1] @@ -162,8 +175,12 @@ class FusionEmbedLayerNoMask(Fusion): if position_embedding_path is not None: position_embedding_weight_node, position_embedding_node_before_gather = position_embedding_path else: - logger.info("Position embedding path is not found. Embed layer cannot be fused.") - return + position_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather', 'Slice'], [0, 1, 1]) + if position_embedding_path is not None: + _, position_embedding_weight_node, position_embedding_node_before_gather = position_embedding_path + else: + logger.info("Position embedding path is not found. Embed layer cannot be fused.") + return if position_embedding_shape is not None and position_embedding_shape.input[0] != input_ids: logger.info("position and word embedding is expected to be applied on same input") @@ -191,6 +208,13 @@ class FusionEmbedLayerNoMask(Fusion): node_name = self.model.create_node_name('EmbedLayerNormalization') output_name = node_name + "_output" + if normalize_node.op_type == "LayerNormalization": + gamma = normalize_node.input[1] + beta = normalize_node.input[2] + elif normalize_node.op_type == "SkipLayerNormalization": + gamma = normalize_node.input[2] + beta = normalize_node.input[3] + embed_node_inputs = None if is_distill == False: segment_path = self.match_segment_path(normalize_node, input_name_to_nodes, output_name_to_node, @@ -206,8 +230,8 @@ class FusionEmbedLayerNoMask(Fusion): word_embedding_gather.input[0], position_embedding_weight_node.input[0], segment_embedding_gather.input[0], - normalize_node.input[2], - normalize_node.input[3] # gamma and beta + gamma, + beta # gamma and beta ] else: embed_node_inputs = [ @@ -216,8 +240,8 @@ class FusionEmbedLayerNoMask(Fusion): word_embedding_gather.input[0], position_embedding_weight_node.input[0], '', - normalize_node.input[2], - normalize_node.input[3] # gamma and beta + gamma, + beta # gamma and beta ] embed_node = helper.make_node('EmbedLayerNormalization',