onnxruntime/onnxruntime/core/graph/graph_utils.h
Scott McKay 91122a2cf5 Fix GELU fusion (#2213)
* Split graph_utils methods for finalization of fusion in order to support more than 2 nodes being fused into one.
Update GELU fusion to use graph_utils to set up the input/output edges for the fused node, and removing nodes that are being replaced.
2019-10-21 20:18:45 -07:00

185 lines
9.5 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/graph/onnx_protobuf.h"
#include "core/graph/graph.h"
namespace onnxruntime {
namespace graph_utils {
/** Checks if the operator's type, version, and domain of the given node match the given values. */
bool IsSupportedOptypeVersionAndDomain(const Node& node,
const std::string& op_type,
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion>& versions,
const std::string& domain = kOnnxDomainAlias);
/** Checks if the node has the same operator since version as the given one. */
bool MatchesOpSinceVersion(const Node& node, const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion>& versions);
/** Checks if the node has the same op set domain as the given one. */
bool MatchesOpSetDomain(const Node& node, const std::string& domain);
/** Returns true if the execution provider assigned to current node is present in the compatible providers list
or if the compatible_providers list is empty. */
bool IsSupportedProvider(const Node& node,
const std::unordered_set<std::string>& compatible_providers);
/** Checks if the output at the specified index is input to downstream Nodes. */
bool IsOutputUsed(const Node& node, int index);
/** Returns true if the graph has the given input.*/
bool IsGraphInput(const Graph& graph, const NodeArg* input);
/** returns true if 'name' is an initializer in 'graph', or an ancestor graph if check_outer_scope is true.
@param check_outer_scope If true and 'graph' is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'.
*/
bool IsInitializer(const Graph& graph, const std::string& name, bool check_outer_scope);
/** returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime.
@param check_outer_scope If true and 'graph' is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'.
*/
bool IsConstantInitializer(const Graph& graph, const std::string& name, bool check_outer_scope = true);
/** returns the initializer's TensorProto if 'name' is an initializer, is constant and
cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned.
@param check_outer_scope If true and the graph is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'.
*/
const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const Graph& graph, const std::string& name,
bool check_outer_scope = true);
/** Add a new initializer to 'graph'.
Checks that new_initializer does not already exist in 'graph' before adding it.
@returns The NodeArg for the new initializer.
@remarks No matching graph input is created, so the initializer will be constant.
*/
NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer);
/** Checks if the given NodeArg is constant, i.e., it appears in the graph's initializers but not in its inputs. */
bool NodeArgIsConstant(const Graph& graph, const NodeArg& node_arg);
/** Checks if the given node has only constant inputs (initializers) and if so returns them in constant_inputs as they
may come from outer scope. */
bool AllNodeInputsAreConstant(const Graph& graph, const Node& node, InitializedTensorSet& constant_inputs);
/** Gets the name of the incoming NodeArg with the specified index for the given node. */
const std::string& GetNodeInputName(const Node& node, int index);
/** Gets the name of the outgoing NodeArg with the specified index for the given node. */
const std::string& GetNodeOutputName(const Node& node, int index);
/** Returns the attribute of a Node with a given name. */
const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name);
/** Retrieves the values for a repeated attribute of a node and place them to the values vector. */
template <typename T>
bool GetRepeatedNodeAttributeValues(const Node& node,
const std::string& attr_name,
std::vector<T>& values) {
const auto* attr = graph_utils::GetNodeAttribute(node, attr_name);
if (attr) {
values = ONNX_NAMESPACE::RetrieveValues<T>(*attr);
return true;
}
return false;
}
/** Tests if we can remove a node and merge its input edge (if any) with its output edges.
Conditions:
Input rules:
- the node has one input edge
- it may have multiple other inputs coming from graph inputs or initializers
- or the node has no input edges, and a single input
- the input will be coming from a graph input or initializer
Output rules:
- Only one of the node's outputs is used by downstream operators
- multiple edges for the single used output are allowed
- The node does not produce a graph output
- the node removal will result in that output name not being produced
Subgraph rules:
- Removing the node won't break a subgraph that consumes the node's output
*/
bool CanRemoveNode(const Graph& graph, const Node& node);
/** Removes the given Node from the Graph.
See CanRemoveNode for the conditions that must be satisfied in order to remove the node.
If the node has one input edge, merge the input edge with any output edges.
If the node has no input edges it has a single input, so update any output edges to use the input as their source.
After output edges are updated, remove the node.
*/
bool RemoveNode(Graph& graph, Node& node);
/** Tests if we can remove a node and replace its output with an initializer.
Conditions:
- Only one of the node's outputs is used by downstream operators or as a graph output
- multiple edges for the single used output are allowed
- If the node produces a graph output the initializer_name must be the same as the node's output name
- otherwise the required graph output will not be produced
- Removing the node won't break a subgraph that consumes the node's output
*/
bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const std::string& initializer_name);
/** Remove a node and replace its output with the provided NodeArg for an initializer.
See CanReplaceNodeWithInitializer for the conditions that must be satisfied in order to remove the node.*/
bool ReplaceNodeWithInitializer(Graph& graph, Node& node, NodeArg& replacement);
/** Removes all output edges from the given Node of the Graph.
This should probably be elevated to the Graph API eventually. */
size_t RemoveNodeOutputEdges(Graph& graph, Node& node);
/** Replaces the input to nodes that are downstream from 'node', which was being provided by an output of 'node',
with an output from a different node. Moves the output edges from 'node' for 'output_idx' to the replacement node.
@param replacement The node providing the replacement output.
@param replacement_output_idx The index of the output from 'replacement' to use.
e.g. Node A produces outputs A1 and A2.
Node B consumes A2 (edge between A and B for A2) and produces B1.
Node C consumes B1 (edge between B and C for B1).
If Node B was determined to not be needed, you would call ReplaceDownstreamNodeInput(graph, B, 0, A, 1)
to replace B1 (output index 0 for node B) with A2 (output index 1 for node A) as input to the downstream node C.
The edge that existed between B and C for B1 will be removed, and replaced with an edge between A and C for A2.
*/
void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node& replacement, int replacement_output_idx);
/** Replace the input to a node with a NodeArg.
@remarks The replacement only updates the node's input definition and does not create any edges,
as typically this function is used to replace an input with an initializer or graph input
(there is no edge between an initializer or graph input and a Node).
*/
void ReplaceNodeInput(Node& target, int target_input_idx, NodeArg& new_input);
/** Add an input to a node with a NodeArg for an initializer or graph input.
@remarks target_input_idx must be the next input slot.
e.g. if a Node has 2 inputs, AddNodeInput can only add input 3 and not 4.
There is no edge between an initializer or graph input and a Node, so the replacement only updates the
node's input definition and does not create any new edges.
*/
void AddNodeInput(Node& target, int target_input_idx, NodeArg& new_input);
/** Finalize the fusion of second_node into first_node.
The output definitions and edges from the second_node are moved to first_node. second_node is deleted.
e.g. Conv + Add fusion fuses the 'Add' into the Conv.
*/
void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node);
/** Finalize the fusion of two or more nodes which are being replaced with a single node.
The first and last entries in 'nodes' are assumed to be the first and last nodes in a chain of nodes being fused.
Conceptually multiple nodes are being combined into one, and post-fusion will produce output/s with the same names
as the last node in 'nodes', and be connected to the same downstream nodes.
The input edges to the first node in 'nodes' will be moved to replacement_node. No other input edges are moved.
The output definitions and edges from the last node in 'nodes' will be moved to replacement_node.
All nodes in 'nodes' will be removed.
*/
void FinalizeNodeFusion(Graph& graph, const std::vector<std::reference_wrapper<Node>>& nodes, Node& replacement_node);
} // namespace graph_utils
} // namespace onnxruntime