diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index 8136e39ed7..a1bec8a3ca 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -103,24 +103,24 @@ static void AddNodes(std::vector& node_indices, It is because they are matched as part of other subgraph. */ -static bool MatchPositionSubgraph( +static bool MatchInputToConcatSubgraph( Graph& graph, - const Node& expand_node, + const Node& cur_node, const NodeArg* input_ids, + const int index, const logging::Logger& logger, std::vector& subgraph_node_indices, const NodeIndex expected_gather_node_1_index) { subgraph_node_indices.clear(); - std::vector expand_parent_path1{ - {0, 1, "Concat", {4, 11}, kOnnxDomain}, + {0, index, "Concat", {4, 11}, kOnnxDomain}, {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, {0, 0, "Gather", {1, 11}, kOnnxDomain}, {0, 0, "Shape", {1}, kOnnxDomain}, }; std::vector edges; - if (!graph_utils::FindPath(expand_node, true, expand_parent_path1, edges, logger)) { + if (!graph_utils::FindPath(cur_node, true, expand_parent_path1, edges, logger)) { DEBUG_LOG("Failed to find path 1 of position shape."); return false; } @@ -184,60 +184,80 @@ static bool MatchPositionSubgraph( } /** Match subgraph like the following: - (input_ids) - / \ - Shape Shape - | | - ^Gather (indice=0)^ Gather (indice=1)--+ - ^|^ ^|^ | - ^Unsqueeze^ ^Unsqueeze^ Unsqueeze - ^\^ ^/^ | - ^\^ ^/^ ConstantOfShape - ^\^ ^/^ | - ^Concat^ NonZero - | | - | Transpose - | | - | Squeeze - | | - | Cast - | | - | Unsqueeze - +--|----------------------------+ - | | - Expand - | - Gather - - Note that position gather node is the node in the bottom of above sub-graph. - Paths in ^^ are alternative path to be matched if path input_ids -> Shape -> Expand -> Gather is not found. -*/ -static bool MatchPositionEmbeddingSubgraph1( + * + * Shape -> ^Gather (indice=0)^ -> ^Unsqueeze^ + * / | +-----------------------+ + * / v | | + * [input_ids] ^Concat^ -> *Reshape* -> *Equal* -> *Where* -> Expand -> Gather + * \ | | ("position") + * Shape -> ^Gather (indice=1)^ -> ^Unsqueeze^ | + * | | + * +-------------- # one of the below subgraph patterns # ---------------+ + * # Unsqueeze -> ConstantOfShape -> NonZero -> Transpose -> Squeeze -> (Cast) -> Unsqueeze # + * # or # + * # (Cast (to=7)) -> Range (start=0, delta=1) -> Unsqueeze # + * + * Note that position gather node is the node in the bottom of above sub-graph. + * Paths in ^^ are alternative path to be matched if path input_ids -> Shape -> Expand -> Gather is not found. + * Path in ** is an alternative path to check. + */ +static bool MatchPositionEmbeddingSubgraphsFromGather( Graph& graph, const Node& position_gather_node, const NodeArg* input_ids, const logging::Logger& logger, std::vector& subgraph_node_indices) { subgraph_node_indices.clear(); - std::vector pg_edges; // Look for Path 1: - // Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze --> Cast --> Unsqueeze --> Expand --> Gather - if (!graph_utils::FindPath(position_gather_node, true, - {{0, 1, "Expand", {8}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, - {0, 0, "Cast", {9}, kOnnxDomain}, - {0, 0, "Squeeze", {1, 11}, kOnnxDomain}, - {0, 0, "Transpose", {1}, kOnnxDomain}, - {0, 0, "NonZero", {9}, kOnnxDomain}, - {0, 0, "ConstantOfShape", {9}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, - {0, 0, "Gather", {1, 11}, kOnnxDomain}, - {0, 0, "Shape", {1}, kOnnxDomain}}, - pg_edges, logger)) { + // Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze + // --> Cast --> Unsqueeze --> Expand --> Gather + std::vector parent_path_1{ + {0, 1, "Expand", {8}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 0, "Cast", {9}, kOnnxDomain}, + {0, 0, "Squeeze", {1, 11}, kOnnxDomain}, + {0, 0, "Transpose", {1}, kOnnxDomain}, + {0, 0, "NonZero", {9}, kOnnxDomain}, + {0, 0, "ConstantOfShape", {9}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 0, "Gather", {1, 11}, kOnnxDomain}, + {0, 0, "Shape", {1}, kOnnxDomain}}; + // Look for Path 2 (Path 1 with no cast): + std::vector parent_path_2{ + {0, 1, "Expand", {8}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 0, "Squeeze", {1, 11}, kOnnxDomain}, + {0, 0, "Transpose", {1}, kOnnxDomain}, + {0, 0, "NonZero", {9}, kOnnxDomain}, + {0, 0, "ConstantOfShape", {9}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 0, "Gather", {1, 11}, kOnnxDomain}, + {0, 0, "Shape", {1}, kOnnxDomain}}; + // Path 3 Pattern: + // Shape -> Gather -> Cast (to=7) -> Range (start=0, delta=1) -> Unsqueeze -> Expand + std::vector parent_path_3{ + {0, 1, "Expand", {8}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 0, "Range", {1, 11}, kOnnxDomain}, + {0, 1, "Cast", {9}, kOnnxDomain}, + {0, 0, "Gather", {1, 11}, kOnnxDomain}, + {0, 0, "Shape", {1}, kOnnxDomain}}; + // Path 4 pattern (Path 3 with no "Cast"): + std::vector parent_path_4{ + {0, 1, "Expand", {8}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 0, "Range", {1, 11}, kOnnxDomain}, + {0, 1, "Gather", {1, 11}, kOnnxDomain}, + {0, 0, "Shape", {1}, kOnnxDomain}}; + // Match one of the three path patterns. + if (!graph_utils::FindPath(position_gather_node, true, parent_path_1, pg_edges, logger) && + !graph_utils::FindPath(position_gather_node, true, parent_path_2, pg_edges, logger) && + !graph_utils::FindPath(position_gather_node, true, parent_path_3, pg_edges, logger) && + !graph_utils::FindPath(position_gather_node, true, parent_path_4, pg_edges, logger)) { return false; } - const size_t gather_index = 8; + const size_t gather_index = pg_edges.size() - 2; // All nodes in Path 1 must have only 1 output edge, except the gather node allowed 1 or 2 output edges for (size_t i = 0; i < pg_edges.size(); i++) { if (!optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 1)) { @@ -251,6 +271,18 @@ static bool MatchPositionEmbeddingSubgraph1( Node& expand_node = *graph.GetNode(pg_edges[0]->GetNode().Index()); Node& gather_node = *graph.GetNode(pg_edges[gather_index]->GetNode().Index()); + if (pg_edges[2]->GetNode().OpType() == "Range") { + // Check if the values in "start" and "delta" attributes in Range are expected. + Node& range_node = *graph.GetNode(pg_edges[2]->GetNode().Index()); + if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[0]), int64_t(0), true)) { + DEBUG_LOG("The first input of Range should be a constant with value 0."); + return false; + } + if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[2]), int64_t(1), true)) { + DEBUG_LOG("The third input of Range should be a constant with value 1."); + return false; + } + } if (gather_node.GetOutputEdgesCount() == 1) { // Check if the second input of the Gather node in the path has a constant input of 1 @@ -279,7 +311,35 @@ static bool MatchPositionEmbeddingSubgraph1( subgraph_node_indices.push_back(shape_node_index); } else { // gather_output_edges_count == 2 - if (!MatchPositionSubgraph(graph, expand_node, input_ids, logger, subgraph_node_indices, gather_node.Index())) { + // Match optional Reshape -> Equal -> Where -> Expand + // | | + // -------------------- + std::vector pg_edges_2; + std::vector path_to_match_1{ + {0, 1, "Where", {9}, kOnnxDomain}, + {0, 0, "Equal", {1, 11}, kOnnxDomain}, + {0, 0, "Reshape", {5}, kOnnxDomain}}; + if (graph_utils::FindPath(expand_node, true, path_to_match_1, pg_edges_2, logger)) { + if (!optimizer_utils::CheckOutputEdges(graph, pg_edges_2[0]->GetNode(), 1) || + !optimizer_utils::CheckOutputEdges(graph, pg_edges_2[1]->GetNode(), 1) || + !optimizer_utils::CheckOutputEdges(graph, pg_edges_2[2]->GetNode(), 2)) { + DEBUG_LOG("Optional position subgraph nodes number of outputs unexpected."); + return false; + } + Node& where_node = *graph.GetNode(pg_edges_2[0]->GetNode().Index()); + Node& reshape_node = *graph.GetNode(pg_edges_2[2]->GetNode().Index()); + if (where_node.MutableInputDefs()[2] != reshape_node.MutableOutputDefs()[0]) { + DEBUG_LOG("Optional position subgraph nodes Where node is expected to be the parent of Reshape."); + return false; + } + // Match [input_ids] -> Gather -> Shape -> Unsqueeze from Reshape node. + if (!MatchInputToConcatSubgraph(graph, reshape_node, input_ids, 0, logger, subgraph_node_indices, gather_node.Index())) { + DEBUG_LOG("Failed to match position subgraph."); + return false; + } + AddNodes(subgraph_node_indices, pg_edges_2); + } else if (!MatchInputToConcatSubgraph(graph, expand_node, input_ids, 1, logger, subgraph_node_indices, gather_node.Index())) { + // Match [input_ids] -> Gather -> Shape -> Unsqueeze from Expand node. DEBUG_LOG("Failed to match position subgraph."); return false; } @@ -290,90 +350,6 @@ static bool MatchPositionEmbeddingSubgraph1( return true; } -/** Match subgraph like the following: - (input_ids) - / \ - Shape Shape - | | - Gather (indice=0) Gather (indice=1)--+ - | | | - Unsqueeze Unsqueeze Cast(to=7) (Cast is optional) - \ / | - \ / Range(start=0, delta=1) - \ / | - Concat Unsqueeze - | | - +--|----------------------------+ - | | - Expand - | - Gather - - Note that position gather node is the node in the bottom of above sub-graph. -*/ - -static bool MatchPositionEmbeddingSubgraph2( - Graph& graph, - const Node& position_gather_node, - const NodeArg* input_ids, - const logging::Logger& logger, - std::vector& subgraph_node_indices) { - subgraph_node_indices.clear(); - - // Match Gather <-- Expand <-- Unsqueeze <-- Range <-- Cast <-- Gather - // Since Range is from opset 11, we only match opset 11 here. - std::vector position_parent_nodes; - std::vector position_embedding_path_symbolic{ - {0, 1, "Expand", {8}, kOnnxDomain}, - {0, 0, "Unsqueeze", {11}, kOnnxDomain}, - {0, 0, "Range", {11}, kOnnxDomain}, - {0, 1, "Cast", {9}, kOnnxDomain}, - {0, 0, "Gather", {11}, kOnnxDomain}}; - std::vector edges; - - if (!graph_utils::FindPath(position_gather_node, true, position_embedding_path_symbolic, edges, logger)) { - // Cast node might be removed by other optimizer. Here we check a pattern without Cast node. - std::vector position_embedding_path_no_cast{ - {0, 1, "Expand", {8}, kOnnxDomain}, - {0, 0, "Unsqueeze", {11}, kOnnxDomain}, - {0, 0, "Range", {11}, kOnnxDomain}, - {0, 1, "Gather", {11}, kOnnxDomain}}; - if (!graph_utils::FindPath(position_gather_node, true, position_embedding_path_no_cast, edges, logger)) { - DEBUG_LOG("Failed to find path 1."); - return false; - } - } - - size_t last_edge = edges.size() - 1; - for (size_t i = 0; i < edges.size(); i++) { - if (!optimizer_utils::CheckOutputEdges(graph, edges[i]->GetNode(), (i == last_edge ? 2u : 1u))) { - DEBUG_LOG("Output edge count not expected for nodes in path 1."); - return false; - } - } - - Node& expand_node = *graph.GetNode(edges[0]->GetNode().Index()); - Node& range_node = *graph.GetNode(edges[2]->GetNode().Index()); - Node& gather_node_1 = *graph.GetNode(edges[last_edge]->GetNode().Index()); - if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[0]), int64_t(0), true)) { - DEBUG_LOG("The first input of Range should be a constant with value 0."); - return false; - } - if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[2]), int64_t(1), true)) { - DEBUG_LOG("The third input of Range should be a constant with value 1."); - return false; - } - - if (!MatchPositionSubgraph(graph, expand_node, input_ids, logger, subgraph_node_indices, gather_node_1.Index())) { - DEBUG_LOG("Failed to match position subgraph."); - return false; - } - - AddNodes(subgraph_node_indices, edges); - - return true; -} - static bool MatchPositionEmbeddingSubgraph( Graph& graph, const Node& add_node, @@ -422,10 +398,8 @@ static bool MatchPositionEmbeddingSubgraph( } } } else { - if (!MatchPositionEmbeddingSubgraph1(graph, position_gather_node, input_ids, logger, subgraph_node_indices)) { - if (!MatchPositionEmbeddingSubgraph2(graph, position_gather_node, input_ids, logger, subgraph_node_indices)) { - return false; - } + if (!MatchPositionEmbeddingSubgraphsFromGather(graph, position_gather_node, input_ids, logger, subgraph_node_indices)) { + return false; } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 94be6c3469..c59c6f34f6 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2424,6 +2424,35 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) { } } +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) { + auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format6.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Expand"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["Reshape"], 0); + EXPECT_EQ(op_to_count["Equal"], 0); + EXPECT_EQ(op_to_count["Where"], 0); + EXPECT_EQ(op_to_count["LayerNormalization"], 0); + EXPECT_EQ(op_to_count["SkipLayerNormalization"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 0); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["Add"], 2); + EXPECT_EQ(op_to_count["Cast"], 3); + EXPECT_EQ(op_to_count["Attention"], 1); + EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1); +} + TEST_F(GraphTransformationTests, DynamicQuantizeMatMulTest) { auto model_uri = MODEL_FOLDER "fusion/dynamic_quantize_matmul.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format6.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format6.onnx new file mode 100644 index 0000000000..dd51076224 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format6.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py index f7b9cd30ab..97e7abfb55 100644 --- a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py @@ -143,7 +143,86 @@ def GenerateModel5(model_name): model = helper.make_model(graph) onnx.save(model, model_name) +def GenerateModel6(model_name): + nodes = [ # LayerNorm subgraph + helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"), + helper.make_node("Gather", ["shape1_out", "indices_0"], ["gather0_out"], "gather0"), + helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), + helper.make_node("Shape", ["input_ids"], ["shape2_out"], "shape2"), + helper.make_node("Gather", ["shape2_out", "indices_1"], ["gather1_out"], "gather1"), + helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), + helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out"], ["concat_out"], "concat", axis=0), + + helper.make_node("Reshape", ["concat_out", "reshape_init"], ["reshape_out"], "reshape"), + helper.make_node("Equal", ["reshape_out", "equal_init"], ["equal_out"], "equal"), + helper.make_node("Where", ["equal_out", "where_init", "reshape_out"], ["where_out"], "where"), + + helper.make_node("Range", ["start_0", "gather1_out", "delta_1"], ["range_out"], + "range"), + helper.make_node("Unsqueeze", ["range_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), + + helper.make_node("Expand", ["unsqueeze2_out", "where_out"], ["expand_out"], "expand"), + helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather"), + helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather"), + helper.make_node("Add", ["word_gather_out", "pos_gather_out"], ["word_add_pos_out"], "word_add_pos"), + helper.make_node("Gather", ["seg_embed", "segment_ids"], ["seg_gather_out"], "seg_gather"), + helper.make_node("Add", ["word_add_pos_out", "seg_gather_out"], ["add3_out"], "add3"), + helper.make_node("LayerNormalization", ["add3_out", "layer_norm_weight", "layer_norm_bias"], ["layernorm_out"], + "layernorm", + axis=-1, + epsion=0.000009999999747378752), + helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6), + helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0), + helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"], + "att", + domain="com.microsoft", + num_heads=2), + helper.make_node("MatMul", ["att_out", "matmul_weight"], ["matmul_out"], "matmul"), + helper.make_node("Add", ["matmul_out", "add_bias"], ["add_out"], "add"), + helper.make_node("Add", ["add_out", "layernorm_out"], ["add2_out"], "add2") + ] + + + # hidden_size=4, num_heads=2, max_seq_length=3 + initializers = [ # initializers + helper.make_tensor('indices_0', TensorProto.INT64, [], [0]), + helper.make_tensor('indices_1', TensorProto.INT64, [], [1]), + helper.make_tensor('start_0', TensorProto.INT64, [], [0]), + helper.make_tensor('delta_1', TensorProto.INT64, [], [1]), + helper.make_tensor('word_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('pos_embed', TensorProto.FLOAT, [4, 4], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('seg_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 4], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('qkv_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor('matmul_weight', TensorProto.FLOAT, [4, 4], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('add_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor('reshape_init', TensorProto.INT64, [1], [-1]), + helper.make_tensor('equal_init', TensorProto.INT64, [2], [-1, -1]), + helper.make_tensor('where_init', TensorProto.INT64, [2], [1, 1]), + ] + + graph = helper.make_graph( + nodes, + "EmbedLayerNorm_format6", #name + [ # inputs + helper.make_tensor_value_info('input_ids', TensorProto.INT64, ['batch', 3]), + helper.make_tensor_value_info('segment_ids', TensorProto.INT64, ['batch', 3]), + helper.make_tensor_value_info('input_mask', TensorProto.INT64, ['batch', 3]), + ], + [ # outputs + helper.make_tensor_value_info('add2_out', TensorProto.FLOAT, ['batch', 3, 4]), + ], + initializers) + + model = helper.make_model(graph) + onnx.save(model, model_name) GenerateModel3('embed_layer_norm_format3.onnx', True) GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False) GenerateModel5('embed_layer_norm_format5.onnx') +GenerateModel6('embed_layer_norm_format6.onnx') \ No newline at end of file