diff --git a/onnxruntime/core/optimizer/quick_gelu_fusion.cc b/onnxruntime/core/optimizer/quick_gelu_fusion.cc index 93de7a64bd..d8e627fbb0 100644 --- a/onnxruntime/core/optimizer/quick_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/quick_gelu_fusion.cc @@ -30,7 +30,8 @@ Status QuickGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, int alpha_index = -1; float alpha = 1.0f; if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14}) && - graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && node.GetOutputEdgesCount() == 1) { + graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && node.GetOutputEdgesCount() == 1 && + !graph.NodeProducesGraphOutput(node)) { for (int i = 0; i < static_cast(node.InputDefs().size()); ++i) { const NodeArg& input_arg = *(node.InputDefs()[i]); if (!optimizer_utils::IsScalar(input_arg)) continue; @@ -68,7 +69,7 @@ Status QuickGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& sigmoid_node = *p_sigmoid_node; if (!graph_utils::IsSupportedOptypeVersionAndDomain(sigmoid_node, "Sigmoid", {6, 13}) || !graph_utils::IsSupportedProvider(sigmoid_node, GetCompatibleExecutionProviders()) || - sigmoid_node.GetOutputEdgesCount() != 1) { + sigmoid_node.GetOutputEdgesCount() != 1 || graph.NodeProducesGraphOutput(sigmoid_node)) { continue; } nodes_to_fuse.emplace_back(sigmoid_node); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index ddd1e43395..da9cd3caac 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4049,6 +4049,72 @@ TEST_F(GraphTransformationTests, QuickGelu) { pre_graph_checker, post_graph_checker)); } + // Sigmoid's output is a graph output. + { + constexpr float alpha = 1.702f; + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* alpha_arg = builder.MakeInitializer({}, {alpha}); + auto* mul_out_0 = builder.MakeIntermediate(); + auto* sigmoid_out = builder.MakeOutput(); + auto* mul_out_1 = builder.MakeOutput(); + + builder.AddNode("Mul", {alpha_arg, input_arg}, {mul_out_0}); + builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out}); + builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out_1}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sigmoid"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sigmoid"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["com.microsoft.QuickGelu"] == 0); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } + + // First Mul's output is a graph output. + { + constexpr float alpha = 1.702f; + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* alpha_arg = builder.MakeInitializer({}, {alpha}); + auto* mul_out_0 = builder.MakeOutput(); + auto* sigmoid_out = builder.MakeIntermediate(); + auto* mul_out_1 = builder.MakeOutput(); + + builder.AddNode("Mul", {alpha_arg, input_arg}, {mul_out_0}); + builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out}); + builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out_1}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sigmoid"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sigmoid"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["com.microsoft.QuickGelu"] == 0); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } + // Sigmoid(x)*x, float { auto build_test_case = [&](ModelTestBuilder& builder) {