diff --git a/onnxruntime/python/tools/transformers/fusion_embedlayer.py b/onnxruntime/python/tools/transformers/fusion_embedlayer.py index a20febb9f0..bc38399e3c 100644 --- a/onnxruntime/python/tools/transformers/fusion_embedlayer.py +++ b/onnxruntime/python/tools/transformers/fusion_embedlayer.py @@ -378,7 +378,7 @@ class FusionEmbedLayerNoMask(Fusion): logger.info("Cannot fuse EmbedLayerNormalization: segment embedding table is not expected") return False - # In normal case, word embeding table is the largest, and segment embedding table is the smallest, while postion embedding table is in between. + # In normal case, word embedding table is the largest, and segment embedding table is the smallest, while position embedding table is in between. # TODO: use other information (like initializer names) to identify different embedding weights automatically. if word_embedding_table.shape[0] <= position_embedding_table.shape[0]: logger.warning( @@ -430,6 +430,7 @@ class FusionEmbedLayerNoMask(Fusion): segment_embedding_gather: Union[None, NodeProto], position_ids: Optional[str] = None, embedding_sum_output=False, + embedding_sum_name=None, ): """Create an EmbedLayerNormalization node. Note that segment embedding is optional. @@ -487,7 +488,8 @@ class FusionEmbedLayerNoMask(Fusion): embed_node_outputs = [node_name + "_output", node_name + "_dummy_mask_index"] if embedding_sum_output: - embed_node_outputs.append(node_name + "_embedding_sum") + name = embedding_sum_name if embedding_sum_name is not None else node_name + "_embedding_sum" + embed_node_outputs.append(name) embed_node = helper.make_node( "EmbedLayerNormalization", @@ -522,19 +524,8 @@ class FusionEmbedLayerNoMask(Fusion): # use prune graph to remove nodes that is not needed self.prune_graph = True - def is_embedding_sum_needed(self, add_before_layer_norm): - """Check that Add before layer norm has an output to add before next layernorm - - Args: - add_before_layer_norm (NodeProto): Add before any LayerNormalization node in topological order of graph - - Returns: - bool: whether there is an extra output needed out of embed layer norm node - """ - - nodes = self.model.get_children(add_before_layer_norm) - - return len(nodes) > 1 + def is_skip_layer_norm_with_sum_output(self, node): + return (node.op_type == "SkipLayerNormalization") and len(node.output) > 3 and len(node.output[3]) > 0 def fuse_gpt2( self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather=None @@ -570,21 +561,31 @@ class FusionEmbedLayerNoMask(Fusion): if not self.check_embedding(word_embedding_gather, None, position_embedding_gather): return False - # If the add_before_layernorm node is an Add node, then the add_output output is the first index - # output of this node. - - # If the add_before_layernorm node is SkipLayerNormalization node, then the add_output output + # If layernorm node is SkipLayerNormalization, we need look at its optional fourth output. + # If the add_before_layernorm node is an Add node, then the add_output output is the first output of this node. + # If the add_before_layernorm node is a SkipLayerNormalization node, then the add_output output # is the (optional) fourth index output of this node. - add_output = None - optional_embedding_sum_output = False - if (add_before_layernorm.op_type == "Add" and self.is_embedding_sum_needed(add_before_layernorm)) or ( - add_before_layernorm.op_type == "SkipLayerNormalization" and len(add_before_layernorm.output) >= 4 - ): - optional_embedding_sum_output = True - add_output = ( - add_before_layernorm.output[0] - if add_before_layernorm.op_type == "Add" - else add_before_layernorm.output[3] + # When add_before_layernorm is SkipLayerNormalization, add_before_layernorm and layernorm are same node. + if layernorm.op_type == "SkipLayerNormalization": + need_embedding_sum_output = self.is_skip_layer_norm_with_sum_output(layernorm) + sum_output_index = 3 + node_with_sum_output = layernorm + sum_output = layernorm.output[3] if need_embedding_sum_output else None + is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None) + else: # layernorm.op_type == "LayerNormalization" + node_with_sum_output = add_before_layernorm + sum_output_index = 0 if add_before_layernorm.op_type == "Add" else 3 + sum_output = ( + add_before_layernorm.output[sum_output_index] + if len(add_before_layernorm.output) > sum_output_index + else None + ) + is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None) + is_sum_used_by_multiple_nodes = ( + sum_output and (sum_output in input_name_to_nodes) and len(input_name_to_nodes[sum_output]) > 1 + ) + need_embedding_sum_output = (sum_output is not None) and ( + add_before_layernorm.op_type != "Add" or is_sum_graph_output or is_sum_used_by_multiple_nodes ) # make the fused node @@ -595,14 +596,16 @@ class FusionEmbedLayerNoMask(Fusion): position_embedding_gather, optional_segment_gather, position_ids, - optional_embedding_sum_output, + embedding_sum_output=need_embedding_sum_output, + embedding_sum_name=sum_output if is_sum_graph_output else None, ) - # direct the output to another add too - self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0]) - if optional_embedding_sum_output: - self.model.replace_input_of_all_nodes(add_output, embed_node.output[2]) + if need_embedding_sum_output: + node_with_sum_output.output[sum_output_index] = "_no_use__to_be_removed_" + if not is_sum_graph_output: + self.model.replace_input_of_all_nodes(sum_output, embed_node.output[2]) + self.finish_fusion(layernorm, embed_node) return True def fuse_distilbert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node): @@ -707,9 +710,14 @@ class FusionEmbedLayerNoMask(Fusion): gather_0_path = self.model.match_parent_path(node, ["Gather"], [0]) gather_1_path = self.model.match_parent_path(node, ["Gather"], [1]) if gather_0_path is None and gather_1_path is not None: + if first_add_path is None: + return add_before_layernorm = first_add_path[0] optional_segment_gather = gather_1_path[0] elif gather_0_path is not None and gather_1_path is None: + first_add_path = self.model.match_parent_path(node, ["Add"], [1]) + if first_add_path is None: + return add_before_layernorm = first_add_path[0] optional_segment_gather = gather_0_path[0] else: diff --git a/onnxruntime/test/python/transformers/gpt2_model_generator.py b/onnxruntime/test/python/transformers/gpt2_model_generator.py index 6d4d6ea920..4a1b48d4d1 100644 --- a/onnxruntime/test/python/transformers/gpt2_model_generator.py +++ b/onnxruntime/test/python/transformers/gpt2_model_generator.py @@ -555,6 +555,8 @@ def create_gpt2_embedlayer( num_heads=4, epsilon=0.1, one_attention_node=False, + has_skip_layer_norm=True, + output_embedding_sum=False, ): # Construct input and output nodes inputs = [ @@ -564,21 +566,47 @@ def create_gpt2_embedlayer( helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]) ] + if output_embedding_sum: + outputs.append( + helper.make_tensor_value_info( + "embedding_sum", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size] + ) + ) + # Construct graph nodes embed_layernorm_nodes = [ helper.make_node("Gather", ["word_embeddings_weight", "ids"], ["gather_0_out"], "gather_word_embeddings"), helper.make_node("Gather", ["pos_embeddings_weight", "ids"], ["gather_1_out"], "gather_position_embeddings"), helper.make_node("Add", ["gather_0_out", "gather_1_out"], ["add_0_out"], "add_before_layernorm"), helper.make_node("Gather", ["token_embeddings_weight", "ids"], ["gather_2_out"], "gather_token_embeddings"), - helper.make_node( - "SkipLayerNormalization", - ["add_0_out", "gather_2_out", "layernorm_weight", "layernorm_bias"], - ["skip_layernorm_out"], - "skip_layernorm", - domain="com.microsoft", - epsilon=epsilon, - ), ] + + if has_skip_layer_norm: + embed_layernorm_nodes.append( + helper.make_node( + "SkipLayerNormalization", + ["add_0_out", "gather_2_out", "layernorm_weight", "layernorm_bias"], + ["skip_layernorm_out"] if not output_embedding_sum else ["skip_layernorm_out", "", "", "embedding_sum"], + "skip_layernorm", + domain="com.microsoft", + epsilon=epsilon, + ) + ) + else: + embed_layernorm_nodes.append( + helper.make_node("Add", ["add_0_out", "gather_2_out"], ["embedding_sum"], "embedding_sum") + ) + + embed_layernorm_nodes.append( + helper.make_node( + "LayerNormalization", + ["embedding_sum", "layernorm_weight", "layernorm_bias"], + ["skip_layernorm_out"], + "layernorm", + epsilon=epsilon, + ) + ) + attention_nodes = ( [ helper.make_node("MatMul", ["skip_layernorm_out", "q_weight"], ["q_out"], "q_attn"), @@ -708,6 +736,7 @@ def create_gpt2_fused_embedlayer( num_heads=4, epsilon=0.1, one_attention_node=False, + output_embedding_sum=False, ): # Construct input and output nodes inputs = [ @@ -716,6 +745,12 @@ def create_gpt2_fused_embedlayer( outputs = [ helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]) ] + if output_embedding_sum: + outputs.append( + helper.make_tensor_value_info( + "embedding_sum", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size] + ) + ) # Construct graph nodes embed_layernorm_nodes = [ @@ -732,7 +767,9 @@ def create_gpt2_fused_embedlayer( "", "ids", ], - ["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index"], + ["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index", "embedding_sum"] + if output_embedding_sum + else ["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index"], "EmbedLayerNormalization_0", domain="com.microsoft", epsilon=epsilon, @@ -876,3 +913,9 @@ if __name__ == "__main__": model = create_gpt2_fused_embedlayer(one_attention_node=True) onnx.save(model, "./test_data/models/gpt2_embedlayer_one_attn_exp.onnx") + + model = create_gpt2_embedlayer(one_attention_node=True, output_embedding_sum=True) + onnx.save(model, "gpt2_embedlayer_one_attn_output_sum.onnx") + + model = create_gpt2_fused_embedlayer(one_attention_node=True, output_embedding_sum=True) + onnx.save(model, "./test_data/models/gpt2_embedlayer_one_attn_output_sum_exp.onnx") diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_embedlayer_one_attn_output_sum_exp.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_embedlayer_one_attn_output_sum_exp.onnx new file mode 100644 index 0000000000..853f3f5cf7 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/gpt2_embedlayer_one_attn_output_sum_exp.onnx differ diff --git a/onnxruntime/test/python/transformers/test_embedlayer_fusion.py b/onnxruntime/test/python/transformers/test_embedlayer_fusion.py index 732833e5da..ccd367fdbb 100644 --- a/onnxruntime/test/python/transformers/test_embedlayer_fusion.py +++ b/onnxruntime/test/python/transformers/test_embedlayer_fusion.py @@ -74,6 +74,38 @@ class TestFusion(unittest.TestCase): os.remove(original_model_path) os.remove(optimized_model_path) + def test_embedlayer_fusion_with_embedding_sum_output(self): + model = create_gpt2_embedlayer(one_attention_node=True, output_embedding_sum=True) + path = "." + original_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum.onnx") + optimized_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum_opt.onnx") + expected_model_filename = "gpt2_embedlayer_one_attn_output_sum_exp.onnx" + + onnx.save(model, original_model_path) + optimized_model = optimize_model(original_model_path, model_type="gpt2") + optimized_model.save_model_to_file(optimized_model_path, use_external_data_format=True) + + self.verify_fusion(optimized_model, expected_model_filename) + self.verify_parity(optimized_model_path, expected_model_filename) + os.remove(original_model_path) + os.remove(optimized_model_path) + + def test_embedlayer_fusion_with_embedding_sum_output_no_sln(self): + model = create_gpt2_embedlayer(one_attention_node=True, has_skip_layer_norm=False, output_embedding_sum=True) + path = "." + original_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum_no_sln.onnx") + optimized_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum_no_sln_opt.onnx") + expected_model_filename = "gpt2_embedlayer_one_attn_output_sum_exp.onnx" + + onnx.save(model, original_model_path) + optimized_model = optimize_model(original_model_path, model_type="gpt2") + optimized_model.save_model_to_file(optimized_model_path, use_external_data_format=True) + + self.verify_fusion(optimized_model, expected_model_filename) + self.verify_parity(optimized_model_path, expected_model_filename) + os.remove(original_model_path) + os.remove(optimized_model_path) + if __name__ == "__main__": unittest.main()