diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index ef47f9b064..b124c2570d 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -52,9 +52,9 @@ due to restriction in older opsets. Therefore, Layer Normalization will also han | | | v X --> ReduceMean --> Sub --> Cast --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add - | ^ - | | - +--------------------------------------------------------+ + | ^ + | | + +------------------------------------------------+ +---------------------+ Cast | | | | v v @@ -134,7 +134,6 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& sub_node = *graph.GetNode(p_sub_node->Index()); if (!graph_utils::IsSupportedOptypeVersionAndDomain(sub_node, "Sub", {7, 13, 14}) || sub_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || - !optimizer_utils::CheckOutputEdges(graph, sub_node, subCnt == 1 ? 2u : 1u) || !IsSupportedDataType(sub_node)) { continue; } @@ -184,9 +183,24 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } // Apex O2 pattern specific match ends... - // Find the "Div" node after "Sub". + // Find the "Div" node after "Sub". It's possible that there is "Cast" node after "Sub" node. + const Node* p_cast1 = nullptr; + if (!p_sub_node_dup && sub_node.GetOutputEdgesCount() == 1) { + Node& cast_node = *graph.GetNode(sub_node.OutputNodesBegin()->Index()); + if (graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13}) && + cast_node.GetExecutionProviderType() == reduce_mean_node.GetExecutionProviderType() && + optimizer_utils::CheckOutputEdges(graph, cast_node, 2u) && IsSupportedDataType(cast_node)) { + p_cast1 = &cast_node; + nodes_to_remove.push_back(cast_node); + } + } + + if (!optimizer_utils::CheckOutputEdges(graph, sub_node, subCnt == 1 && !p_cast1 ? 2u : 1u)) { + continue; + } + const Node* p_div = nullptr; - p_div = graph_utils::FirstChildByType(sub_node, "Div"); + p_div = graph_utils::FirstChildByType(p_cast1 ? *p_cast1 : sub_node, "Div"); // Find the sub_dup node if exist if (p_sub_node_dup != nullptr) { @@ -269,23 +283,19 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, nodes_to_remove.push_back(pow_node); // check if Cast node exists: either between sub and pow, or as second input to pow - const Node* p_cast_node = graph_utils::FirstParentByType(pow_node, "Cast"); - if (p_cast_node != nullptr) { - Node& cast_node = *graph.GetNode(p_cast_node->Index()); + const Node* p_cast2 = graph_utils::FirstParentByType(pow_node, "Cast"); + if (p_cast2 != nullptr && p_cast2 != p_cast1) { + Node& cast_node = *graph.GetNode(p_cast2->Index()); if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13}) || cast_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, cast_node, 1)) { continue; } nodes_to_remove.push_back(cast_node); - - // Traceback from the last node in vector to find sub --> pow or sub --> cast - const Node* p_sub2_node = graph_utils::FirstParentByType(nodes_to_remove.back(), "Sub"); - if (p_sub2_node != nullptr) { - // Cast is between Sub and Pow - if ((p_sub2_node != p_sub_node && p_sub2_node != p_sub_node_dup) || !IsSupportedDataType(cast_node)) { - continue; - } + } else if (!p_cast2) { + const Node* p_sub2_node = graph_utils::FirstParentByType(pow_node, "Sub"); + if (!p_sub2_node || (p_sub2_node != p_sub_node && p_sub2_node != p_sub_node_dup)) { + continue; } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index eb0d2c9ebd..2044ade4ac 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4453,6 +4453,70 @@ TEST_F(GraphTransformationTests, LayerNormWithSubDupFusionTest) { } } +TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_5) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* pow_initializer = builder.MakeInitializer({}, {2.0f}); + auto* add_initializer = builder.MakeInitializer({}, {1e-5f}); + auto* weight_initializer = builder.MakeInitializer({3}, std::vector(3, MLFloat16(1.0f))); + auto* bias_initializer = builder.MakeInitializer({3}, std::vector(3, MLFloat16(0.0f))); + auto* reduce_mean_out_1 = builder.MakeIntermediate(); + auto* sub_out = builder.MakeIntermediate(); + auto* cast_out_1 = builder.MakeIntermediate(); + auto* pow_out = builder.MakeIntermediate(); + auto* reduce_mean_out_2 = builder.MakeIntermediate(); + auto* add_out_1 = builder.MakeIntermediate(); + auto* sqrt_out = builder.MakeIntermediate(); + auto* div_out = builder.MakeIntermediate(); + auto* cast_out_2 = builder.MakeIntermediate(); + auto* mul_out = builder.MakeIntermediate(); + auto* add_out_2 = builder.MakeOutput(); + + builder.AddNode("ReduceMean", {data_arg}, {reduce_mean_out_1}).AddAttribute("axes", std::vector{-1}); + builder.AddNode("Sub", {data_arg, reduce_mean_out_1}, {sub_out}); + builder.AddNode("Cast", {sub_out}, {cast_out_1}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + builder.AddNode("Pow", {cast_out_1, pow_initializer}, {pow_out}); + builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out_2}).AddAttribute("axes", std::vector{-1}); + builder.AddNode("Add", {reduce_mean_out_2, add_initializer}, {add_out_1}); + builder.AddNode("Sqrt", {add_out_1}, {sqrt_out}); + builder.AddNode("Div", {cast_out_1, sqrt_out}, {div_out}); + builder.AddNode("Cast", {div_out}, {cast_out_2}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); + builder.AddNode("Mul", {cast_out_2, weight_initializer}, {mul_out}); + builder.AddNode("Add", {mul_out, bias_initializer}, {add_out_2}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["LayerNormalization"] == 1); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_t5.onnx"; std::shared_ptr p_model;