Update pattern matching for EmbedLayerNormalization fusion (#14344)

### Description
This PR addresses the case where an optional Gather node is in the
subgraph pattern. The optional node is now fused with the other nodes
matched in the pattern to create an EmbedLayerNormalization node.



### Motivation and Context
The original subgraph pattern is
```
                      Gather    Gather
                           \   /
                            Add
                             |           
                     LayerNormalization
                             |           
                          Attention
                             |  
                            ...
```
and the new subgraph pattern is
```
                      Gather    Gather
                           \   /
   Gather (optional)        Add
                   \         |           
                     LayerNormalization
                             |           
                          Attention
                             |  
                            ...
```
This commit is contained in:
kunal-vaishnavi 2023-02-22 12:57:14 -08:00 committed by GitHub
parent b3b9be19b1
commit 460b3ff4fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 439 additions and 21 deletions

View file

@ -531,24 +531,26 @@ class FusionEmbedLayerNoMask(Fusion):
return len(nodes) > 1
def fuse_gpt2(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
def fuse_gpt2(
self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather=None
):
# graph checks
# gpt2 has no segment embedding, subgraph pattern is like
# input_ids position_ids
# | |
# Gather Gather
# \ /
# Add _ _ _ _ _
# | |
# LayerNormalization |
# | |
# Attention |
# | |
# Matmul |
# | /
# Add /
# \ /
# Add
# gpt2 has optional segment embedding, subgraph pattern is like
# input_ids position_ids
# | |
# token_ids Gather Gather
# | \ /
# Gather (optional) Add _ _ _ _ _
# \ | |
# LayerNormalization |
# | |
# Attention |
# | |
# Matmul |
# | /
# Add /
# \ /
# Add
two_gather = self.match_two_gather(add_before_layernorm)
if two_gather is None:
return False
@ -586,7 +588,7 @@ class FusionEmbedLayerNoMask(Fusion):
layernorm,
word_embedding_gather,
position_embedding_gather,
None,
optional_segment_gather,
position_ids,
optional_embedding_sum_output,
)
@ -690,15 +692,28 @@ class FusionEmbedLayerNoMask(Fusion):
return True
def fuse(self, node, input_name_to_nodes, output_name_to_node):
first_add_path = self.model.match_parent_path(node, ["Add"], [0])
if node.op_type == "LayerNormalization":
first_add_path = self.model.match_parent_path(node, ["Add"], [0])
if first_add_path is None:
return
add_before_layernorm = first_add_path[0]
optional_segment_gather = None
else: # SkipLayerNormalization
add_before_layernorm = node # Add is fused into SkipLayerNormalization
gather_0_path = self.model.match_parent_path(node, ["Gather"], [0])
gather_1_path = self.model.match_parent_path(node, ["Gather"], [1])
if gather_0_path is None and gather_1_path is not None:
add_before_layernorm = first_add_path[0]
optional_segment_gather = gather_1_path[0]
elif gather_0_path is not None and gather_1_path is None:
add_before_layernorm = first_add_path[0]
optional_segment_gather = gather_0_path[0]
else:
add_before_layernorm = node # Add is fused into SkipLayerNormalization
optional_segment_gather = None
if self.fuse_gpt2(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
if self.fuse_gpt2(
node, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather
):
return
if self.fuse_distilbert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):

View file

@ -549,9 +549,333 @@ def create_gpt2_attention(hidden_size=64, num_heads=4, max_seq_len=32, switch_ad
return helper.make_model(graph, opset_imports=(opsetid,))
def create_gpt2_embedlayer(
pos_embed=100,
word_embed=101,
token_embed=10,
hidden_size=768,
attn_hidden_dim=256,
num_heads=4,
epsilon=0.1,
one_attention_node=False,
):
# Construct input and output nodes
inputs = [
helper.make_tensor_value_info("ids", TensorProto.INT32, ["batch_size", "sequence_length"]),
]
outputs = [
helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size])
]
# Construct graph nodes
embed_layernorm_nodes = [
helper.make_node("Gather", ["word_embeddings_weight", "ids"], ["gather_0_out"], "gather_word_embeddings"),
helper.make_node("Gather", ["pos_embeddings_weight", "ids"], ["gather_1_out"], "gather_position_embeddings"),
helper.make_node("Add", ["gather_0_out", "gather_1_out"], ["add_0_out"], "add_before_layernorm"),
helper.make_node("Gather", ["token_embeddings_weight", "ids"], ["gather_2_out"], "gather_token_embeddings"),
helper.make_node(
"SkipLayerNormalization",
["add_0_out", "gather_2_out", "layernorm_weight", "layernorm_bias"],
["skip_layernorm_out"],
"skip_layernorm",
domain="com.microsoft",
epsilon=epsilon,
),
]
attention_nodes = (
[
helper.make_node("MatMul", ["skip_layernorm_out", "q_weight"], ["q_out"], "q_attn"),
helper.make_node("MatMul", ["skip_layernorm_out", "k_weight"], ["k_out"], "k_attn"),
helper.make_node("MatMul", ["skip_layernorm_out", "v_weight"], ["v_out"], "v_attn"),
helper.make_node("Add", ["q_out", "k_out"], ["qk_out"], "qk_attn"),
helper.make_node("Add", ["qk_out", "v_out"], ["qkv_out"], "qkv_attn"),
]
if not one_attention_node
else [
helper.make_node(
"Attention",
["skip_layernorm_out", "qkv_weight", "qkv_bias", ""],
["attn_out"],
"qkv_attn",
domain="com.microsoft",
num_heads=num_heads,
),
helper.make_node(
"MatMul",
["attn_out", "fix_hidden_size"],
["qkv_out"],
"matmul_after_attn",
),
]
)
nodes = [
helper.make_node(
"SkipLayerNormalization",
["skip_layernorm_out", "qkv_out", "layernorm_weight", "layernorm_bias", "dense_bias"],
["output_0"],
"attn_skip_layernorm",
domain="com.microsoft",
epsilon=epsilon,
),
]
nodes.extend(embed_layernorm_nodes)
nodes.extend(attention_nodes)
# Construct data initializers for graph nodes
embed_layernorm_initializers = [
helper.make_tensor(
"word_embeddings_weight",
TensorProto.FLOAT,
[word_embed, hidden_size],
[(i + 1) / (word_embed * hidden_size) for i in range(word_embed * hidden_size)],
),
helper.make_tensor(
"pos_embeddings_weight",
TensorProto.FLOAT,
[pos_embed, hidden_size],
[(i + 2) / (pos_embed * hidden_size) for i in range(pos_embed * hidden_size)],
),
helper.make_tensor(
"token_embeddings_weight",
TensorProto.FLOAT,
[token_embed, hidden_size],
[(i + 3) / (token_embed * hidden_size) for i in range(token_embed * hidden_size)],
),
helper.make_tensor(
"layernorm_weight", TensorProto.FLOAT, [hidden_size], [(i + 4) / hidden_size for i in range(hidden_size)]
),
helper.make_tensor(
"layernorm_bias", TensorProto.FLOAT, [hidden_size], [(i + 5) / hidden_size for i in range(hidden_size)]
),
]
attention_initializers = (
[
helper.make_tensor(
"q_weight",
TensorProto.FLOAT,
[hidden_size, hidden_size],
[(i + 6) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
),
helper.make_tensor(
"k_weight",
TensorProto.FLOAT,
[hidden_size, hidden_size],
[(i + 7) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
),
helper.make_tensor(
"v_weight",
TensorProto.FLOAT,
[hidden_size, hidden_size],
[(i + 8) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
),
]
if not one_attention_node
else [
helper.make_tensor(
"qkv_weight",
TensorProto.FLOAT,
[hidden_size, hidden_size],
[(i + 9) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
),
helper.make_tensor(
"qkv_bias", TensorProto.FLOAT, [hidden_size], [(i + 10) / hidden_size for i in range(hidden_size)]
),
helper.make_tensor(
"fix_hidden_size",
TensorProto.FLOAT,
[attn_hidden_dim, hidden_size],
[(i + 11) / (attn_hidden_dim * hidden_size) for i in range(attn_hidden_dim * hidden_size)],
),
]
)
initializers = [
helper.make_tensor(
"dense_bias", TensorProto.FLOAT, [hidden_size], [(i + 12) / hidden_size for i in range(hidden_size)]
),
]
initializers.extend(embed_layernorm_initializers)
initializers.extend(attention_initializers)
# Construct graph
graph = helper.make_graph(nodes, "GPT2_embedlayer_graph", inputs, outputs, initializers)
opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16))
return helper.make_model(graph, opset_imports=(opsetid,))
def create_gpt2_fused_embedlayer(
pos_embed=100,
word_embed=101,
token_embed=10,
hidden_size=768,
attn_hidden_dim=256,
num_heads=4,
epsilon=0.1,
one_attention_node=False,
):
# Construct input and output nodes
inputs = [
helper.make_tensor_value_info("ids", TensorProto.INT32, ["batch_size", "sequence_length"]),
]
outputs = [
helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size])
]
# Construct graph nodes
embed_layernorm_nodes = [
helper.make_node(
"EmbedLayerNormalization",
[
"ids",
"ids",
"word_embeddings_weight",
"pos_embeddings_weight",
"token_embeddings_weight",
"layernorm_weight",
"layernorm_bias",
"",
"ids",
],
["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index"],
"EmbedLayerNormalization_0",
domain="com.microsoft",
epsilon=epsilon,
),
]
attention_nodes = (
[
helper.make_node("MatMul", ["EmbedLayerNormalization_0_output", "q_weight"], ["q_out"], "q_attn"),
helper.make_node("MatMul", ["EmbedLayerNormalization_0_output", "k_weight"], ["k_out"], "k_attn"),
helper.make_node("MatMul", ["EmbedLayerNormalization_0_output", "v_weight"], ["v_out"], "v_attn"),
helper.make_node("Add", ["q_out", "k_out"], ["qk_out"], "qk_attn"),
helper.make_node("Add", ["qk_out", "v_out"], ["qkv_out"], "qkv_attn"),
]
if not one_attention_node
else [
helper.make_node(
"Attention",
["EmbedLayerNormalization_0_output", "qkv_weight", "qkv_bias", ""],
["attn_out"],
"qkv_attn",
domain="com.microsoft",
num_heads=num_heads,
),
helper.make_node(
"MatMul",
["attn_out", "fix_hidden_size"],
["qkv_out"],
"matmul_after_attn",
),
]
)
nodes = [
helper.make_node(
"SkipLayerNormalization",
["EmbedLayerNormalization_0_output", "qkv_out", "layernorm_weight", "layernorm_bias", "dense_bias"],
["output_0"],
"attn_skip_layernorm",
domain="com.microsoft",
epsilon=epsilon,
),
]
nodes.extend(embed_layernorm_nodes)
nodes.extend(attention_nodes)
# Construct data initializers for graph nodes
embed_layernorm_initializers = [
helper.make_tensor(
"word_embeddings_weight",
TensorProto.FLOAT,
[word_embed, hidden_size],
[(i + 1) / (word_embed * hidden_size) for i in range(word_embed * hidden_size)],
),
helper.make_tensor(
"pos_embeddings_weight",
TensorProto.FLOAT,
[pos_embed, hidden_size],
[(i + 2) / (pos_embed * hidden_size) for i in range(pos_embed * hidden_size)],
),
helper.make_tensor(
"token_embeddings_weight",
TensorProto.FLOAT,
[token_embed, hidden_size],
[(i + 3) / (token_embed * hidden_size) for i in range(token_embed * hidden_size)],
),
helper.make_tensor(
"layernorm_weight", TensorProto.FLOAT, [hidden_size], [(i + 4) / hidden_size for i in range(hidden_size)]
),
helper.make_tensor(
"layernorm_bias", TensorProto.FLOAT, [hidden_size], [(i + 5) / hidden_size for i in range(hidden_size)]
),
]
attention_initializers = (
[
helper.make_tensor(
"q_weight",
TensorProto.FLOAT,
[hidden_size, hidden_size],
[(i + 6) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
),
helper.make_tensor(
"k_weight",
TensorProto.FLOAT,
[hidden_size, hidden_size],
[(i + 7) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
),
helper.make_tensor(
"v_weight",
TensorProto.FLOAT,
[hidden_size, hidden_size],
[(i + 8) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
),
]
if not one_attention_node
else [
helper.make_tensor(
"qkv_weight",
TensorProto.FLOAT,
[hidden_size, hidden_size],
[(i + 9) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
),
helper.make_tensor(
"qkv_bias", TensorProto.FLOAT, [hidden_size], [(i + 10) / hidden_size for i in range(hidden_size)]
),
helper.make_tensor(
"fix_hidden_size",
TensorProto.FLOAT,
[attn_hidden_dim, hidden_size],
[(i + 11) / (attn_hidden_dim * hidden_size) for i in range(attn_hidden_dim * hidden_size)],
),
]
)
initializers = [
helper.make_tensor(
"dense_bias", TensorProto.FLOAT, [hidden_size], [(i + 12) / hidden_size for i in range(hidden_size)]
),
]
initializers.extend(embed_layernorm_initializers)
initializers.extend(attention_initializers)
# Construct graph
graph = helper.make_graph(nodes, "GPT2_embedlayer_graph", inputs, outputs, initializers)
opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16))
return helper.make_model(graph, opset_imports=(opsetid,))
if __name__ == "__main__":
model = create_gpt2_attention()
onnx.save(model, "gpt2_attention.onnx")
model = create_gpt2_attention(switch_add_inputs=True)
onnx.save(model, "gpt2_attention_add.onnx")
model = create_gpt2_embedlayer()
onnx.save(model, "gpt2_embedlayer.onnx")
model = create_gpt2_fused_embedlayer()
onnx.save(model, "gpt2_embedlayer_exp.onnx")
model = create_gpt2_embedlayer(one_attention_node=True)
onnx.save(model, "gpt2_embedlayer_one_attn.onnx")
model = create_gpt2_fused_embedlayer(one_attention_node=True)
onnx.save(model, "gpt2_embedlayer_one_attn_exp.onnx")

View file

@ -0,0 +1,79 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import os
import unittest
import numpy as np
import onnx
from gpt2_model_generator import create_gpt2_embedlayer
from parity_utilities import find_transformers_source
from onnxruntime import InferenceSession
if find_transformers_source():
from onnx_model import OnnxModel
from optimizer import optimize_model
else:
from onnxruntime.transformers.onnx_model import OnnxModel
from onnxruntime.transformers.optimizer import optimize_model
class TestFusion(unittest.TestCase):
def verify_fusion(self, optimized_model, expected_model_filename):
optimized_model.topological_sort()
expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename)
expected_model = OnnxModel(onnx.load(expected_model_path))
expected_model.topological_sort()
self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph))
def verify_parity(self, optimized_model_path, expected_model):
expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model)
sess_optimized = InferenceSession(optimized_model_path, providers=["CPUExecutionProvider"])
sess_expected = InferenceSession(expected_model_path, providers=["CPUExecutionProvider"])
inputs = np.random.randint(low=0, high=6, size=(4, 8), dtype=np.int32) + 1
outputs_optimized = sess_optimized.run(None, {"ids": inputs})
outputs_expected = sess_expected.run(None, {"ids": inputs})
self.assertTrue(np.allclose(outputs_optimized[0], outputs_expected[0]))
def test_embedlayer_fusion(self):
model = create_gpt2_embedlayer(one_attention_node=False)
path = "."
original_model_path = os.path.join(path, "gpt2_embedlayer.onnx")
optimized_model_path = os.path.join(path, "gpt2_embedlayer_opt.onnx")
expected_model_filename = "gpt2_embedlayer_exp.onnx"
onnx.save(model, original_model_path)
optimized_model = optimize_model(original_model_path, model_type="gpt2")
optimized_model.save_model_to_file(optimized_model_path, use_external_data_format=True)
self.verify_fusion(optimized_model, expected_model_filename)
self.verify_parity(optimized_model_path, expected_model_filename)
os.remove(original_model_path)
os.remove(optimized_model_path)
def test_embedlayer_fusion_one_attn_node(self):
model = create_gpt2_embedlayer(one_attention_node=True)
path = "."
original_model_path = os.path.join(path, "gpt2_embedlayer_one_attn.onnx")
optimized_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_opt.onnx")
expected_model_filename = "gpt2_embedlayer_one_attn_exp.onnx"
onnx.save(model, original_model_path)
optimized_model = optimize_model(original_model_path, model_type="gpt2")
optimized_model.save_model_to_file(optimized_model_path, use_external_data_format=True)
self.verify_fusion(optimized_model, expected_model_filename)
self.verify_parity(optimized_model_path, expected_model_filename)
os.remove(original_model_path)
os.remove(optimized_model_path)
if __name__ == "__main__":
unittest.main()