From 91122a2cf51516b89f483911d00a30da4315a2b0 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Mon, 21 Oct 2019 20:18:45 -0700 Subject: [PATCH] Fix GELU fusion (#2213) * Split graph_utils methods for finalization of fusion in order to support more than 2 nodes being fused into one. Update GELU fusion to use graph_utils to set up the input/output edges for the fused node, and removing nodes that are being replaced. --- onnxruntime/core/graph/graph_utils.cc | 48 +++++++++++++---- onnxruntime/core/graph/graph_utils.h | 30 +++++------ .../core/optimizer/conv_activation_fusion.cc | 2 +- onnxruntime/core/optimizer/gelu_fusion.cc | 52 +++++++++---------- .../core/optimizer/gemm_activation_fusion.cc | 2 +- .../core/optimizer/matmul_add_fusion.cc | 2 +- 6 files changed, 81 insertions(+), 55 deletions(-) diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index 793202cfd1..6e598eb664 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -138,6 +138,16 @@ static void UpdateImplicitInputNameInSubgraph(Node& node, } } +/** Returns a vector of the input GraphEdges of a node. */ +static std::vector GetNodeInputEdges(const Node& node) { + std::vector input_edges; + for (auto it = node.InputEdgesBegin(), end = node.InputEdgesEnd(); it != end; ++it) { + input_edges.push_back(GraphEdge::CreateGraphEdge(node, *it, true)); + } + + return input_edges; +} + /** Returns a vector of the output GraphEdges of a node. */ static std::vector GetNodeOutputEdges(const Node& node) { std::vector output_edges; @@ -216,6 +226,20 @@ static bool RemoveNodeWithSingleNodeInSingleUsedOutput(Graph& graph, Node& node) return true; } +/** Move the input edges that src_node has to target_node. +After the move is complete src_node will have no input edges. +*/ +static void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node) { + auto target_idx = target_node.Index(); + auto input_edges = GetNodeInputEdges(src_node); + + for (auto cur = input_edges.cbegin(), end = input_edges.cend(); cur != end; ++cur) { + graph.AddEdge(cur->src_node, target_idx, cur->src_arg_index, cur->dst_arg_index); + } + + RemoveGraphEdges(graph, input_edges); +} + /** Move the output defs and edges from src_node to target_node. After the move is complete src_node will have no output edges and can be safely removed by Graph::RemoveNode. */ @@ -223,14 +247,14 @@ static void MoveAllNodeOutputs(Graph& graph, Node& src_node, Node& target_node) // copy the NodeArg*'s for all output defs. target_node.MutableOutputDefs() = src_node.MutableOutputDefs(); - auto src_idx = src_node.Index(); auto target_idx = target_node.Index(); auto output_edges = GetNodeOutputEdges(src_node); for (auto cur = output_edges.cbegin(), end = output_edges.cend(); cur != end; ++cur) { - graph.AddEdge(target_idx, cur->dst_node, 0, cur->dst_arg_index); - graph.RemoveEdge(src_idx, cur->dst_node, 0, cur->dst_arg_index); + graph.AddEdge(target_idx, cur->dst_node, cur->src_arg_index, cur->dst_arg_index); } + + RemoveGraphEdges(graph, output_edges); } //---------------------------- @@ -598,16 +622,22 @@ void AddNodeInput(Node& target, int target_input_idx, NodeArg& new_input) { target.MutableInputArgsCount()[target_input_idx] = 1; } -void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node, Node* replacement_node) { - graph_utils::RemoveNodeOutputEdges(graph, first_node); - graph_utils::MoveAllNodeOutputs(graph, second_node, replacement_node ? *replacement_node : first_node); +void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node) { + // move the outputs from second_node to first_node + RemoveNodeOutputEdges(graph, first_node); + MoveAllNodeOutputs(graph, second_node, first_node); // second node now has no output edges and can be removed graph.RemoveNode(second_node.Index()); +} - if (replacement_node) { - // first_node has no output edges and can be removed - graph.RemoveNode(first_node.Index()); +void FinalizeNodeFusion(Graph& graph, const std::vector>& nodes, Node& replacement_node) { + MoveAllNodeInputEdges(graph, nodes.front(), replacement_node); + MoveAllNodeOutputs(graph, nodes.back(), replacement_node); + + for (Node& node : nodes) { + RemoveNodeOutputEdges(graph, node); + graph.RemoveNode(node.Index()); } } diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index c124b19a60..15334917c7 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -163,23 +163,23 @@ void ReplaceNodeInput(Node& target, int target_input_idx, NodeArg& new_input); */ void AddNodeInput(Node& target, int target_input_idx, NodeArg& new_input); -/** Finalize the fusion of two nodes. - Conceptually two nodes are being combined into one, and post-fusion will produce output/s with the same names - and be connected to the same downstream nodes. - -Two styles of fusion are supported. - -1. Both first_node and second_node are being removed and replaced by replacement_node - e.g. fusion of Conv and an activation into a new FusedConv node - The output definitions and edges from second_node are moved to replacement_node, and both first_node and - second_node are removed. - -2. The functionality of the second_node is being fused into first_node - e.g. fusion of Conv and Add into a single Conv node that does the Add via a bias input +/** Finalize the fusion of second_node into first_node. The output definitions and edges from the second_node are moved to first_node. second_node is deleted. + e.g. Conv + Add fusion fuses the 'Add' into the Conv. */ -void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node, Node* replacement_node = nullptr); +void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node); + +/** Finalize the fusion of two or more nodes which are being replaced with a single node. + The first and last entries in 'nodes' are assumed to be the first and last nodes in a chain of nodes being fused. + + Conceptually multiple nodes are being combined into one, and post-fusion will produce output/s with the same names + as the last node in 'nodes', and be connected to the same downstream nodes. + + The input edges to the first node in 'nodes' will be moved to replacement_node. No other input edges are moved. + The output definitions and edges from the last node in 'nodes' will be moved to replacement_node. + All nodes in 'nodes' will be removed. +*/ +void FinalizeNodeFusion(Graph& graph, const std::vector>& nodes, Node& replacement_node); } // namespace graph_utils - } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index 3380c74a92..b5e3eff2f8 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -141,7 +141,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l } // move output definitions and edges from act_node to fused_conv. delete conv_node and act_node. - graph_utils::FinalizeNodeFusion(graph, conv_node, act_node, &fused_conv); + graph_utils::FinalizeNodeFusion(graph, {conv_node, act_node}, fused_conv); modified = true; } diff --git a/onnxruntime/core/optimizer/gelu_fusion.cc b/onnxruntime/core/optimizer/gelu_fusion.cc index bc8090f88a..345e22e2c9 100644 --- a/onnxruntime/core/optimizer/gelu_fusion.cc +++ b/onnxruntime/core/optimizer/gelu_fusion.cc @@ -70,10 +70,13 @@ static bool IsSupportedDataType(const Node& node) { Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - std::deque removed_nodes; for (auto node_index : node_topology_list) { - auto& div = *graph.GetNode(node_index); + auto* p_div = graph.GetNode(node_index); + if (p_div == nullptr) + continue; // we removed the node as part of an earlier fusion + + Node& div = *p_div; ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level)); if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7}) || @@ -88,7 +91,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) cons continue; } - const Node& erf_node = *(div.OutputNodesBegin()); + Node& erf_node = *graph.GetNode(div.OutputNodesBegin()->Index()); if (!graph_utils::IsSupportedOptypeVersionAndDomain(erf_node, "Erf", {9}) || erf_node.GetExecutionProviderType() != div.GetExecutionProviderType() || erf_node.GetOutputEdgesCount() != 1 || @@ -96,7 +99,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) cons continue; } - const Node& add_node = *(erf_node.OutputNodesBegin()); + Node& add_node = *graph.GetNode(erf_node.OutputNodesBegin()->Index()); if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) || add_node.GetExecutionProviderType() != div.GetExecutionProviderType() || add_node.GetOutputEdgesCount() != 1 || @@ -115,67 +118,60 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) cons continue; } - const Node& mul_node = *(add_node.OutputNodesBegin()); + Node& mul_node = *graph.GetNode(add_node.OutputNodesBegin()->Index()); + // note: output edges count doesn't matter as the new Gelu node will produce outputs with the same names if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) || mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() || !IsSupportedDataType(mul_node)) { continue; } - const Node* mul2_node = nullptr; + const Node* p_mul2_node = nullptr; for (auto iter = mul_node.InputNodesBegin(); iter != mul_node.InputNodesEnd(); ++iter) { if ((*iter).OpType().compare("Mul") == 0) { // find the other input node of Mul - mul2_node = &(*iter); + p_mul2_node = &(*iter); break; } } - if (mul2_node == nullptr) { + if (p_mul2_node == nullptr) { continue; } - if (!graph_utils::IsSupportedOptypeVersionAndDomain(*mul2_node, "Mul", {7}) || - mul2_node->GetExecutionProviderType() != div.GetExecutionProviderType() || - mul2_node->GetOutputEdgesCount() != 1 || - !IsSupportedDataType(*mul2_node)) { + Node& mul2_node = *graph.GetNode(p_mul2_node->Index()); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) || + mul2_node.GetExecutionProviderType() != div.GetExecutionProviderType() || + mul2_node.GetOutputEdgesCount() != 1 || + !IsSupportedDataType(mul2_node)) { continue; } // Check the other input node(e.g. not of type Add) is 0.5f. int mul_const_input_index = 0; - if (mul2_node->InputDefs()[0]->Name() == div.MutableInputDefs()[0]->Name()) { + if (mul2_node.InputDefs()[0]->Name() == div.MutableInputDefs()[0]->Name()) { mul_const_input_index = 1; } - const auto& mul_const_input_arg = mul2_node->InputDefs()[mul_const_input_index]; + const auto& mul_const_input_arg = mul2_node.InputDefs()[mul_const_input_index]; if (!CheckConstantInput(graph, *mul_const_input_arg, 0.5f)) { continue; } const std::vector gelu_input_defs{div.MutableInputDefs()[0]}; - const std::vector gelu_output_defs{const_cast(mul_node.OutputDefs()[0])}; Node& gelu_node = graph.AddNode(graph.GenerateNodeName("Gelu"), "Gelu", "fused Gelu subgraphs ", gelu_input_defs, - gelu_output_defs, {}, kMSDomain); + {}, {}, kMSDomain); // Assign provider to this new node. Provider should be same as the provider for old node. gelu_node.SetExecutionProviderType(div.GetExecutionProviderType()); - removed_nodes.push_front(div.Index()); - removed_nodes.push_front(erf_node.Index()); - removed_nodes.push_front(add_node.Index()); - removed_nodes.push_front(mul2_node->Index()); - removed_nodes.push_front(mul_node.Index()); - } + // move input edges to div (first in list) across to the gelu_node. + // move output definitions and output edges from mul_node (last in list) to gelu_node. + // remove all the other nodes. + graph_utils::FinalizeNodeFusion(graph, {div, erf_node, add_node, mul2_node, mul_node}, gelu_node); - // Have to remove node in reversed order for now to walk around the issue in RemoveNode - for (onnxruntime::NodeIndex removed_node : removed_nodes) { - graph.RemoveNode(removed_node); - } - - if (!removed_nodes.empty()) { modified = true; } diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index 6f8ee54ef3..756f53c7cc 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -73,7 +73,7 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l } // move output definitions and edges from act_node to fused_gemm. delete gemm_node and act_node. - graph_utils::FinalizeNodeFusion(graph, gemm_node, act_node, &fused_gemm); + graph_utils::FinalizeNodeFusion(graph, {gemm_node, act_node}, fused_gemm); modified = true; } diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index 438d494e64..5500eb9ec1 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -110,7 +110,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) gemm_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType()); // move output definitions and edges from act_node to gemm_node. delete gemm_node and act_node. - graph_utils::FinalizeNodeFusion(graph, matmul_node, add_node, &gemm_node); + graph_utils::FinalizeNodeFusion(graph, {matmul_node, add_node}, gemm_node); modified = true; }