mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
fix embed layer norm fusion with embedding sum output (#17460)
The embedding sum could be graph output (when exporting with output hidden state enabled). Previously, we only check whether there are multiple children node to decide whether to output embedding sum in fused node. This fix will check if the sum is graph output, we will retain the name.
This commit is contained in:
parent
9017ea131b
commit
7bc6dcecf7
4 changed files with 126 additions and 43 deletions
|
|
@ -378,7 +378,7 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
logger.info("Cannot fuse EmbedLayerNormalization: segment embedding table is not expected")
|
||||
return False
|
||||
|
||||
# In normal case, word embeding table is the largest, and segment embedding table is the smallest, while postion embedding table is in between.
|
||||
# In normal case, word embedding table is the largest, and segment embedding table is the smallest, while position embedding table is in between.
|
||||
# TODO: use other information (like initializer names) to identify different embedding weights automatically.
|
||||
if word_embedding_table.shape[0] <= position_embedding_table.shape[0]:
|
||||
logger.warning(
|
||||
|
|
@ -430,6 +430,7 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
segment_embedding_gather: Union[None, NodeProto],
|
||||
position_ids: Optional[str] = None,
|
||||
embedding_sum_output=False,
|
||||
embedding_sum_name=None,
|
||||
):
|
||||
"""Create an EmbedLayerNormalization node. Note that segment embedding is optional.
|
||||
|
||||
|
|
@ -487,7 +488,8 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
|
||||
embed_node_outputs = [node_name + "_output", node_name + "_dummy_mask_index"]
|
||||
if embedding_sum_output:
|
||||
embed_node_outputs.append(node_name + "_embedding_sum")
|
||||
name = embedding_sum_name if embedding_sum_name is not None else node_name + "_embedding_sum"
|
||||
embed_node_outputs.append(name)
|
||||
|
||||
embed_node = helper.make_node(
|
||||
"EmbedLayerNormalization",
|
||||
|
|
@ -522,19 +524,8 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
# use prune graph to remove nodes that is not needed
|
||||
self.prune_graph = True
|
||||
|
||||
def is_embedding_sum_needed(self, add_before_layer_norm):
|
||||
"""Check that Add before layer norm has an output to add before next layernorm
|
||||
|
||||
Args:
|
||||
add_before_layer_norm (NodeProto): Add before any LayerNormalization node in topological order of graph
|
||||
|
||||
Returns:
|
||||
bool: whether there is an extra output needed out of embed layer norm node
|
||||
"""
|
||||
|
||||
nodes = self.model.get_children(add_before_layer_norm)
|
||||
|
||||
return len(nodes) > 1
|
||||
def is_skip_layer_norm_with_sum_output(self, node):
|
||||
return (node.op_type == "SkipLayerNormalization") and len(node.output) > 3 and len(node.output[3]) > 0
|
||||
|
||||
def fuse_gpt2(
|
||||
self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather=None
|
||||
|
|
@ -570,21 +561,31 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
|
||||
return False
|
||||
|
||||
# If the add_before_layernorm node is an Add node, then the add_output output is the first index
|
||||
# output of this node.
|
||||
|
||||
# If the add_before_layernorm node is SkipLayerNormalization node, then the add_output output
|
||||
# If layernorm node is SkipLayerNormalization, we need look at its optional fourth output.
|
||||
# If the add_before_layernorm node is an Add node, then the add_output output is the first output of this node.
|
||||
# If the add_before_layernorm node is a SkipLayerNormalization node, then the add_output output
|
||||
# is the (optional) fourth index output of this node.
|
||||
add_output = None
|
||||
optional_embedding_sum_output = False
|
||||
if (add_before_layernorm.op_type == "Add" and self.is_embedding_sum_needed(add_before_layernorm)) or (
|
||||
add_before_layernorm.op_type == "SkipLayerNormalization" and len(add_before_layernorm.output) >= 4
|
||||
):
|
||||
optional_embedding_sum_output = True
|
||||
add_output = (
|
||||
add_before_layernorm.output[0]
|
||||
if add_before_layernorm.op_type == "Add"
|
||||
else add_before_layernorm.output[3]
|
||||
# When add_before_layernorm is SkipLayerNormalization, add_before_layernorm and layernorm are same node.
|
||||
if layernorm.op_type == "SkipLayerNormalization":
|
||||
need_embedding_sum_output = self.is_skip_layer_norm_with_sum_output(layernorm)
|
||||
sum_output_index = 3
|
||||
node_with_sum_output = layernorm
|
||||
sum_output = layernorm.output[3] if need_embedding_sum_output else None
|
||||
is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
|
||||
else: # layernorm.op_type == "LayerNormalization"
|
||||
node_with_sum_output = add_before_layernorm
|
||||
sum_output_index = 0 if add_before_layernorm.op_type == "Add" else 3
|
||||
sum_output = (
|
||||
add_before_layernorm.output[sum_output_index]
|
||||
if len(add_before_layernorm.output) > sum_output_index
|
||||
else None
|
||||
)
|
||||
is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
|
||||
is_sum_used_by_multiple_nodes = (
|
||||
sum_output and (sum_output in input_name_to_nodes) and len(input_name_to_nodes[sum_output]) > 1
|
||||
)
|
||||
need_embedding_sum_output = (sum_output is not None) and (
|
||||
add_before_layernorm.op_type != "Add" or is_sum_graph_output or is_sum_used_by_multiple_nodes
|
||||
)
|
||||
|
||||
# make the fused node
|
||||
|
|
@ -595,14 +596,16 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
position_embedding_gather,
|
||||
optional_segment_gather,
|
||||
position_ids,
|
||||
optional_embedding_sum_output,
|
||||
embedding_sum_output=need_embedding_sum_output,
|
||||
embedding_sum_name=sum_output if is_sum_graph_output else None,
|
||||
)
|
||||
|
||||
# direct the output to another add too
|
||||
self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0])
|
||||
if optional_embedding_sum_output:
|
||||
self.model.replace_input_of_all_nodes(add_output, embed_node.output[2])
|
||||
if need_embedding_sum_output:
|
||||
node_with_sum_output.output[sum_output_index] = "_no_use__to_be_removed_"
|
||||
if not is_sum_graph_output:
|
||||
self.model.replace_input_of_all_nodes(sum_output, embed_node.output[2])
|
||||
|
||||
self.finish_fusion(layernorm, embed_node)
|
||||
return True
|
||||
|
||||
def fuse_distilbert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
|
||||
|
|
@ -707,9 +710,14 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
gather_0_path = self.model.match_parent_path(node, ["Gather"], [0])
|
||||
gather_1_path = self.model.match_parent_path(node, ["Gather"], [1])
|
||||
if gather_0_path is None and gather_1_path is not None:
|
||||
if first_add_path is None:
|
||||
return
|
||||
add_before_layernorm = first_add_path[0]
|
||||
optional_segment_gather = gather_1_path[0]
|
||||
elif gather_0_path is not None and gather_1_path is None:
|
||||
first_add_path = self.model.match_parent_path(node, ["Add"], [1])
|
||||
if first_add_path is None:
|
||||
return
|
||||
add_before_layernorm = first_add_path[0]
|
||||
optional_segment_gather = gather_0_path[0]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -555,6 +555,8 @@ def create_gpt2_embedlayer(
|
|||
num_heads=4,
|
||||
epsilon=0.1,
|
||||
one_attention_node=False,
|
||||
has_skip_layer_norm=True,
|
||||
output_embedding_sum=False,
|
||||
):
|
||||
# Construct input and output nodes
|
||||
inputs = [
|
||||
|
|
@ -564,21 +566,47 @@ def create_gpt2_embedlayer(
|
|||
helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size])
|
||||
]
|
||||
|
||||
if output_embedding_sum:
|
||||
outputs.append(
|
||||
helper.make_tensor_value_info(
|
||||
"embedding_sum", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]
|
||||
)
|
||||
)
|
||||
|
||||
# Construct graph nodes
|
||||
embed_layernorm_nodes = [
|
||||
helper.make_node("Gather", ["word_embeddings_weight", "ids"], ["gather_0_out"], "gather_word_embeddings"),
|
||||
helper.make_node("Gather", ["pos_embeddings_weight", "ids"], ["gather_1_out"], "gather_position_embeddings"),
|
||||
helper.make_node("Add", ["gather_0_out", "gather_1_out"], ["add_0_out"], "add_before_layernorm"),
|
||||
helper.make_node("Gather", ["token_embeddings_weight", "ids"], ["gather_2_out"], "gather_token_embeddings"),
|
||||
helper.make_node(
|
||||
"SkipLayerNormalization",
|
||||
["add_0_out", "gather_2_out", "layernorm_weight", "layernorm_bias"],
|
||||
["skip_layernorm_out"],
|
||||
"skip_layernorm",
|
||||
domain="com.microsoft",
|
||||
epsilon=epsilon,
|
||||
),
|
||||
]
|
||||
|
||||
if has_skip_layer_norm:
|
||||
embed_layernorm_nodes.append(
|
||||
helper.make_node(
|
||||
"SkipLayerNormalization",
|
||||
["add_0_out", "gather_2_out", "layernorm_weight", "layernorm_bias"],
|
||||
["skip_layernorm_out"] if not output_embedding_sum else ["skip_layernorm_out", "", "", "embedding_sum"],
|
||||
"skip_layernorm",
|
||||
domain="com.microsoft",
|
||||
epsilon=epsilon,
|
||||
)
|
||||
)
|
||||
else:
|
||||
embed_layernorm_nodes.append(
|
||||
helper.make_node("Add", ["add_0_out", "gather_2_out"], ["embedding_sum"], "embedding_sum")
|
||||
)
|
||||
|
||||
embed_layernorm_nodes.append(
|
||||
helper.make_node(
|
||||
"LayerNormalization",
|
||||
["embedding_sum", "layernorm_weight", "layernorm_bias"],
|
||||
["skip_layernorm_out"],
|
||||
"layernorm",
|
||||
epsilon=epsilon,
|
||||
)
|
||||
)
|
||||
|
||||
attention_nodes = (
|
||||
[
|
||||
helper.make_node("MatMul", ["skip_layernorm_out", "q_weight"], ["q_out"], "q_attn"),
|
||||
|
|
@ -708,6 +736,7 @@ def create_gpt2_fused_embedlayer(
|
|||
num_heads=4,
|
||||
epsilon=0.1,
|
||||
one_attention_node=False,
|
||||
output_embedding_sum=False,
|
||||
):
|
||||
# Construct input and output nodes
|
||||
inputs = [
|
||||
|
|
@ -716,6 +745,12 @@ def create_gpt2_fused_embedlayer(
|
|||
outputs = [
|
||||
helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size])
|
||||
]
|
||||
if output_embedding_sum:
|
||||
outputs.append(
|
||||
helper.make_tensor_value_info(
|
||||
"embedding_sum", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]
|
||||
)
|
||||
)
|
||||
|
||||
# Construct graph nodes
|
||||
embed_layernorm_nodes = [
|
||||
|
|
@ -732,7 +767,9 @@ def create_gpt2_fused_embedlayer(
|
|||
"",
|
||||
"ids",
|
||||
],
|
||||
["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index"],
|
||||
["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index", "embedding_sum"]
|
||||
if output_embedding_sum
|
||||
else ["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index"],
|
||||
"EmbedLayerNormalization_0",
|
||||
domain="com.microsoft",
|
||||
epsilon=epsilon,
|
||||
|
|
@ -876,3 +913,9 @@ if __name__ == "__main__":
|
|||
|
||||
model = create_gpt2_fused_embedlayer(one_attention_node=True)
|
||||
onnx.save(model, "./test_data/models/gpt2_embedlayer_one_attn_exp.onnx")
|
||||
|
||||
model = create_gpt2_embedlayer(one_attention_node=True, output_embedding_sum=True)
|
||||
onnx.save(model, "gpt2_embedlayer_one_attn_output_sum.onnx")
|
||||
|
||||
model = create_gpt2_fused_embedlayer(one_attention_node=True, output_embedding_sum=True)
|
||||
onnx.save(model, "./test_data/models/gpt2_embedlayer_one_attn_output_sum_exp.onnx")
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -74,6 +74,38 @@ class TestFusion(unittest.TestCase):
|
|||
os.remove(original_model_path)
|
||||
os.remove(optimized_model_path)
|
||||
|
||||
def test_embedlayer_fusion_with_embedding_sum_output(self):
|
||||
model = create_gpt2_embedlayer(one_attention_node=True, output_embedding_sum=True)
|
||||
path = "."
|
||||
original_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum.onnx")
|
||||
optimized_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum_opt.onnx")
|
||||
expected_model_filename = "gpt2_embedlayer_one_attn_output_sum_exp.onnx"
|
||||
|
||||
onnx.save(model, original_model_path)
|
||||
optimized_model = optimize_model(original_model_path, model_type="gpt2")
|
||||
optimized_model.save_model_to_file(optimized_model_path, use_external_data_format=True)
|
||||
|
||||
self.verify_fusion(optimized_model, expected_model_filename)
|
||||
self.verify_parity(optimized_model_path, expected_model_filename)
|
||||
os.remove(original_model_path)
|
||||
os.remove(optimized_model_path)
|
||||
|
||||
def test_embedlayer_fusion_with_embedding_sum_output_no_sln(self):
|
||||
model = create_gpt2_embedlayer(one_attention_node=True, has_skip_layer_norm=False, output_embedding_sum=True)
|
||||
path = "."
|
||||
original_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum_no_sln.onnx")
|
||||
optimized_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum_no_sln_opt.onnx")
|
||||
expected_model_filename = "gpt2_embedlayer_one_attn_output_sum_exp.onnx"
|
||||
|
||||
onnx.save(model, original_model_path)
|
||||
optimized_model = optimize_model(original_model_path, model_type="gpt2")
|
||||
optimized_model.save_model_to_file(optimized_model_path, use_external_data_format=True)
|
||||
|
||||
self.verify_fusion(optimized_model, expected_model_filename)
|
||||
self.verify_parity(optimized_model_path, expected_model_filename)
|
||||
os.remove(original_model_path)
|
||||
os.remove(optimized_model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue