mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-18 01:54:05 +00:00
Fix GELU fusion (#2213)
* Split graph_utils methods for finalization of fusion in order to support more than 2 nodes being fused into one. Update GELU fusion to use graph_utils to set up the input/output edges for the fused node, and removing nodes that are being replaced.
This commit is contained in:
parent
aef055ebe8
commit
91122a2cf5
6 changed files with 81 additions and 55 deletions
|
|
@ -138,6 +138,16 @@ static void UpdateImplicitInputNameInSubgraph(Node& node,
|
|||
}
|
||||
}
|
||||
|
||||
/** Returns a vector of the input GraphEdges of a node. */
|
||||
static std::vector<GraphEdge> GetNodeInputEdges(const Node& node) {
|
||||
std::vector<GraphEdge> input_edges;
|
||||
for (auto it = node.InputEdgesBegin(), end = node.InputEdgesEnd(); it != end; ++it) {
|
||||
input_edges.push_back(GraphEdge::CreateGraphEdge(node, *it, true));
|
||||
}
|
||||
|
||||
return input_edges;
|
||||
}
|
||||
|
||||
/** Returns a vector of the output GraphEdges of a node. */
|
||||
static std::vector<GraphEdge> GetNodeOutputEdges(const Node& node) {
|
||||
std::vector<GraphEdge> output_edges;
|
||||
|
|
@ -216,6 +226,20 @@ static bool RemoveNodeWithSingleNodeInSingleUsedOutput(Graph& graph, Node& node)
|
|||
return true;
|
||||
}
|
||||
|
||||
/** Move the input edges that src_node has to target_node.
|
||||
After the move is complete src_node will have no input edges.
|
||||
*/
|
||||
static void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node) {
|
||||
auto target_idx = target_node.Index();
|
||||
auto input_edges = GetNodeInputEdges(src_node);
|
||||
|
||||
for (auto cur = input_edges.cbegin(), end = input_edges.cend(); cur != end; ++cur) {
|
||||
graph.AddEdge(cur->src_node, target_idx, cur->src_arg_index, cur->dst_arg_index);
|
||||
}
|
||||
|
||||
RemoveGraphEdges(graph, input_edges);
|
||||
}
|
||||
|
||||
/** Move the output defs and edges from src_node to target_node.
|
||||
After the move is complete src_node will have no output edges and can be safely removed by Graph::RemoveNode.
|
||||
*/
|
||||
|
|
@ -223,14 +247,14 @@ static void MoveAllNodeOutputs(Graph& graph, Node& src_node, Node& target_node)
|
|||
// copy the NodeArg*'s for all output defs.
|
||||
target_node.MutableOutputDefs() = src_node.MutableOutputDefs();
|
||||
|
||||
auto src_idx = src_node.Index();
|
||||
auto target_idx = target_node.Index();
|
||||
auto output_edges = GetNodeOutputEdges(src_node);
|
||||
|
||||
for (auto cur = output_edges.cbegin(), end = output_edges.cend(); cur != end; ++cur) {
|
||||
graph.AddEdge(target_idx, cur->dst_node, 0, cur->dst_arg_index);
|
||||
graph.RemoveEdge(src_idx, cur->dst_node, 0, cur->dst_arg_index);
|
||||
graph.AddEdge(target_idx, cur->dst_node, cur->src_arg_index, cur->dst_arg_index);
|
||||
}
|
||||
|
||||
RemoveGraphEdges(graph, output_edges);
|
||||
}
|
||||
|
||||
//----------------------------
|
||||
|
|
@ -598,16 +622,22 @@ void AddNodeInput(Node& target, int target_input_idx, NodeArg& new_input) {
|
|||
target.MutableInputArgsCount()[target_input_idx] = 1;
|
||||
}
|
||||
|
||||
void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node, Node* replacement_node) {
|
||||
graph_utils::RemoveNodeOutputEdges(graph, first_node);
|
||||
graph_utils::MoveAllNodeOutputs(graph, second_node, replacement_node ? *replacement_node : first_node);
|
||||
void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node) {
|
||||
// move the outputs from second_node to first_node
|
||||
RemoveNodeOutputEdges(graph, first_node);
|
||||
MoveAllNodeOutputs(graph, second_node, first_node);
|
||||
|
||||
// second node now has no output edges and can be removed
|
||||
graph.RemoveNode(second_node.Index());
|
||||
}
|
||||
|
||||
if (replacement_node) {
|
||||
// first_node has no output edges and can be removed
|
||||
graph.RemoveNode(first_node.Index());
|
||||
void FinalizeNodeFusion(Graph& graph, const std::vector<std::reference_wrapper<Node>>& nodes, Node& replacement_node) {
|
||||
MoveAllNodeInputEdges(graph, nodes.front(), replacement_node);
|
||||
MoveAllNodeOutputs(graph, nodes.back(), replacement_node);
|
||||
|
||||
for (Node& node : nodes) {
|
||||
RemoveNodeOutputEdges(graph, node);
|
||||
graph.RemoveNode(node.Index());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -163,23 +163,23 @@ void ReplaceNodeInput(Node& target, int target_input_idx, NodeArg& new_input);
|
|||
*/
|
||||
void AddNodeInput(Node& target, int target_input_idx, NodeArg& new_input);
|
||||
|
||||
/** Finalize the fusion of two nodes.
|
||||
Conceptually two nodes are being combined into one, and post-fusion will produce output/s with the same names
|
||||
and be connected to the same downstream nodes.
|
||||
|
||||
Two styles of fusion are supported.
|
||||
|
||||
1. Both first_node and second_node are being removed and replaced by replacement_node
|
||||
e.g. fusion of Conv and an activation into a new FusedConv node
|
||||
The output definitions and edges from second_node are moved to replacement_node, and both first_node and
|
||||
second_node are removed.
|
||||
|
||||
2. The functionality of the second_node is being fused into first_node
|
||||
e.g. fusion of Conv and Add into a single Conv node that does the Add via a bias input
|
||||
/** Finalize the fusion of second_node into first_node.
|
||||
The output definitions and edges from the second_node are moved to first_node. second_node is deleted.
|
||||
e.g. Conv + Add fusion fuses the 'Add' into the Conv.
|
||||
*/
|
||||
void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node, Node* replacement_node = nullptr);
|
||||
void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node);
|
||||
|
||||
/** Finalize the fusion of two or more nodes which are being replaced with a single node.
|
||||
The first and last entries in 'nodes' are assumed to be the first and last nodes in a chain of nodes being fused.
|
||||
|
||||
Conceptually multiple nodes are being combined into one, and post-fusion will produce output/s with the same names
|
||||
as the last node in 'nodes', and be connected to the same downstream nodes.
|
||||
|
||||
The input edges to the first node in 'nodes' will be moved to replacement_node. No other input edges are moved.
|
||||
The output definitions and edges from the last node in 'nodes' will be moved to replacement_node.
|
||||
All nodes in 'nodes' will be removed.
|
||||
*/
|
||||
void FinalizeNodeFusion(Graph& graph, const std::vector<std::reference_wrapper<Node>>& nodes, Node& replacement_node);
|
||||
|
||||
} // namespace graph_utils
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
}
|
||||
|
||||
// move output definitions and edges from act_node to fused_conv. delete conv_node and act_node.
|
||||
graph_utils::FinalizeNodeFusion(graph, conv_node, act_node, &fused_conv);
|
||||
graph_utils::FinalizeNodeFusion(graph, {conv_node, act_node}, fused_conv);
|
||||
|
||||
modified = true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,10 +70,13 @@ static bool IsSupportedDataType(const Node& node) {
|
|||
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
std::deque<onnxruntime::NodeIndex> removed_nodes;
|
||||
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto& div = *graph.GetNode(node_index);
|
||||
auto* p_div = graph.GetNode(node_index);
|
||||
if (p_div == nullptr)
|
||||
continue; // we removed the node as part of an earlier fusion
|
||||
|
||||
Node& div = *p_div;
|
||||
ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7}) ||
|
||||
|
|
@ -88,7 +91,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) cons
|
|||
continue;
|
||||
}
|
||||
|
||||
const Node& erf_node = *(div.OutputNodesBegin());
|
||||
Node& erf_node = *graph.GetNode(div.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(erf_node, "Erf", {9}) ||
|
||||
erf_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
erf_node.GetOutputEdgesCount() != 1 ||
|
||||
|
|
@ -96,7 +99,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) cons
|
|||
continue;
|
||||
}
|
||||
|
||||
const Node& add_node = *(erf_node.OutputNodesBegin());
|
||||
Node& add_node = *graph.GetNode(erf_node.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) ||
|
||||
add_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
add_node.GetOutputEdgesCount() != 1 ||
|
||||
|
|
@ -115,67 +118,60 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) cons
|
|||
continue;
|
||||
}
|
||||
|
||||
const Node& mul_node = *(add_node.OutputNodesBegin());
|
||||
Node& mul_node = *graph.GetNode(add_node.OutputNodesBegin()->Index());
|
||||
// note: output edges count doesn't matter as the new Gelu node will produce outputs with the same names
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) ||
|
||||
mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
!IsSupportedDataType(mul_node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const Node* mul2_node = nullptr;
|
||||
const Node* p_mul2_node = nullptr;
|
||||
for (auto iter = mul_node.InputNodesBegin(); iter != mul_node.InputNodesEnd(); ++iter) {
|
||||
if ((*iter).OpType().compare("Mul") == 0) {
|
||||
// find the other input node of Mul
|
||||
mul2_node = &(*iter);
|
||||
p_mul2_node = &(*iter);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (mul2_node == nullptr) {
|
||||
if (p_mul2_node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*mul2_node, "Mul", {7}) ||
|
||||
mul2_node->GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
mul2_node->GetOutputEdgesCount() != 1 ||
|
||||
!IsSupportedDataType(*mul2_node)) {
|
||||
Node& mul2_node = *graph.GetNode(p_mul2_node->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) ||
|
||||
mul2_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
mul2_node.GetOutputEdgesCount() != 1 ||
|
||||
!IsSupportedDataType(mul2_node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check the other input node(e.g. not of type Add) is 0.5f.
|
||||
int mul_const_input_index = 0;
|
||||
if (mul2_node->InputDefs()[0]->Name() == div.MutableInputDefs()[0]->Name()) {
|
||||
if (mul2_node.InputDefs()[0]->Name() == div.MutableInputDefs()[0]->Name()) {
|
||||
mul_const_input_index = 1;
|
||||
}
|
||||
|
||||
const auto& mul_const_input_arg = mul2_node->InputDefs()[mul_const_input_index];
|
||||
const auto& mul_const_input_arg = mul2_node.InputDefs()[mul_const_input_index];
|
||||
if (!CheckConstantInput(graph, *mul_const_input_arg, 0.5f)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::vector<NodeArg*> gelu_input_defs{div.MutableInputDefs()[0]};
|
||||
const std::vector<NodeArg*> gelu_output_defs{const_cast<NodeArg*>(mul_node.OutputDefs()[0])};
|
||||
Node& gelu_node = graph.AddNode(graph.GenerateNodeName("Gelu"),
|
||||
"Gelu",
|
||||
"fused Gelu subgraphs ",
|
||||
gelu_input_defs,
|
||||
gelu_output_defs, {}, kMSDomain);
|
||||
{}, {}, kMSDomain);
|
||||
|
||||
// Assign provider to this new node. Provider should be same as the provider for old node.
|
||||
gelu_node.SetExecutionProviderType(div.GetExecutionProviderType());
|
||||
|
||||
removed_nodes.push_front(div.Index());
|
||||
removed_nodes.push_front(erf_node.Index());
|
||||
removed_nodes.push_front(add_node.Index());
|
||||
removed_nodes.push_front(mul2_node->Index());
|
||||
removed_nodes.push_front(mul_node.Index());
|
||||
}
|
||||
// move input edges to div (first in list) across to the gelu_node.
|
||||
// move output definitions and output edges from mul_node (last in list) to gelu_node.
|
||||
// remove all the other nodes.
|
||||
graph_utils::FinalizeNodeFusion(graph, {div, erf_node, add_node, mul2_node, mul_node}, gelu_node);
|
||||
|
||||
// Have to remove node in reversed order for now to walk around the issue in RemoveNode
|
||||
for (onnxruntime::NodeIndex removed_node : removed_nodes) {
|
||||
graph.RemoveNode(removed_node);
|
||||
}
|
||||
|
||||
if (!removed_nodes.empty()) {
|
||||
modified = true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
}
|
||||
|
||||
// move output definitions and edges from act_node to fused_gemm. delete gemm_node and act_node.
|
||||
graph_utils::FinalizeNodeFusion(graph, gemm_node, act_node, &fused_gemm);
|
||||
graph_utils::FinalizeNodeFusion(graph, {gemm_node, act_node}, fused_gemm);
|
||||
|
||||
modified = true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level)
|
|||
gemm_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType());
|
||||
|
||||
// move output definitions and edges from act_node to gemm_node. delete gemm_node and act_node.
|
||||
graph_utils::FinalizeNodeFusion(graph, matmul_node, add_node, &gemm_node);
|
||||
graph_utils::FinalizeNodeFusion(graph, {matmul_node, add_node}, gemm_node);
|
||||
|
||||
modified = true;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue