mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Modify embedlayernorm fusion due to shape node merging (#4967)
* modify embedlayernorm fusion due to shape integration * update * update comments * review comments * review comments * fix test
This commit is contained in:
parent
38453acae3
commit
c239ff0750
5 changed files with 182 additions and 34 deletions
|
|
@ -111,6 +111,8 @@ static bool IsNeighborNodeExpectedTypes(Node::NodeConstIterator start, const Nod
|
|||
|
||||
The Expand and Gather on the bottom will not be added to subgraph_node_indices.
|
||||
It is because they are matched as part of other subgraph.
|
||||
|
||||
Two Shape nodes may merge into one.
|
||||
*/
|
||||
|
||||
static bool MatchInputToConcatSubgraph(
|
||||
|
|
@ -134,8 +136,14 @@ static bool MatchInputToConcatSubgraph(
|
|||
DEBUG_LOG("Failed to find path 1 of position shape.");
|
||||
return false;
|
||||
}
|
||||
const size_t shape_index = edges.size() - 1;
|
||||
for (size_t i = 0; i < edges.size(); i++) {
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, edges[i]->GetNode(), 1)) {
|
||||
// Shape may have multiple outputs due to shape integration
|
||||
// So check it later
|
||||
if (i == shape_index) {
|
||||
continue;
|
||||
}
|
||||
DEBUG_LOG("Output edge count not expected for nodes in path 1 of position shape.");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -161,9 +169,10 @@ static bool MatchInputToConcatSubgraph(
|
|||
return false;
|
||||
}
|
||||
|
||||
// Shape may have multiple outputs due to shape integration
|
||||
// Check it later
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, edges[0]->GetNode(), 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, edges[1]->GetNode(), 2) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, edges[2]->GetNode(), 1)) {
|
||||
!optimizer_utils::CheckOutputEdges(graph, edges[1]->GetNode(), 2)) {
|
||||
DEBUG_LOG("Output edge count not expected for nodes in path 2 of position shape.");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -189,6 +198,19 @@ static bool MatchInputToConcatSubgraph(
|
|||
return false;
|
||||
}
|
||||
|
||||
// Check if shape have more than one output, it may due to shape integration
|
||||
// We check if they share the same node
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, shape_node_0, 1) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, shape_node_1, 1)) {
|
||||
if (shape_node_0.Index() == shape_node_1.Index() &&
|
||||
(shape_node_0.GetOutputEdgesCount() == 2 ||
|
||||
shape_node_0.GetOutputEdgesCount() == 4)) {
|
||||
DEBUG_LOG("two paths share the same shape");
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
AddNodes(subgraph_node_indices, edges);
|
||||
return true;
|
||||
}
|
||||
|
|
@ -210,6 +232,7 @@ static bool MatchInputToConcatSubgraph(
|
|||
* Note that position gather node is the node in the bottom of above sub-graph.
|
||||
* Paths in ^^ are alternative path to be matched if path input_ids -> Shape -> Expand -> Gather is not found.
|
||||
* Path in ** is an alternative path to check.
|
||||
* Two shape node may merge into one
|
||||
*/
|
||||
static bool MatchPositionEmbeddingSubgraphsFromGather(
|
||||
Graph& graph,
|
||||
|
|
@ -268,12 +291,19 @@ static bool MatchPositionEmbeddingSubgraphsFromGather(
|
|||
return false;
|
||||
}
|
||||
const size_t gather_index = pg_edges.size() - 2;
|
||||
const size_t shape_index = pg_edges.size() - 1;
|
||||
// All nodes in Path 1 must have only 1 output edge, except the gather node allowed 1 or 2 output edges
|
||||
// And shape node allowed multiple output edges due to shape integration
|
||||
for (size_t i = 0; i < pg_edges.size(); i++) {
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 1)) {
|
||||
if (i == gather_index && optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 2)) {
|
||||
continue;
|
||||
}
|
||||
if (i == shape_index &&
|
||||
(optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 2) ||
|
||||
optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 4))) {
|
||||
continue;
|
||||
}
|
||||
DEBUG_LOG("Output edge count not expected for nodes in path1.");
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2630,19 +2630,25 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) {
|
|||
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
|
||||
}
|
||||
|
||||
//DistilBert
|
||||
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format7.onnx";
|
||||
static void TestEmbedLayerNormFusionDistilBert(const std::basic_string<ORTCHAR_T>& model_uri,
|
||||
std::map<std::string, int>& op_to_count,
|
||||
logging::Logger* logger) {
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
|
||||
auto ret1 = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
|
||||
auto ret1 = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger);
|
||||
ASSERT_TRUE(ret1.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
}
|
||||
|
||||
//DistilBert
|
||||
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) {
|
||||
std::map<std::string, int> op_to_count;
|
||||
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format7.onnx", op_to_count, logger_.get());
|
||||
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
|
||||
EXPECT_EQ(op_to_count["Attention"], 1);
|
||||
EXPECT_EQ(op_to_count["Cast"], 2);
|
||||
|
|
@ -2652,6 +2658,30 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) {
|
|||
EXPECT_EQ(op_to_count["ReduceSum"], 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8) {
|
||||
std::map<std::string, int> op_to_count;
|
||||
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format8.onnx", op_to_count, logger_.get());
|
||||
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
|
||||
EXPECT_EQ(op_to_count["Attention"], 1);
|
||||
EXPECT_EQ(op_to_count["Cast"], 2);
|
||||
EXPECT_EQ(op_to_count["Shape"], 0);
|
||||
EXPECT_EQ(op_to_count["Gather"], 0);
|
||||
EXPECT_EQ(op_to_count["Unsqueeze"], 0);
|
||||
EXPECT_EQ(op_to_count["ReduceSum"], 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9) {
|
||||
std::map<std::string, int> op_to_count;
|
||||
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format9.onnx", op_to_count, logger_.get());
|
||||
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
|
||||
EXPECT_EQ(op_to_count["Attention"], 1);
|
||||
EXPECT_EQ(op_to_count["Cast"], 2);
|
||||
EXPECT_EQ(op_to_count["Shape"], 0);
|
||||
EXPECT_EQ(op_to_count["Gather"], 2);
|
||||
EXPECT_EQ(op_to_count["Unsqueeze"], 2);
|
||||
EXPECT_EQ(op_to_count["ReduceSum"], 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_multiple.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9.onnx
vendored
Normal file
Binary file not shown.
|
|
@ -278,12 +278,33 @@ def GenerateModel6(model_name):
|
|||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
def GenerateModel7(model_name):
|
||||
batch_size = 2
|
||||
hidden_size = 4
|
||||
attention_heads = 2
|
||||
sequence_length = 3
|
||||
def GenerateInitializers2(hidden_size):
|
||||
qkv_weights = [
|
||||
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0,
|
||||
3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0,
|
||||
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0
|
||||
]
|
||||
|
||||
initializers = [ # initializers
|
||||
helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('pos_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('indices_0', TensorProto.INT64, [], [0]),
|
||||
helper.make_tensor('indices_1', TensorProto.INT64, [], [1]),
|
||||
helper.make_tensor('start', TensorProto.INT64, [], [0]),
|
||||
helper.make_tensor('delta', TensorProto.INT64, [], [1]),
|
||||
helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [hidden_size], [1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [hidden_size, 3 * hidden_size], qkv_weights),
|
||||
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [3 * hidden_size],
|
||||
[0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size],
|
||||
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
|
||||
]
|
||||
|
||||
return initializers
|
||||
|
||||
def GenerateNodes2(attention_heads):
|
||||
nodes = [
|
||||
helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather", axis=0),
|
||||
|
||||
|
|
@ -313,28 +334,17 @@ def GenerateModel7(model_name):
|
|||
helper.make_node("Add", ["add2_out", "layernorm_out"], ["add3_out"], "add3")
|
||||
]
|
||||
|
||||
qkv_weights = [
|
||||
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0,
|
||||
3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0,
|
||||
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0
|
||||
]
|
||||
return nodes
|
||||
|
||||
initializers = [ # initializers
|
||||
helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('pos_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('indices_0', TensorProto.INT64, [], [0]),
|
||||
helper.make_tensor('indices_1', TensorProto.INT64, [], [1]),
|
||||
helper.make_tensor('start', TensorProto.INT64, [], [0]),
|
||||
helper.make_tensor('delta', TensorProto.INT64, [], [1]),
|
||||
helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [hidden_size], [1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [hidden_size, 3 * hidden_size], qkv_weights),
|
||||
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [3 * hidden_size],
|
||||
[0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size],
|
||||
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
|
||||
]
|
||||
def GenerateModel7(model_name):
|
||||
batch_size = 2
|
||||
hidden_size = 4
|
||||
attention_heads = 2
|
||||
sequence_length = 3
|
||||
|
||||
nodes = GenerateNodes2(attention_heads)
|
||||
|
||||
initializers = GenerateInitializers2(hidden_size)
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
|
|
@ -351,9 +361,87 @@ def GenerateModel7(model_name):
|
|||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
def GenerateModel8(model_name):
|
||||
batch_size = -1
|
||||
hidden_size = 4
|
||||
attention_heads = 2
|
||||
sequence_length = -1
|
||||
|
||||
nodes = GenerateNodes2(attention_heads)
|
||||
|
||||
del nodes[5:7]
|
||||
del nodes[1:3]
|
||||
new_nodes = [
|
||||
helper.make_node("Shape", ["input_ids"], ["shape_out"], "shape"),
|
||||
helper.make_node("Gather", ["shape_out", "indices_1"], ["gather0_out"], "gather0"),
|
||||
helper.make_node("Expand", ["unsqueeze0_out", "shape_out"], ["expand_out"], "expand")
|
||||
]
|
||||
nodes = nodes + new_nodes
|
||||
|
||||
initializers = GenerateInitializers2(hidden_size)
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"EmbedLayerNorm_format8", #name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info('input_ids', TensorProto.INT64, [batch_size, sequence_length]),
|
||||
helper.make_tensor_value_info('input_mask', TensorProto.INT64, [batch_size, sequence_length]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info('add3_out', TensorProto.FLOAT, [batch_size, sequence_length, hidden_size]),
|
||||
],
|
||||
initializers)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
def GenerateModel9(model_name):
|
||||
batch_size = -1
|
||||
hidden_size = 4
|
||||
attention_heads = 2
|
||||
sequence_length = -1
|
||||
|
||||
nodes = GenerateNodes2(attention_heads)
|
||||
|
||||
del nodes[10]
|
||||
del nodes[5:7]
|
||||
del nodes[1:3]
|
||||
new_nodes = [
|
||||
helper.make_node("Shape", ["input_ids"], ["shape_out"], "shape"),
|
||||
helper.make_node("Gather", ["shape_out", "indices_1"], ["gather0_out"], "gather0"),
|
||||
helper.make_node("Expand", ["unsqueeze0_out", "shape_out"], ["expand_out"], "expand"),
|
||||
helper.make_node("Gather", ["shape_out", "indices_0"], ["gather1_out"], "gather1"),
|
||||
helper.make_node("Gather", ["shape_out", "indices_1"], ["gather2_out"], "gather2"),
|
||||
helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
|
||||
helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
|
||||
helper.make_node("Concat", ["unsqueeze1_out", "unsqueeze2_out"], ["concat_out"], "concat", axis=0),
|
||||
helper.make_node('ConstantOfShape', ['concat_out'], ['constant_of_shape_out'], "constant_of_shape",
|
||||
value=helper.make_tensor('mask_shape', TensorProto.FLOAT, [1], [1.0])),
|
||||
helper.make_node("Cast", ["constant_of_shape_out"], ["mask_cast_out"], "mask_cast", to=6),
|
||||
]
|
||||
nodes = nodes + new_nodes
|
||||
|
||||
initializers = GenerateInitializers2(hidden_size)
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"EmbedLayerNorm_format9", #name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info('input_ids', TensorProto.INT64, [batch_size, sequence_length]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info('add3_out', TensorProto.FLOAT, [batch_size, sequence_length, hidden_size]),
|
||||
],
|
||||
initializers)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
GenerateModel3('embed_layer_norm_format3.onnx', True)
|
||||
GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False)
|
||||
GenerateModel5('embed_layer_norm_format5.onnx')
|
||||
GenerateModel6('embed_layer_norm_format6.onnx')
|
||||
GenerateModel7('embed_layer_norm_format7.onnx') #distilbert
|
||||
GenerateModel8('embed_layer_norm_format8.onnx') #distilbert & shape nodes integration with input mask
|
||||
GenerateModel9('embed_layer_norm_format9.onnx') #distilbert & shape nodes integration without input mask
|
||||
GenerateMultipleEmbedModel('embed_layer_norm_multiple.onnx')
|
||||
|
|
|
|||
Loading…
Reference in a new issue