From c239ff0750fb2e626e3d5f3a154c99551d543cb9 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Tue, 8 Sep 2020 14:17:29 -0700 Subject: [PATCH] Modify embedlayernorm fusion due to shape node merging (#4967) * modify embedlayernorm fusion due to shape integration * update * update comments * review comments * review comments * fix test --- .../core/optimizer/embed_layer_norm_fusion.cc | 34 ++++- .../test/optimizer/graph_transform_test.cc | 42 +++++- .../fusion/embed_layer_norm_format8.onnx | Bin 0 -> 1919 bytes .../fusion/embed_layer_norm_format9.onnx | Bin 0 -> 2308 bytes .../transform/fusion/embed_layer_norm_gen.py | 140 ++++++++++++++---- 5 files changed, 182 insertions(+), 34 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9.onnx diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index 63951d6a89..31a88be543 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -111,6 +111,8 @@ static bool IsNeighborNodeExpectedTypes(Node::NodeConstIterator start, const Nod The Expand and Gather on the bottom will not be added to subgraph_node_indices. It is because they are matched as part of other subgraph. + + Two Shape nodes may merge into one. */ static bool MatchInputToConcatSubgraph( @@ -134,8 +136,14 @@ static bool MatchInputToConcatSubgraph( DEBUG_LOG("Failed to find path 1 of position shape."); return false; } + const size_t shape_index = edges.size() - 1; for (size_t i = 0; i < edges.size(); i++) { if (!optimizer_utils::CheckOutputEdges(graph, edges[i]->GetNode(), 1)) { + // Shape may have multiple outputs due to shape integration + // So check it later + if (i == shape_index) { + continue; + } DEBUG_LOG("Output edge count not expected for nodes in path 1 of position shape."); return false; } @@ -161,9 +169,10 @@ static bool MatchInputToConcatSubgraph( return false; } + // Shape may have multiple outputs due to shape integration + // Check it later if (!optimizer_utils::CheckOutputEdges(graph, edges[0]->GetNode(), 1) || - !optimizer_utils::CheckOutputEdges(graph, edges[1]->GetNode(), 2) || - !optimizer_utils::CheckOutputEdges(graph, edges[2]->GetNode(), 1)) { + !optimizer_utils::CheckOutputEdges(graph, edges[1]->GetNode(), 2)) { DEBUG_LOG("Output edge count not expected for nodes in path 2 of position shape."); return false; } @@ -189,6 +198,19 @@ static bool MatchInputToConcatSubgraph( return false; } + // Check if shape have more than one output, it may due to shape integration + // We check if they share the same node + if (!optimizer_utils::CheckOutputEdges(graph, shape_node_0, 1) || + !optimizer_utils::CheckOutputEdges(graph, shape_node_1, 1)) { + if (shape_node_0.Index() == shape_node_1.Index() && + (shape_node_0.GetOutputEdgesCount() == 2 || + shape_node_0.GetOutputEdgesCount() == 4)) { + DEBUG_LOG("two paths share the same shape"); + } else { + return false; + } + } + AddNodes(subgraph_node_indices, edges); return true; } @@ -210,6 +232,7 @@ static bool MatchInputToConcatSubgraph( * Note that position gather node is the node in the bottom of above sub-graph. * Paths in ^^ are alternative path to be matched if path input_ids -> Shape -> Expand -> Gather is not found. * Path in ** is an alternative path to check. + * Two shape node may merge into one */ static bool MatchPositionEmbeddingSubgraphsFromGather( Graph& graph, @@ -268,12 +291,19 @@ static bool MatchPositionEmbeddingSubgraphsFromGather( return false; } const size_t gather_index = pg_edges.size() - 2; + const size_t shape_index = pg_edges.size() - 1; // All nodes in Path 1 must have only 1 output edge, except the gather node allowed 1 or 2 output edges + // And shape node allowed multiple output edges due to shape integration for (size_t i = 0; i < pg_edges.size(); i++) { if (!optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 1)) { if (i == gather_index && optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 2)) { continue; } + if (i == shape_index && + (optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 2) || + optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 4))) { + continue; + } DEBUG_LOG("Output edge count not expected for nodes in path1."); return false; } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index a1ec44fd88..101588b7ae 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2630,19 +2630,25 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) { EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1); } -//DistilBert -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) { - auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format7.onnx"; +static void TestEmbedLayerNormFusionDistilBert(const std::basic_string& model_uri, + std::map& op_to_count, + logging::Logger* logger) { std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret1 = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + auto ret1 = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger); ASSERT_TRUE(ret1.IsOK()); - std::map op_to_count = CountOpsInGraph(graph); + op_to_count = CountOpsInGraph(graph); +} + +//DistilBert +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) { + std::map op_to_count; + TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format7.onnx", op_to_count, logger_.get()); EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1); EXPECT_EQ(op_to_count["Attention"], 1); EXPECT_EQ(op_to_count["Cast"], 2); @@ -2652,6 +2658,30 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) { EXPECT_EQ(op_to_count["ReduceSum"], 1); } +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8) { + std::map op_to_count; + TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format8.onnx", op_to_count, logger_.get()); + EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1); + EXPECT_EQ(op_to_count["Attention"], 1); + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 1); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9) { + std::map op_to_count; + TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format9.onnx", op_to_count, logger_.get()); + EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1); + EXPECT_EQ(op_to_count["Attention"], 1); + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Gather"], 2); + EXPECT_EQ(op_to_count["Unsqueeze"], 2); + EXPECT_EQ(op_to_count["ReduceSum"], 1); +} + TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) { auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_multiple.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a133cc53cf0547a18a3af8effcb0d2065257b4ab GIT binary patch literal 1919 zcmd^9OHbQC5ROCQ#{@#UYAcijLKYJ0Rz=~VQX!QFSEyA5JwU5Uy;w`U5R2Fj@w!k> z9Qqq7(Gx$Qf1@|}JH669I7!R0|qta))IbD^=OX~1OpNt zQVBJ6b+)U!0)sYWl&GGB@o_@w+0YoqvTc1F>D0?KZu6H7cnelPiib%6J?)dw6OW_1 zfHO(JSiUQVV&YJ^1DVA0mSjmP;$%;3dLDe!AjRQ9k9_4-3!_2co>2d=%V0VK?fWD~ zQxcFSqPkWp+GtyNzvs@R=bsVgMAzZTXmH4WAPR0{(V zAE7Dn?{-McWVlu;+LhNF#1@oUG^2~&eSo|%s!L-(^ysN{hOum0yVOfMbT0|gHt={1 zJPqa%rG3v2h7-|;e40=@KKe2&NereCL@}W77%~V`Ay4pwZEZ3}Ls^FgSYZ;lT}r%o zrpyChpXx-xV&HcMQ5+pGScQ?VNQebVZ}=SI2}n^?TUTdW-yv)#>A_RTXk{a1T(Z^6 zX{(o2s~4q#gbV3O`<2vwMcL11X=TM-(x=j-Dvx$`Par(-u}V02euu{H63TL^X$*-q znqA)S0$nn7=;_V%*pZylR%sN<*{fN~QBxrBbQQzx;PD zemU96>gDy-_QsEkiw%C>aEh$N-+A`tB+vi6T)g*xlbtMAaa&fURjlBBvXgQgf8v!DlXYv7 z{JeENT4tS(brwFl!p92{m)HOK-HC@Od$}MDMlVZ^?xqd5`A7HZ?Sx NyD%^C33JMg$v-^VMVSBq literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9.onnx new file mode 100644 index 0000000000000000000000000000000000000000..12cb4f977e0832ef4e813c3bca261b4b2f48f74e GIT binary patch literal 2308 zcmd^ANpI6Y6pquxv0u8R7F1OZsIrg{gFs1(Kvbf%MuffcgizMH-+a0JOX==q6nt3h(~qKRaWahj0~YS?aahn;+aXPz zIb@)Stgc;g<}&D1Vmq;TmYS*bj^gJGNSMY*lsa1_~g1B0ZE}E9BRd{AGPR z>~xO~7>d;K=+<9JjtIQGk_tiCW2xVD^p$8g^8FZ%Al!zDsn@JH+B~;SHQTQF0B3xP zyUxg+8q?few{Z7|AXTYyt<^$nYqZyBRoheLrQFl*)IBXspbO1}a#@;KooeE-Rw{I? z(aI|HlDJ3J;F;>|n)i`6j#vAji8>dBsDQ6nDDL=Wdl%2K-34I`C131d(}@f9544{# zvSb-8!>v^+$M-z7X=v!wsDTH+ci@aKY}+pO*R9