From 200f4b4ea616ef4a35aeced6795e997ecfa2349d Mon Sep 17 00:00:00 2001 From: liuziyue Date: Sat, 7 Dec 2019 23:14:26 -0800 Subject: [PATCH] EmbedLayerNormalization Fusion Improvement (#2553) Embedding layer norm fusion improvements - add more checks --- .../core/optimizer/embed_layer_norm_fusion.cc | 205 +++++++++++++----- .../test/optimizer/graph_transform_test.cc | 5 + .../fusion/embed_layer_norm_format1.onnx | Bin 940 -> 985 bytes .../fusion/embed_layer_norm_format2.onnx | Bin 1060 -> 1787 bytes 4 files changed, 154 insertions(+), 56 deletions(-) diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index 34896d7efc..970de353f5 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -3,6 +3,8 @@ #include "core/optimizer/initializer.h" #include "core/optimizer/embed_layer_norm_fusion.h" #include "core/graph/graph_utils.h" +#include "core/optimizer/utils.h" +#include "core/framework/tensorprotoutils.h" #include "float.h" #define DEBUG_LOG(x) LOGS(logger, VERBOSE) << x @@ -13,6 +15,10 @@ namespace onnxruntime { // Add a Cast to convert Input from int64 to int32. static NodeArg* CastToInt32(Graph& graph, NodeArg* input, ProviderType provider_type) { + auto data_type = input->TypeAsProto()->tensor_type().elem_type(); + if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32) { + return input; + } const TensorShapeProto* input_shape = input->Shape(); TypeProto input_int32; input_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); @@ -41,26 +47,22 @@ static NodeArg* CastToInt32(Graph& graph, NodeArg* input, ProviderType provider_ return &cast32; } -static NodeArg* CheckInput(Graph& graph, NodeArg* input, ProviderType provider_type, const logging::Logger& logger) { +static bool CheckInput(NodeArg* input, const logging::Logger& logger) { // Validate input shape (batch_size, sequence_length) and data type. // Note that batch_size and sequence_length could be symbolic. const TensorShapeProto* input_shape = input->Shape(); if (input_shape == nullptr || input_shape->dim_size() != 2 || input->Type() == nullptr) { - DEBUG_LOG("Mask shape is unknown or not 2D, or data type unknown"); - return nullptr; + DEBUG_LOG("Input shape is unknown or not 2D, or data type unknown"); + return false; } auto data_type = input->TypeAsProto()->tensor_type().elem_type(); if (data_type != ONNX_NAMESPACE::TensorProto_DataType_INT64 && data_type != ONNX_NAMESPACE::TensorProto_DataType_INT32) { DEBUG_LOG("Input data type is not int32 or int64"); - return nullptr; + return false; } - - if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) { - return CastToInt32(graph, input, provider_type); - } - return input; + return true; } /** @@ -124,8 +126,11 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } // The first input of segment_gather_node must be 2d. - auto sg_shape = segment_gather_node.MutableInputDefs()[0]->Shape(); - if (sg_shape != nullptr && sg_shape->dim_size() != 2) { + NodeArg* segment_embedding = segment_gather_node.MutableInputDefs()[0]; + auto sg_shape = segment_embedding->Shape(); + if (sg_shape == nullptr || sg_shape->dim_size() != 2 || + !utils::HasDimValue(sg_shape->dim()[1]) || + sg_shape->dim()[1].dim_value() <= 0) { continue; } @@ -142,8 +147,11 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } // The first input of word_gather_node must be 2d. - auto wg_shape = word_gather_node.MutableInputDefs()[0]->Shape(); - if (wg_shape != nullptr && wg_shape->dim_size() != 2) { + NodeArg* word_embedding = word_gather_node.MutableInputDefs()[0]; + auto wg_shape = word_embedding->Shape(); + if (wg_shape == nullptr || wg_shape->dim_size() != 2 || + !utils::HasDimValue(wg_shape->dim()[1]) || + wg_shape->dim()[1].dim_value() != sg_shape->dim()[1].dim_value()) { continue; } @@ -160,51 +168,157 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } // The first input of position_gather_node must be 2d. - auto pg_shape = position_gather_node.MutableInputDefs()[0]->Shape(); - if (pg_shape != nullptr && pg_shape->dim_size() != 2) { + NodeArg* position_embedding = position_gather_node.MutableInputDefs()[0]; + auto pg_shape = position_embedding->Shape(); + if (pg_shape == nullptr || pg_shape->dim_size() != 2 || + !utils::HasDimValue(pg_shape->dim()[1]) || + pg_shape->dim()[1].dim_value() != sg_shape->dim()[1].dim_value()) { continue; } - // Match Shape --> Expand path if needed. - std::vector position_embedding_path_symbolic{ - {0, 1, "Expand", {8}, kOnnxDomain}, - {0, 1, "Shape", {1}, kOnnxDomain}}; + // Check the second input of position gather. If it's not initializer, check for two paths. Node* p_expand_node = nullptr; Node* p_shape_node = nullptr; - if (graph_utils::FindPath(position_gather_node, true, position_embedding_path_symbolic, edges, logger)) { - if (edges[0]->GetNode().GetOutputEdgesCount() == 1 && edges[1]->GetNode().GetOutputEdgesCount() == 1) { - p_expand_node = graph.GetNode(edges[0]->GetNode().Index()); - p_shape_node = graph.GetNode(edges[1]->GetNode().Index()); + std::vector pg_edges; + bool isValidEmbedSubNode = true; + if (graph_utils::IsConstantInitializer(graph, position_gather_node.MutableInputDefs()[1]->Name())) { + // Check if the second input of position gather is a tensor with values evenly spaced by 1 starting from 0. + std::vector data; + auto expected_shape = word_gather_node.MutableInputDefs()[1]->Shape(); + if (!optimizer_utils::AppendTensorFromInitializer(graph, *(position_gather_node.MutableInputDefs()[1]), data) + || !utils::HasDimValue(expected_shape->dim()[0]) + || !utils::HasDimValue(expected_shape->dim()[1]) + || static_cast(data.size()) != expected_shape->dim()[0].dim_value() * expected_shape->dim()[1].dim_value()) { + continue; } + int64_t expected_value = 0; + for (size_t i = 0; i < data.size(); i++) { + if (data[i] != expected_value) { + isValidEmbedSubNode = false; + break; + } + expected_value++; + if (expected_value >= static_cast(expected_shape->dim()[1].dim_value())) { + expected_value = 0; + } + } + } else { + // Match two paths. + // Match Shape --> Expand path if needed. + std::vector position_parent_nodes; + std::vector position_embedding_path_symbolic{ + {0, 1, "Expand", {8}, kOnnxDomain}, + {0, 1, "Shape", {1}, kOnnxDomain}}; + if (!graph_utils::FindPath(position_gather_node, true, position_embedding_path_symbolic, edges, logger)) { + continue; + } + if (edges[0]->GetNode().GetOutputEdgesCount() != 1 && edges[1]->GetNode().GetOutputEdgesCount() != 1) { + continue; + } + p_expand_node = graph.GetNode(edges[0]->GetNode().Index()); + p_shape_node = graph.GetNode(edges[1]->GetNode().Index()); + // Match Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze --> Cast --> Unsqueeze --> Expand + Node& expand_node = *graph.GetNode(edges[0]->GetNode().Index()); + Node& shape_node_1 = *graph.GetNode(edges[1]->GetNode().Index()); + std::vector pg_parent_path{ + {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 0, "Cast", {9}, kOnnxDomain}, + {0, 0, "Squeeze", {1}, kOnnxDomain}, + {0, 0, "Transpose", {1}, kOnnxDomain}, + {0, 0, "NonZero", {9}, kOnnxDomain}, + {0, 0, "ConstantOfShape", {9}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 0, "Gather", {1, 11}, kOnnxDomain}, + {0, 0, "Shape", {1}, kOnnxDomain}, + }; + if (!graph_utils::FindPath(expand_node, true, pg_parent_path, pg_edges, logger)) { + continue; + } + for (size_t i = 0; i < pg_edges.size(); i++) { + if (pg_edges[i]->GetNode().GetOutputEdgesCount() != 1) { + isValidEmbedSubNode = false; + break; + } + } + // Check if the second input of the Gather node in the path has a constant input of 1 + Node& gather_node = *graph.GetNode(pg_edges[pg_edges.size() - 2]->GetNode().Index()); + if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(gather_node.InputDefs()[1]), int64_t(1), true)) { + DEBUG_LOG("Second input of Gather should be a constant with value 1. "); + + continue; + } + // Check if the two paths of position gather lead to the same input. + Node& shape_node_2 = *graph.GetNode(pg_edges[pg_edges.size() - 1]->GetNode().Index()); + if (shape_node_1.MutableInputDefs()[0] != shape_node_2.MutableInputDefs()[0]) { + continue; + } + // Check if the parent of "shape" is the parent of "word gather" + if (shape_node_1.MutableInputDefs()[0] != word_gather_node.MutableInputDefs()[1]) { + continue; + } + + } + if (!isValidEmbedSubNode) { + continue; } // Get input "input_ids" from node. - NodeArg* input_ids = CheckInput(graph, word_gather_node.MutableInputDefs()[1], layer_norm_node.GetExecutionProviderType(), logger); - if (input_ids == nullptr) { + NodeArg* input_ids = word_gather_node.MutableInputDefs()[1]; + if (!CheckInput(input_ids, logger)) { DEBUG_LOG("Input id is not valid. "); continue; } // Get input "segment_ids" from node. - NodeArg* segment_ids = CheckInput(graph, segment_gather_node.MutableInputDefs()[1], layer_norm_node.GetExecutionProviderType(), logger); - if (segment_ids == nullptr) { + NodeArg* segment_ids = segment_gather_node.MutableInputDefs()[1]; + if (!CheckInput(segment_ids, logger)) { DEBUG_LOG("Segment id is not valid. "); continue; } // Get input "mask" from "ReduceSum" node. - NodeArg* mask = CheckInput(graph, reduce_sum_node.MutableInputDefs()[0], layer_norm_node.GetExecutionProviderType(), logger); - if (mask == nullptr) { + NodeArg* mask = reduce_sum_node.MutableInputDefs()[0]; + if (!CheckInput(mask, logger)) { DEBUG_LOG("Mask is not valid. "); continue; } + if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) != + utils::GetTensorShapeFromTensorShapeProto(*(segment_ids->Shape()))) { + DEBUG_LOG("Input_ids and segment id should have the same shape. "); + continue; + } + if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) != + utils::GetTensorShapeFromTensorShapeProto(*(mask->Shape()))) { + DEBUG_LOG("Input_ids and mask should have the same shape. "); + continue; + } + + NodeArg* gamma = layer_norm_node.MutableInputDefs()[1]; + NodeArg* beta = layer_norm_node.MutableInputDefs()[2]; + if (gamma->Shape() == nullptr + || gamma->Shape()->dim()[0].dim_value() != word_embedding->Shape()->dim()[1].dim_value()) { + DEBUG_LOG("Gamma should be of shape (hidden_size). "); + continue; + } + + if (beta->Shape() == nullptr + || beta->Shape()->dim()[0].dim_value() != word_embedding->Shape()->dim()[1].dim_value()) { + DEBUG_LOG("Beta should be of shape (hidden_size). "); + continue; + } + + // Cast input_ids, segment_ids, and mask to int32 if needed. + input_ids = CastToInt32(graph, input_ids, layer_norm_node.GetExecutionProviderType()); + segment_ids = CastToInt32(graph, segment_ids, layer_norm_node.GetExecutionProviderType()); + mask = CastToInt32(graph, mask, layer_norm_node.GetExecutionProviderType()); + const std::vector embed_layer_norm_input_defs{ input_ids, segment_ids, - word_gather_node.MutableInputDefs()[0], - position_gather_node.MutableInputDefs()[0], - segment_gather_node.MutableInputDefs()[0], + word_embedding, + position_embedding, + segment_embedding, layer_norm_node.MutableInputDefs()[1], layer_norm_node.MutableInputDefs()[2], mask}; @@ -222,31 +336,10 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l // move output definitions and output edges to embed_layer_norm_node. // remove all the other nodes. std::vector nodes_to_remove; + for (size_t i = 0; i < pg_edges.size(); i++) { + nodes_to_remove.push_back(pg_edges[i]->GetNode().Index()); + } if (p_shape_node != nullptr && p_expand_node != nullptr) { - // Match Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze --> Cast --> Unsqueeze --> Expand - if (p_expand_node != nullptr) { - Node& expand_node = *graph.GetNode(p_expand_node->Index()); - std::vector expand_parent_path{ - {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, - {0, 0, "Cast", {9}, kOnnxDomain}, - {0, 0, "Squeeze", {1}, kOnnxDomain}, - {0, 0, "Transpose", {1}, kOnnxDomain}, - {0, 0, "NonZero", {9}, kOnnxDomain}, - {0, 0, "ConstantOfShape", {9}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, - {0, 0, "Gather", {1, 11}, kOnnxDomain}, - {0, 0, "Shape", {1}, kOnnxDomain}, - }; - if (graph_utils::FindPath(expand_node, true, expand_parent_path, edges, logger)) { - for (size_t i = 0; i < edges.size(); i++) { - if (edges[i]->GetNode().GetOutputEdgesCount() != 1) { - nodes_to_remove.clear(); - break; - } - nodes_to_remove.push_back(edges[i]->GetNode().Index()); - } - } - } nodes_to_remove.push_back(p_shape_node->Index()); nodes_to_remove.push_back(p_expand_node->Index()); } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index a186c26808..bd7665fded 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1317,6 +1317,11 @@ TEST(GraphTransformationTests, EmbedLayerNormFusionFormat2) { ASSERT_TRUE(op_to_count["Shape"] == 0); ASSERT_TRUE(op_to_count["Expand"] == 0); ASSERT_TRUE(op_to_count["Gather"] == 0); + ASSERT_TRUE(op_to_count["Unsqueeze"] == 0); + ASSERT_TRUE(op_to_count["ConstantOfShape"] == 0); + ASSERT_TRUE(op_to_count["NonZero"] == 0); + ASSERT_TRUE(op_to_count["Transpose"] == 0); + ASSERT_TRUE(op_to_count["Squeeze"] == 0); ASSERT_TRUE(op_to_count["Add"] == 0); ASSERT_TRUE(op_to_count["ReduceSum"] == 0); ASSERT_TRUE(op_to_count["Attention"] == 1); diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format1.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format1.onnx index 413c52bdb280d5976c5b3624353ce017a2129e4d..cd535abf3a95fcb92645aa734920c595dd2e0034 100644 GIT binary patch delta 335 zcmZ3(ev^HI1k-NziT0us^##Rrx!B56Gt)ClxVY03OEOZ6;xqFyOD6h!Vsx5p&!{4& z#m&W%Sdm)nz_@^sU5lTKBRe&Y~QueY|y|#DvdD!1wzSqvFGTB~hhqS#x zZGgS}rscL98!p(%v^}=9)i$=Tn19#iuC|c9VWztML?u=GwOOa_e(k(#_kDkvy-r7j zUGo}uTR!<}`|YLX_LDC%$*I)6SZBLS`h?vTYh(N5DNgpUc(Ux(H_6#IA31I3BM@jO z5yWmkS(rK8NrsCpzqF*Fv_wdVOMrt>h=+@bg9(V4ftWE%o=dPOHKjB;HNLnsHy)~x Ri;IJUO^5}g(1}Ta3jkImX*d7? delta 285 zcmcb~zJ`5*1Y_SsNp&7&F1GU2%=C;BF2;z7VV@XPCVK*LH{)(GF%CunHZ4XbC*IVG zg2cR(c(A(3iA;t$XT0O=_89cp>Q9+%x4V6}9rJxI`-~%|_V1=>+Uwoow`buIwYL#< zx3_$mY;XVH$UY(=#QwnA$9AvnuG?)0m9kqD5@lEDm1A#pEzN$5fr4FdoRa;@)2r>~ zCo9{tPH3<*S}J7!^$DYm(z0iElm9Wvsq~(ex9@covG*-3uwQ;G)&5T7eOu#*Q2PcZ zR{Mijr`Uz$x7bd0U=DYZ;9|=!Eh#815#rAX^+MuR}WFNv0+^PnxAnbMgekQf&( zj7E%cuaXd>8xz_1Te#7c8&@ujQR2?HaOKLx#QD9UGn7R$z31F_|M#-Bs6{DhU0O+y z)mfz^30h8()xmEKun&UQNNz5x0dom8&IwwUlzP)jSu2*cc<_`dL(CxS)vQn?l&YOH zBxp0329xRYQ2pS>q0X|N&dE8Ay6UHosH;J+gP;tGlx20HOvsAkJA9rZTpa|2C88#M zYfK1^!!}(M&FG10f~+aCuK1gVZG3Fj%A^-|P(_WhBCn8=LPHIWD)N*rLl<}`6frAY zRMf3l&lxp05Ey`MG-~vu7MiqDbrNHY#Ja(whw3tds?mS~!@9tQ!LKQ?&>G{{Jyr6k z*{Pt@qn%1P0Na(p)p^uFmWqNbQs==FrU*UGAh{K2rybX{(|xqlblj*+HBp@)ij<7k z42fdpv|gr$k2;aq)7k@-52;nEW?H2#EURbD2NZ)^H6s)V1ZHxAioyliw%?9Mgk@41 zHP)w)TQOHCUIl{`1SCP7N6jT7mNO)+mL(cuc*2Ax%aW_XTobVKgcQY`q{j47_QNd( zwmd`YkfssY*xB^L{9@W(SmZ{HaACVo6ttGPJ?yB*xBhYw2+MkWSt5vPtX_9L*pN}A zZYh^3DXKE7p%{^7h?p)Ixu<14`LuOF`nXn>VVMXM5BPk(WXe7x9Et=vhDlW}JJ(G0 zlF5`+T`p1qnU&`5Fl-GQ*6kIZzTAVw*U7{cSip0u8;J)S$8d*q27m86gr_?W;19ts zJUK}4#lus0OqKah|K9RXj=bVKU%cnr?~dcIKlb9~dkH>tA&+yKxACUMy5k&FEhy-|ksT z0G17*fT5G$8u`m*ESy7u|c!&ih)eeLb(E{j$@#_A`#~ z*-K>^+B4|p+8X^gvg7@lWY7QfvCWi1PJ1s|Z@Y&V%undUkhZ@T z|HbaoEH!)f@NoO7Ob6|9WA!J;utrEqaIxi=mK2nh2=Q_8a4-sSa4~T(0WtIBi>$I5 f%3Oj)sVSw&sqw|7x$#H}*}w`JfeIO&m;|^0r0-vf