fix embed layer norm fusion with embedding sum output (#17460)

The embedding sum could be graph output (when exporting with output
hidden state enabled). Previously, we only check whether there are
multiple children node to decide whether to output embedding sum in
fused node. This fix will check if the sum is graph output, we will
retain the name.
This commit is contained in:
Tianlei Wu 2023-09-07 22:01:26 -07:00 committed by GitHub
parent 9017ea131b
commit 7bc6dcecf7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 126 additions and 43 deletions

View file

@ -378,7 +378,7 @@ class FusionEmbedLayerNoMask(Fusion):
logger.info("Cannot fuse EmbedLayerNormalization: segment embedding table is not expected")
return False
# In normal case, word embeding table is the largest, and segment embedding table is the smallest, while postion embedding table is in between.
# In normal case, word embedding table is the largest, and segment embedding table is the smallest, while position embedding table is in between.
# TODO: use other information (like initializer names) to identify different embedding weights automatically.
if word_embedding_table.shape[0] <= position_embedding_table.shape[0]:
logger.warning(
@ -430,6 +430,7 @@ class FusionEmbedLayerNoMask(Fusion):
segment_embedding_gather: Union[None, NodeProto],
position_ids: Optional[str] = None,
embedding_sum_output=False,
embedding_sum_name=None,
):
"""Create an EmbedLayerNormalization node. Note that segment embedding is optional.
@ -487,7 +488,8 @@ class FusionEmbedLayerNoMask(Fusion):
embed_node_outputs = [node_name + "_output", node_name + "_dummy_mask_index"]
if embedding_sum_output:
embed_node_outputs.append(node_name + "_embedding_sum")
name = embedding_sum_name if embedding_sum_name is not None else node_name + "_embedding_sum"
embed_node_outputs.append(name)
embed_node = helper.make_node(
"EmbedLayerNormalization",
@ -522,19 +524,8 @@ class FusionEmbedLayerNoMask(Fusion):
# use prune graph to remove nodes that is not needed
self.prune_graph = True
def is_embedding_sum_needed(self, add_before_layer_norm):
"""Check that Add before layer norm has an output to add before next layernorm
Args:
add_before_layer_norm (NodeProto): Add before any LayerNormalization node in topological order of graph
Returns:
bool: whether there is an extra output needed out of embed layer norm node
"""
nodes = self.model.get_children(add_before_layer_norm)
return len(nodes) > 1
def is_skip_layer_norm_with_sum_output(self, node):
return (node.op_type == "SkipLayerNormalization") and len(node.output) > 3 and len(node.output[3]) > 0
def fuse_gpt2(
self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather=None
@ -570,21 +561,31 @@ class FusionEmbedLayerNoMask(Fusion):
if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
return False
# If the add_before_layernorm node is an Add node, then the add_output output is the first index
# output of this node.
# If the add_before_layernorm node is SkipLayerNormalization node, then the add_output output
# If layernorm node is SkipLayerNormalization, we need look at its optional fourth output.
# If the add_before_layernorm node is an Add node, then the add_output output is the first output of this node.
# If the add_before_layernorm node is a SkipLayerNormalization node, then the add_output output
# is the (optional) fourth index output of this node.
add_output = None
optional_embedding_sum_output = False
if (add_before_layernorm.op_type == "Add" and self.is_embedding_sum_needed(add_before_layernorm)) or (
add_before_layernorm.op_type == "SkipLayerNormalization" and len(add_before_layernorm.output) >= 4
):
optional_embedding_sum_output = True
add_output = (
add_before_layernorm.output[0]
if add_before_layernorm.op_type == "Add"
else add_before_layernorm.output[3]
# When add_before_layernorm is SkipLayerNormalization, add_before_layernorm and layernorm are same node.
if layernorm.op_type == "SkipLayerNormalization":
need_embedding_sum_output = self.is_skip_layer_norm_with_sum_output(layernorm)
sum_output_index = 3
node_with_sum_output = layernorm
sum_output = layernorm.output[3] if need_embedding_sum_output else None
is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
else: # layernorm.op_type == "LayerNormalization"
node_with_sum_output = add_before_layernorm
sum_output_index = 0 if add_before_layernorm.op_type == "Add" else 3
sum_output = (
add_before_layernorm.output[sum_output_index]
if len(add_before_layernorm.output) > sum_output_index
else None
)
is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
is_sum_used_by_multiple_nodes = (
sum_output and (sum_output in input_name_to_nodes) and len(input_name_to_nodes[sum_output]) > 1
)
need_embedding_sum_output = (sum_output is not None) and (
add_before_layernorm.op_type != "Add" or is_sum_graph_output or is_sum_used_by_multiple_nodes
)
# make the fused node
@ -595,14 +596,16 @@ class FusionEmbedLayerNoMask(Fusion):
position_embedding_gather,
optional_segment_gather,
position_ids,
optional_embedding_sum_output,
embedding_sum_output=need_embedding_sum_output,
embedding_sum_name=sum_output if is_sum_graph_output else None,
)
# direct the output to another add too
self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0])
if optional_embedding_sum_output:
self.model.replace_input_of_all_nodes(add_output, embed_node.output[2])
if need_embedding_sum_output:
node_with_sum_output.output[sum_output_index] = "_no_use__to_be_removed_"
if not is_sum_graph_output:
self.model.replace_input_of_all_nodes(sum_output, embed_node.output[2])
self.finish_fusion(layernorm, embed_node)
return True
def fuse_distilbert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
@ -707,9 +710,14 @@ class FusionEmbedLayerNoMask(Fusion):
gather_0_path = self.model.match_parent_path(node, ["Gather"], [0])
gather_1_path = self.model.match_parent_path(node, ["Gather"], [1])
if gather_0_path is None and gather_1_path is not None:
if first_add_path is None:
return
add_before_layernorm = first_add_path[0]
optional_segment_gather = gather_1_path[0]
elif gather_0_path is not None and gather_1_path is None:
first_add_path = self.model.match_parent_path(node, ["Add"], [1])
if first_add_path is None:
return
add_before_layernorm = first_add_path[0]
optional_segment_gather = gather_0_path[0]
else:

View file

@ -555,6 +555,8 @@ def create_gpt2_embedlayer(
num_heads=4,
epsilon=0.1,
one_attention_node=False,
has_skip_layer_norm=True,
output_embedding_sum=False,
):
# Construct input and output nodes
inputs = [
@ -564,21 +566,47 @@ def create_gpt2_embedlayer(
helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size])
]
if output_embedding_sum:
outputs.append(
helper.make_tensor_value_info(
"embedding_sum", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]
)
)
# Construct graph nodes
embed_layernorm_nodes = [
helper.make_node("Gather", ["word_embeddings_weight", "ids"], ["gather_0_out"], "gather_word_embeddings"),
helper.make_node("Gather", ["pos_embeddings_weight", "ids"], ["gather_1_out"], "gather_position_embeddings"),
helper.make_node("Add", ["gather_0_out", "gather_1_out"], ["add_0_out"], "add_before_layernorm"),
helper.make_node("Gather", ["token_embeddings_weight", "ids"], ["gather_2_out"], "gather_token_embeddings"),
helper.make_node(
"SkipLayerNormalization",
["add_0_out", "gather_2_out", "layernorm_weight", "layernorm_bias"],
["skip_layernorm_out"],
"skip_layernorm",
domain="com.microsoft",
epsilon=epsilon,
),
]
if has_skip_layer_norm:
embed_layernorm_nodes.append(
helper.make_node(
"SkipLayerNormalization",
["add_0_out", "gather_2_out", "layernorm_weight", "layernorm_bias"],
["skip_layernorm_out"] if not output_embedding_sum else ["skip_layernorm_out", "", "", "embedding_sum"],
"skip_layernorm",
domain="com.microsoft",
epsilon=epsilon,
)
)
else:
embed_layernorm_nodes.append(
helper.make_node("Add", ["add_0_out", "gather_2_out"], ["embedding_sum"], "embedding_sum")
)
embed_layernorm_nodes.append(
helper.make_node(
"LayerNormalization",
["embedding_sum", "layernorm_weight", "layernorm_bias"],
["skip_layernorm_out"],
"layernorm",
epsilon=epsilon,
)
)
attention_nodes = (
[
helper.make_node("MatMul", ["skip_layernorm_out", "q_weight"], ["q_out"], "q_attn"),
@ -708,6 +736,7 @@ def create_gpt2_fused_embedlayer(
num_heads=4,
epsilon=0.1,
one_attention_node=False,
output_embedding_sum=False,
):
# Construct input and output nodes
inputs = [
@ -716,6 +745,12 @@ def create_gpt2_fused_embedlayer(
outputs = [
helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size])
]
if output_embedding_sum:
outputs.append(
helper.make_tensor_value_info(
"embedding_sum", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]
)
)
# Construct graph nodes
embed_layernorm_nodes = [
@ -732,7 +767,9 @@ def create_gpt2_fused_embedlayer(
"",
"ids",
],
["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index"],
["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index", "embedding_sum"]
if output_embedding_sum
else ["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index"],
"EmbedLayerNormalization_0",
domain="com.microsoft",
epsilon=epsilon,
@ -876,3 +913,9 @@ if __name__ == "__main__":
model = create_gpt2_fused_embedlayer(one_attention_node=True)
onnx.save(model, "./test_data/models/gpt2_embedlayer_one_attn_exp.onnx")
model = create_gpt2_embedlayer(one_attention_node=True, output_embedding_sum=True)
onnx.save(model, "gpt2_embedlayer_one_attn_output_sum.onnx")
model = create_gpt2_fused_embedlayer(one_attention_node=True, output_embedding_sum=True)
onnx.save(model, "./test_data/models/gpt2_embedlayer_one_attn_output_sum_exp.onnx")

View file

@ -74,6 +74,38 @@ class TestFusion(unittest.TestCase):
os.remove(original_model_path)
os.remove(optimized_model_path)
def test_embedlayer_fusion_with_embedding_sum_output(self):
model = create_gpt2_embedlayer(one_attention_node=True, output_embedding_sum=True)
path = "."
original_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum.onnx")
optimized_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum_opt.onnx")
expected_model_filename = "gpt2_embedlayer_one_attn_output_sum_exp.onnx"
onnx.save(model, original_model_path)
optimized_model = optimize_model(original_model_path, model_type="gpt2")
optimized_model.save_model_to_file(optimized_model_path, use_external_data_format=True)
self.verify_fusion(optimized_model, expected_model_filename)
self.verify_parity(optimized_model_path, expected_model_filename)
os.remove(original_model_path)
os.remove(optimized_model_path)
def test_embedlayer_fusion_with_embedding_sum_output_no_sln(self):
model = create_gpt2_embedlayer(one_attention_node=True, has_skip_layer_norm=False, output_embedding_sum=True)
path = "."
original_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum_no_sln.onnx")
optimized_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum_no_sln_opt.onnx")
expected_model_filename = "gpt2_embedlayer_one_attn_output_sum_exp.onnx"
onnx.save(model, original_model_path)
optimized_model = optimize_model(original_model_path, model_type="gpt2")
optimized_model.save_model_to_file(optimized_model_path, use_external_data_format=True)
self.verify_fusion(optimized_model, expected_model_filename)
self.verify_parity(optimized_model_path, expected_model_filename)
os.remove(original_model_path)
os.remove(optimized_model_path)
if __name__ == "__main__":
unittest.main()