Fix GELU fusion (#2213)

* Split graph_utils methods for finalization of fusion in order to support more than 2 nodes being fused into one.
Update GELU fusion to use graph_utils to set up the input/output edges for the fused node, and removing nodes that are being replaced.
This commit is contained in:
Scott McKay 2019-10-21 20:18:45 -07:00 committed by Changming Sun
parent aef055ebe8
commit 91122a2cf5
6 changed files with 81 additions and 55 deletions

View file

@ -138,6 +138,16 @@ static void UpdateImplicitInputNameInSubgraph(Node& node,
}
}
/** Returns a vector of the input GraphEdges of a node. */
static std::vector<GraphEdge> GetNodeInputEdges(const Node& node) {
std::vector<GraphEdge> input_edges;
for (auto it = node.InputEdgesBegin(), end = node.InputEdgesEnd(); it != end; ++it) {
input_edges.push_back(GraphEdge::CreateGraphEdge(node, *it, true));
}
return input_edges;
}
/** Returns a vector of the output GraphEdges of a node. */
static std::vector<GraphEdge> GetNodeOutputEdges(const Node& node) {
std::vector<GraphEdge> output_edges;
@ -216,6 +226,20 @@ static bool RemoveNodeWithSingleNodeInSingleUsedOutput(Graph& graph, Node& node)
return true;
}
/** Move the input edges that src_node has to target_node.
After the move is complete src_node will have no input edges.
*/
static void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node) {
auto target_idx = target_node.Index();
auto input_edges = GetNodeInputEdges(src_node);
for (auto cur = input_edges.cbegin(), end = input_edges.cend(); cur != end; ++cur) {
graph.AddEdge(cur->src_node, target_idx, cur->src_arg_index, cur->dst_arg_index);
}
RemoveGraphEdges(graph, input_edges);
}
/** Move the output defs and edges from src_node to target_node.
After the move is complete src_node will have no output edges and can be safely removed by Graph::RemoveNode.
*/
@ -223,14 +247,14 @@ static void MoveAllNodeOutputs(Graph& graph, Node& src_node, Node& target_node)
// copy the NodeArg*'s for all output defs.
target_node.MutableOutputDefs() = src_node.MutableOutputDefs();
auto src_idx = src_node.Index();
auto target_idx = target_node.Index();
auto output_edges = GetNodeOutputEdges(src_node);
for (auto cur = output_edges.cbegin(), end = output_edges.cend(); cur != end; ++cur) {
graph.AddEdge(target_idx, cur->dst_node, 0, cur->dst_arg_index);
graph.RemoveEdge(src_idx, cur->dst_node, 0, cur->dst_arg_index);
graph.AddEdge(target_idx, cur->dst_node, cur->src_arg_index, cur->dst_arg_index);
}
RemoveGraphEdges(graph, output_edges);
}
//----------------------------
@ -598,16 +622,22 @@ void AddNodeInput(Node& target, int target_input_idx, NodeArg& new_input) {
target.MutableInputArgsCount()[target_input_idx] = 1;
}
void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node, Node* replacement_node) {
graph_utils::RemoveNodeOutputEdges(graph, first_node);
graph_utils::MoveAllNodeOutputs(graph, second_node, replacement_node ? *replacement_node : first_node);
void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node) {
// move the outputs from second_node to first_node
RemoveNodeOutputEdges(graph, first_node);
MoveAllNodeOutputs(graph, second_node, first_node);
// second node now has no output edges and can be removed
graph.RemoveNode(second_node.Index());
}
if (replacement_node) {
// first_node has no output edges and can be removed
graph.RemoveNode(first_node.Index());
void FinalizeNodeFusion(Graph& graph, const std::vector<std::reference_wrapper<Node>>& nodes, Node& replacement_node) {
MoveAllNodeInputEdges(graph, nodes.front(), replacement_node);
MoveAllNodeOutputs(graph, nodes.back(), replacement_node);
for (Node& node : nodes) {
RemoveNodeOutputEdges(graph, node);
graph.RemoveNode(node.Index());
}
}

View file

@ -163,23 +163,23 @@ void ReplaceNodeInput(Node& target, int target_input_idx, NodeArg& new_input);
*/
void AddNodeInput(Node& target, int target_input_idx, NodeArg& new_input);
/** Finalize the fusion of two nodes.
Conceptually two nodes are being combined into one, and post-fusion will produce output/s with the same names
and be connected to the same downstream nodes.
Two styles of fusion are supported.
1. Both first_node and second_node are being removed and replaced by replacement_node
e.g. fusion of Conv and an activation into a new FusedConv node
The output definitions and edges from second_node are moved to replacement_node, and both first_node and
second_node are removed.
2. The functionality of the second_node is being fused into first_node
e.g. fusion of Conv and Add into a single Conv node that does the Add via a bias input
/** Finalize the fusion of second_node into first_node.
The output definitions and edges from the second_node are moved to first_node. second_node is deleted.
e.g. Conv + Add fusion fuses the 'Add' into the Conv.
*/
void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node, Node* replacement_node = nullptr);
void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node);
/** Finalize the fusion of two or more nodes which are being replaced with a single node.
The first and last entries in 'nodes' are assumed to be the first and last nodes in a chain of nodes being fused.
Conceptually multiple nodes are being combined into one, and post-fusion will produce output/s with the same names
as the last node in 'nodes', and be connected to the same downstream nodes.
The input edges to the first node in 'nodes' will be moved to replacement_node. No other input edges are moved.
The output definitions and edges from the last node in 'nodes' will be moved to replacement_node.
All nodes in 'nodes' will be removed.
*/
void FinalizeNodeFusion(Graph& graph, const std::vector<std::reference_wrapper<Node>>& nodes, Node& replacement_node);
} // namespace graph_utils
} // namespace onnxruntime

View file

@ -141,7 +141,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
}
// move output definitions and edges from act_node to fused_conv. delete conv_node and act_node.
graph_utils::FinalizeNodeFusion(graph, conv_node, act_node, &fused_conv);
graph_utils::FinalizeNodeFusion(graph, {conv_node, act_node}, fused_conv);
modified = true;
}

View file

@ -70,10 +70,13 @@ static bool IsSupportedDataType(const Node& node) {
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
std::deque<onnxruntime::NodeIndex> removed_nodes;
for (auto node_index : node_topology_list) {
auto& div = *graph.GetNode(node_index);
auto* p_div = graph.GetNode(node_index);
if (p_div == nullptr)
continue; // we removed the node as part of an earlier fusion
Node& div = *p_div;
ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7}) ||
@ -88,7 +91,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) cons
continue;
}
const Node& erf_node = *(div.OutputNodesBegin());
Node& erf_node = *graph.GetNode(div.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(erf_node, "Erf", {9}) ||
erf_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
erf_node.GetOutputEdgesCount() != 1 ||
@ -96,7 +99,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) cons
continue;
}
const Node& add_node = *(erf_node.OutputNodesBegin());
Node& add_node = *graph.GetNode(erf_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) ||
add_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
add_node.GetOutputEdgesCount() != 1 ||
@ -115,67 +118,60 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) cons
continue;
}
const Node& mul_node = *(add_node.OutputNodesBegin());
Node& mul_node = *graph.GetNode(add_node.OutputNodesBegin()->Index());
// note: output edges count doesn't matter as the new Gelu node will produce outputs with the same names
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) ||
mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
!IsSupportedDataType(mul_node)) {
continue;
}
const Node* mul2_node = nullptr;
const Node* p_mul2_node = nullptr;
for (auto iter = mul_node.InputNodesBegin(); iter != mul_node.InputNodesEnd(); ++iter) {
if ((*iter).OpType().compare("Mul") == 0) {
// find the other input node of Mul
mul2_node = &(*iter);
p_mul2_node = &(*iter);
break;
}
}
if (mul2_node == nullptr) {
if (p_mul2_node == nullptr) {
continue;
}
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*mul2_node, "Mul", {7}) ||
mul2_node->GetExecutionProviderType() != div.GetExecutionProviderType() ||
mul2_node->GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(*mul2_node)) {
Node& mul2_node = *graph.GetNode(p_mul2_node->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) ||
mul2_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
mul2_node.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(mul2_node)) {
continue;
}
// Check the other input node(e.g. not of type Add) is 0.5f.
int mul_const_input_index = 0;
if (mul2_node->InputDefs()[0]->Name() == div.MutableInputDefs()[0]->Name()) {
if (mul2_node.InputDefs()[0]->Name() == div.MutableInputDefs()[0]->Name()) {
mul_const_input_index = 1;
}
const auto& mul_const_input_arg = mul2_node->InputDefs()[mul_const_input_index];
const auto& mul_const_input_arg = mul2_node.InputDefs()[mul_const_input_index];
if (!CheckConstantInput(graph, *mul_const_input_arg, 0.5f)) {
continue;
}
const std::vector<NodeArg*> gelu_input_defs{div.MutableInputDefs()[0]};
const std::vector<NodeArg*> gelu_output_defs{const_cast<NodeArg*>(mul_node.OutputDefs()[0])};
Node& gelu_node = graph.AddNode(graph.GenerateNodeName("Gelu"),
"Gelu",
"fused Gelu subgraphs ",
gelu_input_defs,
gelu_output_defs, {}, kMSDomain);
{}, {}, kMSDomain);
// Assign provider to this new node. Provider should be same as the provider for old node.
gelu_node.SetExecutionProviderType(div.GetExecutionProviderType());
removed_nodes.push_front(div.Index());
removed_nodes.push_front(erf_node.Index());
removed_nodes.push_front(add_node.Index());
removed_nodes.push_front(mul2_node->Index());
removed_nodes.push_front(mul_node.Index());
}
// move input edges to div (first in list) across to the gelu_node.
// move output definitions and output edges from mul_node (last in list) to gelu_node.
// remove all the other nodes.
graph_utils::FinalizeNodeFusion(graph, {div, erf_node, add_node, mul2_node, mul_node}, gelu_node);
// Have to remove node in reversed order for now to walk around the issue in RemoveNode
for (onnxruntime::NodeIndex removed_node : removed_nodes) {
graph.RemoveNode(removed_node);
}
if (!removed_nodes.empty()) {
modified = true;
}

View file

@ -73,7 +73,7 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
}
// move output definitions and edges from act_node to fused_gemm. delete gemm_node and act_node.
graph_utils::FinalizeNodeFusion(graph, gemm_node, act_node, &fused_gemm);
graph_utils::FinalizeNodeFusion(graph, {gemm_node, act_node}, fused_gemm);
modified = true;
}

View file

@ -110,7 +110,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level)
gemm_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType());
// move output definitions and edges from act_node to gemm_node. delete gemm_node and act_node.
graph_utils::FinalizeNodeFusion(graph, matmul_node, add_node, &gemm_node);
graph_utils::FinalizeNodeFusion(graph, {matmul_node, add_node}, gemm_node);
modified = true;
}