From 15c1157ef2d179c911f4e16f87f074ea55ea8c8b Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 4 Jan 2023 17:51:14 +0800 Subject: [PATCH] New Pattern Support for LayerNormFusion (#14118) Latest torch exporter changed the LayerNorm exporting code to add two more Cast nodes (to make it logically correct in compute), but our current LayerNormFusion doesn't support the new pattern. The PR is to add support of this. --- .../core/optimizer/layer_norm_fusion.cc | 44 ++++++++----- .../test/optimizer/graph_transform_test.cc | 64 +++++++++++++++++++ 2 files changed, 91 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index ef47f9b064..b124c2570d 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -52,9 +52,9 @@ due to restriction in older opsets. Therefore, Layer Normalization will also han | | | v X --> ReduceMean --> Sub --> Cast --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add - | ^ - | | - +--------------------------------------------------------+ + | ^ + | | + +------------------------------------------------+ +---------------------+ Cast | | | | v v @@ -134,7 +134,6 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& sub_node = *graph.GetNode(p_sub_node->Index()); if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7, 13, 14}) || sub_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || - !optimizer_utils::CheckOutputEdges(graph, sub_node, subCnt == 1 ? 2u : 1u) || !IsSupportedDataType(sub_node)) { continue; } @@ -184,9 +183,24 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } // Apex O2 pattern specific match ends... - // Find the "Div" node after "Sub". + // Find the "Div" node after "Sub". It's possible that there is "Cast" node after "Sub" node. + const Node* p_cast1 = nullptr; + if (!p_sub_node_dup && sub_node.GetOutputEdgesCount() == 1) { + Node& cast_node = *graph.GetNode(sub_node.OutputNodesBegin()->Index()); + if (graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13}) && + cast_node.GetExecutionProviderType() == reduce_mean_node.GetExecutionProviderType() && + optimizer_utils::CheckOutputEdges(graph, cast_node, 2u) && IsSupportedDataType(cast_node)) { + p_cast1 = &cast_node; + nodes_to_remove.push_back(cast_node); + } + } + + if (!optimizer_utils::CheckOutputEdges(graph, sub_node, subCnt == 1 && !p_cast1 ? 2u : 1u)) { + continue; + } + const Node* p_div = nullptr; - p_div = graph_utils::FirstChildByType(sub_node, "Div"); + p_div = graph_utils::FirstChildByType(p_cast1 ? *p_cast1 : sub_node, "Div"); // Find the sub_dup node if exist if (p_sub_node_dup != nullptr) { @@ -269,23 +283,19 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, nodes_to_remove.push_back(pow_node); // check if Cast node exists: either between sub and pow, or as second input to pow - const Node* p_cast_node = graph_utils::FirstParentByType(pow_node, "Cast"); - if (p_cast_node != nullptr) { - Node& cast_node = *graph.GetNode(p_cast_node->Index()); + const Node* p_cast2 = graph_utils::FirstParentByType(pow_node, "Cast"); + if (p_cast2 != nullptr && p_cast2 != p_cast1) { + Node& cast_node = *graph.GetNode(p_cast2->Index()); if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13}) || cast_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, cast_node, 1)) { continue; } nodes_to_remove.push_back(cast_node); - - // Traceback from the last node in vector to find sub --> pow or sub --> cast - const Node* p_sub2_node = graph_utils::FirstParentByType(nodes_to_remove.back(), "Sub"); - if (p_sub2_node != nullptr) { - // Cast is between Sub and Pow - if ((p_sub2_node != p_sub_node && p_sub2_node != p_sub_node_dup) || !IsSupportedDataType(cast_node)) { - continue; - } + } else if (!p_cast2) { + const Node* p_sub2_node = graph_utils::FirstParentByType(pow_node, "Sub"); + if (!p_sub2_node || (p_sub2_node != p_sub_node && p_sub2_node != p_sub_node_dup)) { + continue; } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index eb0d2c9ebd..2044ade4ac 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4453,6 +4453,70 @@ TEST_F(GraphTransformationTests, LayerNormWithSubDupFusionTest) { } } +TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_5) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* pow_initializer = builder.MakeInitializer({}, {2.0f}); + auto* add_initializer = builder.MakeInitializer({}, {1e-5f}); + auto* weight_initializer = builder.MakeInitializer({3}, std::vector(3, MLFloat16(1.0f))); + auto* bias_initializer = builder.MakeInitializer({3}, std::vector(3, MLFloat16(0.0f))); + auto* reduce_mean_out_1 = builder.MakeIntermediate(); + auto* sub_out = builder.MakeIntermediate(); + auto* cast_out_1 = builder.MakeIntermediate(); + auto* pow_out = builder.MakeIntermediate(); + auto* reduce_mean_out_2 = builder.MakeIntermediate(); + auto* add_out_1 = builder.MakeIntermediate(); + auto* sqrt_out = builder.MakeIntermediate(); + auto* div_out = builder.MakeIntermediate(); + auto* cast_out_2 = builder.MakeIntermediate(); + auto* mul_out = builder.MakeIntermediate(); + auto* add_out_2 = builder.MakeOutput(); + + builder.AddNode("ReduceMean", {data_arg}, {reduce_mean_out_1}).AddAttribute("axes", std::vector{-1}); + builder.AddNode("Sub", {data_arg, reduce_mean_out_1}, {sub_out}); + builder.AddNode("Cast", {sub_out}, {cast_out_1}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + builder.AddNode("Pow", {cast_out_1, pow_initializer}, {pow_out}); + builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out_2}).AddAttribute("axes", std::vector{-1}); + builder.AddNode("Add", {reduce_mean_out_2, add_initializer}, {add_out_1}); + builder.AddNode("Sqrt", {add_out_1}, {sqrt_out}); + builder.AddNode("Div", {cast_out_1, sqrt_out}, {div_out}); + builder.AddNode("Cast", {div_out}, {cast_out_2}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); + builder.AddNode("Mul", {cast_out_2, weight_initializer}, {mul_out}); + builder.AddNode("Add", {mul_out, bias_initializer}, {add_out_2}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["LayerNormalization"] == 1); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_t5.onnx"; std::shared_ptr p_model;