diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index 63951d6a89..31a88be543 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -111,6 +111,8 @@ static bool IsNeighborNodeExpectedTypes(Node::NodeConstIterator start, const Nod The Expand and Gather on the bottom will not be added to subgraph_node_indices. It is because they are matched as part of other subgraph. + + Two Shape nodes may merge into one. */ static bool MatchInputToConcatSubgraph( @@ -134,8 +136,14 @@ static bool MatchInputToConcatSubgraph( DEBUG_LOG("Failed to find path 1 of position shape."); return false; } + const size_t shape_index = edges.size() - 1; for (size_t i = 0; i < edges.size(); i++) { if (!optimizer_utils::CheckOutputEdges(graph, edges[i]->GetNode(), 1)) { + // Shape may have multiple outputs due to shape integration + // So check it later + if (i == shape_index) { + continue; + } DEBUG_LOG("Output edge count not expected for nodes in path 1 of position shape."); return false; } @@ -161,9 +169,10 @@ static bool MatchInputToConcatSubgraph( return false; } + // Shape may have multiple outputs due to shape integration + // Check it later if (!optimizer_utils::CheckOutputEdges(graph, edges[0]->GetNode(), 1) || - !optimizer_utils::CheckOutputEdges(graph, edges[1]->GetNode(), 2) || - !optimizer_utils::CheckOutputEdges(graph, edges[2]->GetNode(), 1)) { + !optimizer_utils::CheckOutputEdges(graph, edges[1]->GetNode(), 2)) { DEBUG_LOG("Output edge count not expected for nodes in path 2 of position shape."); return false; } @@ -189,6 +198,19 @@ static bool MatchInputToConcatSubgraph( return false; } + // Check if shape have more than one output, it may due to shape integration + // We check if they share the same node + if (!optimizer_utils::CheckOutputEdges(graph, shape_node_0, 1) || + !optimizer_utils::CheckOutputEdges(graph, shape_node_1, 1)) { + if (shape_node_0.Index() == shape_node_1.Index() && + (shape_node_0.GetOutputEdgesCount() == 2 || + shape_node_0.GetOutputEdgesCount() == 4)) { + DEBUG_LOG("two paths share the same shape"); + } else { + return false; + } + } + AddNodes(subgraph_node_indices, edges); return true; } @@ -210,6 +232,7 @@ static bool MatchInputToConcatSubgraph( * Note that position gather node is the node in the bottom of above sub-graph. * Paths in ^^ are alternative path to be matched if path input_ids -> Shape -> Expand -> Gather is not found. * Path in ** is an alternative path to check. + * Two shape node may merge into one */ static bool MatchPositionEmbeddingSubgraphsFromGather( Graph& graph, @@ -268,12 +291,19 @@ static bool MatchPositionEmbeddingSubgraphsFromGather( return false; } const size_t gather_index = pg_edges.size() - 2; + const size_t shape_index = pg_edges.size() - 1; // All nodes in Path 1 must have only 1 output edge, except the gather node allowed 1 or 2 output edges + // And shape node allowed multiple output edges due to shape integration for (size_t i = 0; i < pg_edges.size(); i++) { if (!optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 1)) { if (i == gather_index && optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 2)) { continue; } + if (i == shape_index && + (optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 2) || + optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 4))) { + continue; + } DEBUG_LOG("Output edge count not expected for nodes in path1."); return false; } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index a1ec44fd88..101588b7ae 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2630,19 +2630,25 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) { EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1); } -//DistilBert -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) { - auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format7.onnx"; +static void TestEmbedLayerNormFusionDistilBert(const std::basic_string& model_uri, + std::map& op_to_count, + logging::Logger* logger) { std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret1 = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + auto ret1 = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger); ASSERT_TRUE(ret1.IsOK()); - std::map op_to_count = CountOpsInGraph(graph); + op_to_count = CountOpsInGraph(graph); +} + +//DistilBert +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) { + std::map op_to_count; + TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format7.onnx", op_to_count, logger_.get()); EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1); EXPECT_EQ(op_to_count["Attention"], 1); EXPECT_EQ(op_to_count["Cast"], 2); @@ -2652,6 +2658,30 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) { EXPECT_EQ(op_to_count["ReduceSum"], 1); } +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8) { + std::map op_to_count; + TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format8.onnx", op_to_count, logger_.get()); + EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1); + EXPECT_EQ(op_to_count["Attention"], 1); + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 1); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9) { + std::map op_to_count; + TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format9.onnx", op_to_count, logger_.get()); + EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1); + EXPECT_EQ(op_to_count["Attention"], 1); + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Gather"], 2); + EXPECT_EQ(op_to_count["Unsqueeze"], 2); + EXPECT_EQ(op_to_count["ReduceSum"], 1); +} + TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) { auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_multiple.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8.onnx new file mode 100644 index 0000000000..a133cc53cf Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9.onnx new file mode 100644 index 0000000000..12cb4f977e Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py index 9fcd169f3a..1118023ef3 100644 --- a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py @@ -278,12 +278,33 @@ def GenerateModel6(model_name): model = helper.make_model(graph) onnx.save(model, model_name) -def GenerateModel7(model_name): - batch_size = 2 - hidden_size = 4 - attention_heads = 2 - sequence_length = 3 +def GenerateInitializers2(hidden_size): + qkv_weights = [ + 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, + 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, + 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0 + ] + initializers = [ # initializers + helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('pos_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('indices_0', TensorProto.INT64, [], [0]), + helper.make_tensor('indices_1', TensorProto.INT64, [], [1]), + helper.make_tensor('start', TensorProto.INT64, [], [0]), + helper.make_tensor('delta', TensorProto.INT64, [], [1]), + helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [hidden_size], [1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor('qkv_weights', TensorProto.FLOAT, [hidden_size, 3 * hidden_size], qkv_weights), + helper.make_tensor('qkv_bias', TensorProto.FLOAT, [3 * hidden_size], + [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]), + helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), + ] + + return initializers + +def GenerateNodes2(attention_heads): nodes = [ helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather", axis=0), @@ -313,28 +334,17 @@ def GenerateModel7(model_name): helper.make_node("Add", ["add2_out", "layernorm_out"], ["add3_out"], "add3") ] - qkv_weights = [ - 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, - 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, - 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0 - ] + return nodes - initializers = [ # initializers - helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('pos_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('indices_0', TensorProto.INT64, [], [0]), - helper.make_tensor('indices_1', TensorProto.INT64, [], [1]), - helper.make_tensor('start', TensorProto.INT64, [], [0]), - helper.make_tensor('delta', TensorProto.INT64, [], [1]), - helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [hidden_size], [1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('qkv_weights', TensorProto.FLOAT, [hidden_size, 3 * hidden_size], qkv_weights), - helper.make_tensor('qkv_bias', TensorProto.FLOAT, [3 * hidden_size], - [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size], - [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), - ] +def GenerateModel7(model_name): + batch_size = 2 + hidden_size = 4 + attention_heads = 2 + sequence_length = 3 + + nodes = GenerateNodes2(attention_heads) + + initializers = GenerateInitializers2(hidden_size) graph = helper.make_graph( nodes, @@ -351,9 +361,87 @@ def GenerateModel7(model_name): model = helper.make_model(graph) onnx.save(model, model_name) +def GenerateModel8(model_name): + batch_size = -1 + hidden_size = 4 + attention_heads = 2 + sequence_length = -1 + + nodes = GenerateNodes2(attention_heads) + + del nodes[5:7] + del nodes[1:3] + new_nodes = [ + helper.make_node("Shape", ["input_ids"], ["shape_out"], "shape"), + helper.make_node("Gather", ["shape_out", "indices_1"], ["gather0_out"], "gather0"), + helper.make_node("Expand", ["unsqueeze0_out", "shape_out"], ["expand_out"], "expand") + ] + nodes = nodes + new_nodes + + initializers = GenerateInitializers2(hidden_size) + + graph = helper.make_graph( + nodes, + "EmbedLayerNorm_format8", #name + [ # inputs + helper.make_tensor_value_info('input_ids', TensorProto.INT64, [batch_size, sequence_length]), + helper.make_tensor_value_info('input_mask', TensorProto.INT64, [batch_size, sequence_length]), + ], + [ # outputs + helper.make_tensor_value_info('add3_out', TensorProto.FLOAT, [batch_size, sequence_length, hidden_size]), + ], + initializers) + + model = helper.make_model(graph) + onnx.save(model, model_name) + +def GenerateModel9(model_name): + batch_size = -1 + hidden_size = 4 + attention_heads = 2 + sequence_length = -1 + + nodes = GenerateNodes2(attention_heads) + + del nodes[10] + del nodes[5:7] + del nodes[1:3] + new_nodes = [ + helper.make_node("Shape", ["input_ids"], ["shape_out"], "shape"), + helper.make_node("Gather", ["shape_out", "indices_1"], ["gather0_out"], "gather0"), + helper.make_node("Expand", ["unsqueeze0_out", "shape_out"], ["expand_out"], "expand"), + helper.make_node("Gather", ["shape_out", "indices_0"], ["gather1_out"], "gather1"), + helper.make_node("Gather", ["shape_out", "indices_1"], ["gather2_out"], "gather2"), + helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), + helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), + helper.make_node("Concat", ["unsqueeze1_out", "unsqueeze2_out"], ["concat_out"], "concat", axis=0), + helper.make_node('ConstantOfShape', ['concat_out'], ['constant_of_shape_out'], "constant_of_shape", + value=helper.make_tensor('mask_shape', TensorProto.FLOAT, [1], [1.0])), + helper.make_node("Cast", ["constant_of_shape_out"], ["mask_cast_out"], "mask_cast", to=6), + ] + nodes = nodes + new_nodes + + initializers = GenerateInitializers2(hidden_size) + + graph = helper.make_graph( + nodes, + "EmbedLayerNorm_format9", #name + [ # inputs + helper.make_tensor_value_info('input_ids', TensorProto.INT64, [batch_size, sequence_length]), + ], + [ # outputs + helper.make_tensor_value_info('add3_out', TensorProto.FLOAT, [batch_size, sequence_length, hidden_size]), + ], + initializers) + + model = helper.make_model(graph) + onnx.save(model, model_name) + GenerateModel3('embed_layer_norm_format3.onnx', True) GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False) GenerateModel5('embed_layer_norm_format5.onnx') GenerateModel6('embed_layer_norm_format6.onnx') GenerateModel7('embed_layer_norm_format7.onnx') #distilbert +GenerateModel8('embed_layer_norm_format8.onnx') #distilbert & shape nodes integration with input mask +GenerateModel9('embed_layer_norm_format9.onnx') #distilbert & shape nodes integration without input mask GenerateMultipleEmbedModel('embed_layer_norm_multiple.onnx')