mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
Update pattern matching for EmbedLayerNormalization fusion (#14344)
### Description
This PR addresses the case where an optional Gather node is in the
subgraph pattern. The optional node is now fused with the other nodes
matched in the pattern to create an EmbedLayerNormalization node.
### Motivation and Context
The original subgraph pattern is
```
Gather Gather
\ /
Add
|
LayerNormalization
|
Attention
|
...
```
and the new subgraph pattern is
```
Gather Gather
\ /
Gather (optional) Add
\ |
LayerNormalization
|
Attention
|
...
```
This commit is contained in:
parent
b3b9be19b1
commit
460b3ff4fd
5 changed files with 439 additions and 21 deletions
|
|
@ -531,24 +531,26 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
|
||||
return len(nodes) > 1
|
||||
|
||||
def fuse_gpt2(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
|
||||
def fuse_gpt2(
|
||||
self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather=None
|
||||
):
|
||||
# graph checks
|
||||
# gpt2 has no segment embedding, subgraph pattern is like
|
||||
# input_ids position_ids
|
||||
# | |
|
||||
# Gather Gather
|
||||
# \ /
|
||||
# Add _ _ _ _ _
|
||||
# | |
|
||||
# LayerNormalization |
|
||||
# | |
|
||||
# Attention |
|
||||
# | |
|
||||
# Matmul |
|
||||
# | /
|
||||
# Add /
|
||||
# \ /
|
||||
# Add
|
||||
# gpt2 has optional segment embedding, subgraph pattern is like
|
||||
# input_ids position_ids
|
||||
# | |
|
||||
# token_ids Gather Gather
|
||||
# | \ /
|
||||
# Gather (optional) Add _ _ _ _ _
|
||||
# \ | |
|
||||
# LayerNormalization |
|
||||
# | |
|
||||
# Attention |
|
||||
# | |
|
||||
# Matmul |
|
||||
# | /
|
||||
# Add /
|
||||
# \ /
|
||||
# Add
|
||||
two_gather = self.match_two_gather(add_before_layernorm)
|
||||
if two_gather is None:
|
||||
return False
|
||||
|
|
@ -586,7 +588,7 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
layernorm,
|
||||
word_embedding_gather,
|
||||
position_embedding_gather,
|
||||
None,
|
||||
optional_segment_gather,
|
||||
position_ids,
|
||||
optional_embedding_sum_output,
|
||||
)
|
||||
|
|
@ -690,15 +692,28 @@ class FusionEmbedLayerNoMask(Fusion):
|
|||
return True
|
||||
|
||||
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
||||
first_add_path = self.model.match_parent_path(node, ["Add"], [0])
|
||||
if node.op_type == "LayerNormalization":
|
||||
first_add_path = self.model.match_parent_path(node, ["Add"], [0])
|
||||
if first_add_path is None:
|
||||
return
|
||||
add_before_layernorm = first_add_path[0]
|
||||
optional_segment_gather = None
|
||||
else: # SkipLayerNormalization
|
||||
add_before_layernorm = node # Add is fused into SkipLayerNormalization
|
||||
gather_0_path = self.model.match_parent_path(node, ["Gather"], [0])
|
||||
gather_1_path = self.model.match_parent_path(node, ["Gather"], [1])
|
||||
if gather_0_path is None and gather_1_path is not None:
|
||||
add_before_layernorm = first_add_path[0]
|
||||
optional_segment_gather = gather_1_path[0]
|
||||
elif gather_0_path is not None and gather_1_path is None:
|
||||
add_before_layernorm = first_add_path[0]
|
||||
optional_segment_gather = gather_0_path[0]
|
||||
else:
|
||||
add_before_layernorm = node # Add is fused into SkipLayerNormalization
|
||||
optional_segment_gather = None
|
||||
|
||||
if self.fuse_gpt2(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
|
||||
if self.fuse_gpt2(
|
||||
node, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather
|
||||
):
|
||||
return
|
||||
|
||||
if self.fuse_distilbert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
|
||||
|
|
|
|||
|
|
@ -549,9 +549,333 @@ def create_gpt2_attention(hidden_size=64, num_heads=4, max_seq_len=32, switch_ad
|
|||
return helper.make_model(graph, opset_imports=(opsetid,))
|
||||
|
||||
|
||||
def create_gpt2_embedlayer(
|
||||
pos_embed=100,
|
||||
word_embed=101,
|
||||
token_embed=10,
|
||||
hidden_size=768,
|
||||
attn_hidden_dim=256,
|
||||
num_heads=4,
|
||||
epsilon=0.1,
|
||||
one_attention_node=False,
|
||||
):
|
||||
# Construct input and output nodes
|
||||
inputs = [
|
||||
helper.make_tensor_value_info("ids", TensorProto.INT32, ["batch_size", "sequence_length"]),
|
||||
]
|
||||
outputs = [
|
||||
helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size])
|
||||
]
|
||||
|
||||
# Construct graph nodes
|
||||
embed_layernorm_nodes = [
|
||||
helper.make_node("Gather", ["word_embeddings_weight", "ids"], ["gather_0_out"], "gather_word_embeddings"),
|
||||
helper.make_node("Gather", ["pos_embeddings_weight", "ids"], ["gather_1_out"], "gather_position_embeddings"),
|
||||
helper.make_node("Add", ["gather_0_out", "gather_1_out"], ["add_0_out"], "add_before_layernorm"),
|
||||
helper.make_node("Gather", ["token_embeddings_weight", "ids"], ["gather_2_out"], "gather_token_embeddings"),
|
||||
helper.make_node(
|
||||
"SkipLayerNormalization",
|
||||
["add_0_out", "gather_2_out", "layernorm_weight", "layernorm_bias"],
|
||||
["skip_layernorm_out"],
|
||||
"skip_layernorm",
|
||||
domain="com.microsoft",
|
||||
epsilon=epsilon,
|
||||
),
|
||||
]
|
||||
attention_nodes = (
|
||||
[
|
||||
helper.make_node("MatMul", ["skip_layernorm_out", "q_weight"], ["q_out"], "q_attn"),
|
||||
helper.make_node("MatMul", ["skip_layernorm_out", "k_weight"], ["k_out"], "k_attn"),
|
||||
helper.make_node("MatMul", ["skip_layernorm_out", "v_weight"], ["v_out"], "v_attn"),
|
||||
helper.make_node("Add", ["q_out", "k_out"], ["qk_out"], "qk_attn"),
|
||||
helper.make_node("Add", ["qk_out", "v_out"], ["qkv_out"], "qkv_attn"),
|
||||
]
|
||||
if not one_attention_node
|
||||
else [
|
||||
helper.make_node(
|
||||
"Attention",
|
||||
["skip_layernorm_out", "qkv_weight", "qkv_bias", ""],
|
||||
["attn_out"],
|
||||
"qkv_attn",
|
||||
domain="com.microsoft",
|
||||
num_heads=num_heads,
|
||||
),
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
["attn_out", "fix_hidden_size"],
|
||||
["qkv_out"],
|
||||
"matmul_after_attn",
|
||||
),
|
||||
]
|
||||
)
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
"SkipLayerNormalization",
|
||||
["skip_layernorm_out", "qkv_out", "layernorm_weight", "layernorm_bias", "dense_bias"],
|
||||
["output_0"],
|
||||
"attn_skip_layernorm",
|
||||
domain="com.microsoft",
|
||||
epsilon=epsilon,
|
||||
),
|
||||
]
|
||||
nodes.extend(embed_layernorm_nodes)
|
||||
nodes.extend(attention_nodes)
|
||||
|
||||
# Construct data initializers for graph nodes
|
||||
embed_layernorm_initializers = [
|
||||
helper.make_tensor(
|
||||
"word_embeddings_weight",
|
||||
TensorProto.FLOAT,
|
||||
[word_embed, hidden_size],
|
||||
[(i + 1) / (word_embed * hidden_size) for i in range(word_embed * hidden_size)],
|
||||
),
|
||||
helper.make_tensor(
|
||||
"pos_embeddings_weight",
|
||||
TensorProto.FLOAT,
|
||||
[pos_embed, hidden_size],
|
||||
[(i + 2) / (pos_embed * hidden_size) for i in range(pos_embed * hidden_size)],
|
||||
),
|
||||
helper.make_tensor(
|
||||
"token_embeddings_weight",
|
||||
TensorProto.FLOAT,
|
||||
[token_embed, hidden_size],
|
||||
[(i + 3) / (token_embed * hidden_size) for i in range(token_embed * hidden_size)],
|
||||
),
|
||||
helper.make_tensor(
|
||||
"layernorm_weight", TensorProto.FLOAT, [hidden_size], [(i + 4) / hidden_size for i in range(hidden_size)]
|
||||
),
|
||||
helper.make_tensor(
|
||||
"layernorm_bias", TensorProto.FLOAT, [hidden_size], [(i + 5) / hidden_size for i in range(hidden_size)]
|
||||
),
|
||||
]
|
||||
attention_initializers = (
|
||||
[
|
||||
helper.make_tensor(
|
||||
"q_weight",
|
||||
TensorProto.FLOAT,
|
||||
[hidden_size, hidden_size],
|
||||
[(i + 6) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
|
||||
),
|
||||
helper.make_tensor(
|
||||
"k_weight",
|
||||
TensorProto.FLOAT,
|
||||
[hidden_size, hidden_size],
|
||||
[(i + 7) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
|
||||
),
|
||||
helper.make_tensor(
|
||||
"v_weight",
|
||||
TensorProto.FLOAT,
|
||||
[hidden_size, hidden_size],
|
||||
[(i + 8) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
|
||||
),
|
||||
]
|
||||
if not one_attention_node
|
||||
else [
|
||||
helper.make_tensor(
|
||||
"qkv_weight",
|
||||
TensorProto.FLOAT,
|
||||
[hidden_size, hidden_size],
|
||||
[(i + 9) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
|
||||
),
|
||||
helper.make_tensor(
|
||||
"qkv_bias", TensorProto.FLOAT, [hidden_size], [(i + 10) / hidden_size for i in range(hidden_size)]
|
||||
),
|
||||
helper.make_tensor(
|
||||
"fix_hidden_size",
|
||||
TensorProto.FLOAT,
|
||||
[attn_hidden_dim, hidden_size],
|
||||
[(i + 11) / (attn_hidden_dim * hidden_size) for i in range(attn_hidden_dim * hidden_size)],
|
||||
),
|
||||
]
|
||||
)
|
||||
initializers = [
|
||||
helper.make_tensor(
|
||||
"dense_bias", TensorProto.FLOAT, [hidden_size], [(i + 12) / hidden_size for i in range(hidden_size)]
|
||||
),
|
||||
]
|
||||
initializers.extend(embed_layernorm_initializers)
|
||||
initializers.extend(attention_initializers)
|
||||
|
||||
# Construct graph
|
||||
graph = helper.make_graph(nodes, "GPT2_embedlayer_graph", inputs, outputs, initializers)
|
||||
opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16))
|
||||
return helper.make_model(graph, opset_imports=(opsetid,))
|
||||
|
||||
|
||||
def create_gpt2_fused_embedlayer(
|
||||
pos_embed=100,
|
||||
word_embed=101,
|
||||
token_embed=10,
|
||||
hidden_size=768,
|
||||
attn_hidden_dim=256,
|
||||
num_heads=4,
|
||||
epsilon=0.1,
|
||||
one_attention_node=False,
|
||||
):
|
||||
# Construct input and output nodes
|
||||
inputs = [
|
||||
helper.make_tensor_value_info("ids", TensorProto.INT32, ["batch_size", "sequence_length"]),
|
||||
]
|
||||
outputs = [
|
||||
helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size])
|
||||
]
|
||||
|
||||
# Construct graph nodes
|
||||
embed_layernorm_nodes = [
|
||||
helper.make_node(
|
||||
"EmbedLayerNormalization",
|
||||
[
|
||||
"ids",
|
||||
"ids",
|
||||
"word_embeddings_weight",
|
||||
"pos_embeddings_weight",
|
||||
"token_embeddings_weight",
|
||||
"layernorm_weight",
|
||||
"layernorm_bias",
|
||||
"",
|
||||
"ids",
|
||||
],
|
||||
["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index"],
|
||||
"EmbedLayerNormalization_0",
|
||||
domain="com.microsoft",
|
||||
epsilon=epsilon,
|
||||
),
|
||||
]
|
||||
attention_nodes = (
|
||||
[
|
||||
helper.make_node("MatMul", ["EmbedLayerNormalization_0_output", "q_weight"], ["q_out"], "q_attn"),
|
||||
helper.make_node("MatMul", ["EmbedLayerNormalization_0_output", "k_weight"], ["k_out"], "k_attn"),
|
||||
helper.make_node("MatMul", ["EmbedLayerNormalization_0_output", "v_weight"], ["v_out"], "v_attn"),
|
||||
helper.make_node("Add", ["q_out", "k_out"], ["qk_out"], "qk_attn"),
|
||||
helper.make_node("Add", ["qk_out", "v_out"], ["qkv_out"], "qkv_attn"),
|
||||
]
|
||||
if not one_attention_node
|
||||
else [
|
||||
helper.make_node(
|
||||
"Attention",
|
||||
["EmbedLayerNormalization_0_output", "qkv_weight", "qkv_bias", ""],
|
||||
["attn_out"],
|
||||
"qkv_attn",
|
||||
domain="com.microsoft",
|
||||
num_heads=num_heads,
|
||||
),
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
["attn_out", "fix_hidden_size"],
|
||||
["qkv_out"],
|
||||
"matmul_after_attn",
|
||||
),
|
||||
]
|
||||
)
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
"SkipLayerNormalization",
|
||||
["EmbedLayerNormalization_0_output", "qkv_out", "layernorm_weight", "layernorm_bias", "dense_bias"],
|
||||
["output_0"],
|
||||
"attn_skip_layernorm",
|
||||
domain="com.microsoft",
|
||||
epsilon=epsilon,
|
||||
),
|
||||
]
|
||||
nodes.extend(embed_layernorm_nodes)
|
||||
nodes.extend(attention_nodes)
|
||||
|
||||
# Construct data initializers for graph nodes
|
||||
embed_layernorm_initializers = [
|
||||
helper.make_tensor(
|
||||
"word_embeddings_weight",
|
||||
TensorProto.FLOAT,
|
||||
[word_embed, hidden_size],
|
||||
[(i + 1) / (word_embed * hidden_size) for i in range(word_embed * hidden_size)],
|
||||
),
|
||||
helper.make_tensor(
|
||||
"pos_embeddings_weight",
|
||||
TensorProto.FLOAT,
|
||||
[pos_embed, hidden_size],
|
||||
[(i + 2) / (pos_embed * hidden_size) for i in range(pos_embed * hidden_size)],
|
||||
),
|
||||
helper.make_tensor(
|
||||
"token_embeddings_weight",
|
||||
TensorProto.FLOAT,
|
||||
[token_embed, hidden_size],
|
||||
[(i + 3) / (token_embed * hidden_size) for i in range(token_embed * hidden_size)],
|
||||
),
|
||||
helper.make_tensor(
|
||||
"layernorm_weight", TensorProto.FLOAT, [hidden_size], [(i + 4) / hidden_size for i in range(hidden_size)]
|
||||
),
|
||||
helper.make_tensor(
|
||||
"layernorm_bias", TensorProto.FLOAT, [hidden_size], [(i + 5) / hidden_size for i in range(hidden_size)]
|
||||
),
|
||||
]
|
||||
attention_initializers = (
|
||||
[
|
||||
helper.make_tensor(
|
||||
"q_weight",
|
||||
TensorProto.FLOAT,
|
||||
[hidden_size, hidden_size],
|
||||
[(i + 6) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
|
||||
),
|
||||
helper.make_tensor(
|
||||
"k_weight",
|
||||
TensorProto.FLOAT,
|
||||
[hidden_size, hidden_size],
|
||||
[(i + 7) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
|
||||
),
|
||||
helper.make_tensor(
|
||||
"v_weight",
|
||||
TensorProto.FLOAT,
|
||||
[hidden_size, hidden_size],
|
||||
[(i + 8) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
|
||||
),
|
||||
]
|
||||
if not one_attention_node
|
||||
else [
|
||||
helper.make_tensor(
|
||||
"qkv_weight",
|
||||
TensorProto.FLOAT,
|
||||
[hidden_size, hidden_size],
|
||||
[(i + 9) / (hidden_size * hidden_size) for i in range(hidden_size * hidden_size)],
|
||||
),
|
||||
helper.make_tensor(
|
||||
"qkv_bias", TensorProto.FLOAT, [hidden_size], [(i + 10) / hidden_size for i in range(hidden_size)]
|
||||
),
|
||||
helper.make_tensor(
|
||||
"fix_hidden_size",
|
||||
TensorProto.FLOAT,
|
||||
[attn_hidden_dim, hidden_size],
|
||||
[(i + 11) / (attn_hidden_dim * hidden_size) for i in range(attn_hidden_dim * hidden_size)],
|
||||
),
|
||||
]
|
||||
)
|
||||
initializers = [
|
||||
helper.make_tensor(
|
||||
"dense_bias", TensorProto.FLOAT, [hidden_size], [(i + 12) / hidden_size for i in range(hidden_size)]
|
||||
),
|
||||
]
|
||||
initializers.extend(embed_layernorm_initializers)
|
||||
initializers.extend(attention_initializers)
|
||||
|
||||
# Construct graph
|
||||
graph = helper.make_graph(nodes, "GPT2_embedlayer_graph", inputs, outputs, initializers)
|
||||
opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16))
|
||||
return helper.make_model(graph, opset_imports=(opsetid,))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = create_gpt2_attention()
|
||||
onnx.save(model, "gpt2_attention.onnx")
|
||||
|
||||
model = create_gpt2_attention(switch_add_inputs=True)
|
||||
onnx.save(model, "gpt2_attention_add.onnx")
|
||||
|
||||
model = create_gpt2_embedlayer()
|
||||
onnx.save(model, "gpt2_embedlayer.onnx")
|
||||
|
||||
model = create_gpt2_fused_embedlayer()
|
||||
onnx.save(model, "gpt2_embedlayer_exp.onnx")
|
||||
|
||||
model = create_gpt2_embedlayer(one_attention_node=True)
|
||||
onnx.save(model, "gpt2_embedlayer_one_attn.onnx")
|
||||
|
||||
model = create_gpt2_fused_embedlayer(one_attention_node=True)
|
||||
onnx.save(model, "gpt2_embedlayer_one_attn_exp.onnx")
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,79 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
from gpt2_model_generator import create_gpt2_embedlayer
|
||||
from parity_utilities import find_transformers_source
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
if find_transformers_source():
|
||||
from onnx_model import OnnxModel
|
||||
from optimizer import optimize_model
|
||||
else:
|
||||
from onnxruntime.transformers.onnx_model import OnnxModel
|
||||
from onnxruntime.transformers.optimizer import optimize_model
|
||||
|
||||
|
||||
class TestFusion(unittest.TestCase):
|
||||
def verify_fusion(self, optimized_model, expected_model_filename):
|
||||
optimized_model.topological_sort()
|
||||
|
||||
expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename)
|
||||
expected_model = OnnxModel(onnx.load(expected_model_path))
|
||||
expected_model.topological_sort()
|
||||
|
||||
self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph))
|
||||
|
||||
def verify_parity(self, optimized_model_path, expected_model):
|
||||
expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model)
|
||||
sess_optimized = InferenceSession(optimized_model_path, providers=["CPUExecutionProvider"])
|
||||
sess_expected = InferenceSession(expected_model_path, providers=["CPUExecutionProvider"])
|
||||
inputs = np.random.randint(low=0, high=6, size=(4, 8), dtype=np.int32) + 1
|
||||
|
||||
outputs_optimized = sess_optimized.run(None, {"ids": inputs})
|
||||
outputs_expected = sess_expected.run(None, {"ids": inputs})
|
||||
self.assertTrue(np.allclose(outputs_optimized[0], outputs_expected[0]))
|
||||
|
||||
def test_embedlayer_fusion(self):
|
||||
model = create_gpt2_embedlayer(one_attention_node=False)
|
||||
path = "."
|
||||
original_model_path = os.path.join(path, "gpt2_embedlayer.onnx")
|
||||
optimized_model_path = os.path.join(path, "gpt2_embedlayer_opt.onnx")
|
||||
expected_model_filename = "gpt2_embedlayer_exp.onnx"
|
||||
|
||||
onnx.save(model, original_model_path)
|
||||
optimized_model = optimize_model(original_model_path, model_type="gpt2")
|
||||
optimized_model.save_model_to_file(optimized_model_path, use_external_data_format=True)
|
||||
|
||||
self.verify_fusion(optimized_model, expected_model_filename)
|
||||
self.verify_parity(optimized_model_path, expected_model_filename)
|
||||
os.remove(original_model_path)
|
||||
os.remove(optimized_model_path)
|
||||
|
||||
def test_embedlayer_fusion_one_attn_node(self):
|
||||
model = create_gpt2_embedlayer(one_attention_node=True)
|
||||
path = "."
|
||||
original_model_path = os.path.join(path, "gpt2_embedlayer_one_attn.onnx")
|
||||
optimized_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_opt.onnx")
|
||||
expected_model_filename = "gpt2_embedlayer_one_attn_exp.onnx"
|
||||
|
||||
onnx.save(model, original_model_path)
|
||||
optimized_model = optimize_model(original_model_path, model_type="gpt2")
|
||||
optimized_model.save_model_to_file(optimized_model_path, use_external_data_format=True)
|
||||
|
||||
self.verify_fusion(optimized_model, expected_model_filename)
|
||||
self.verify_parity(optimized_model_path, expected_model_filename)
|
||||
os.remove(original_model_path)
|
||||
os.remove(optimized_model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in a new issue