onnxruntime/onnxruntime/core/graph/graph_utils.cc
guyang3532 3e4db2c686
Fuse Cast + SoftmaxCrossEntropyLossInternal (#20334)
### Description
Fuse Cast + SoftmaxCrossEntropyLossInternal to
SoftmaxCrossEntropyLossInternal.
2024-04-29 14:12:10 +08:00

901 lines
37 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/graph_utils.h"
#include <queue>
#include "core/graph/graph.h"
#include "core/common/logging/logging.h"
namespace onnxruntime {
namespace graph_utils {
//---------------------
//--- local helpers ---
//---------------------
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
static int GetIndexFromName(const Node& node, const std::string& name, bool is_input) {
const auto& node_args = is_input ? node.InputDefs() : node.OutputDefs();
auto itr = std::find_if(node_args.begin(), node_args.end(),
[&name](const NodeArg* node_arg) { return node_arg->Name() == name; });
ORT_ENFORCE(itr != node_args.end(),
"Attempting to get index by a name which does not exist:", name, "for node: ", node.Name());
auto index = std::distance(node_args.begin(), itr);
return static_cast<int>(index);
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD)
// check if an output edge provides an implicit input to the destination node
static bool OutputEdgeProvidesImplicitInput(const Graph& graph, const GraphEdge& output_edge) {
// we treat the explicit and implicit inputs as sequential, so if the destination arg index of an output edge
// is past the valid range for the node's explicit inputs, it is for an implicit input
const size_t num_explicit_inputs = (*graph.GetNode(output_edge.dst_node)).InputDefs().size();
return static_cast<size_t>(output_edge.dst_arg_index) >= num_explicit_inputs;
}
/** Checks if new_output_name can be used to replace removed_output_name in the subgraph input.
If there is an existing NodeArg in a subgraph that implicitly consumes removed_output_name, it is not safe. */
static bool CanUpdateImplicitInputNameInSubgraph(const Node& node,
const std::string& removed_output_name,
const std::string& new_output_name) {
if (!node.ContainsSubgraph())
return true;
for (const gsl::not_null<const Graph*>& subgraph : node.GetSubgraphs()) {
// if we have an existing NodeArg in the subgraph with the new_output_name that would override an implicit input
// with the same name
if (subgraph->GetNodeArg(new_output_name) != nullptr) {
return false;
}
for (auto& subgraph_node : subgraph->Nodes()) {
// recurse if this node also consumes removed_output_name as an implicit input (i.e. there are multiple levels of nested
// subgraphs, and at least one level lower uses removed_output_name as an implicit input
const auto subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs();
if (!subgraph_node_implicit_inputs.empty()) {
auto subgraph_node_also_consumes_nodearg_as_implicit_input =
std::find_if(subgraph_node_implicit_inputs.cbegin(), subgraph_node_implicit_inputs.cend(),
[&removed_output_name](const NodeArg* input) {
return input != nullptr && input->Name() == removed_output_name;
});
if (subgraph_node_also_consumes_nodearg_as_implicit_input != subgraph_node_implicit_inputs.cend()) {
if (!CanUpdateImplicitInputNameInSubgraph(subgraph_node, removed_output_name, new_output_name))
return false;
}
}
}
}
return true;
}
/** Updates removed_output_name with new_output_name in the subgraph input. */
static void UpdateImplicitInputNameInSubgraph(Node& node,
const std::string& removed_output_name,
const std::string& new_output_name) {
for (auto& attr_subgraph_pair : node.GetAttributeNameToMutableSubgraphMap()) {
Graph& subgraph = *attr_subgraph_pair.second;
for (auto& subgraph_node : subgraph.Nodes()) {
// recurse if this node also consumes removed_output_name as an implicit input
// (i.e. there are multiple levels of nested subgraphs, and at least one level lower uses
// removed_output_name as an implicit input
const auto subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs();
if (!subgraph_node_implicit_inputs.empty()) {
auto subgraph_node_also_consumes_nodearg_as_implicit_input =
std::find_if(subgraph_node_implicit_inputs.cbegin(), subgraph_node_implicit_inputs.cend(),
[&removed_output_name](const NodeArg* input) {
return input->Name() == removed_output_name;
});
if (subgraph_node_also_consumes_nodearg_as_implicit_input != subgraph_node_implicit_inputs.cend()) {
UpdateImplicitInputNameInSubgraph(subgraph_node, removed_output_name, new_output_name);
}
}
// Need mutable input defs to be able to update the implicit input names
auto& input_args = subgraph_node.MutableInputDefs();
if (!input_args.empty()) {
int input_slot_index = -1;
for (const auto* input_arg : input_args) {
++input_slot_index;
// if the input matches, replace the NodeArg with one using the new name
if (input_arg->Exists() && input_arg->Name() == removed_output_name) {
// sanity check there was no edge for this input. implicit inputs from outer scope do not have edges
ORT_ENFORCE(std::count_if(subgraph_node.InputEdgesBegin(), subgraph_node.InputEdgesEnd(),
[input_slot_index](const Node::EdgeEnd& entry) {
return entry.GetDstArgIndex() == input_slot_index;
}) == 0);
// Create a new NodeArg with the new name
input_args[input_slot_index] = &attr_subgraph_pair.second->GetOrCreateNodeArg(new_output_name,
input_arg->TypeAsProto());
}
}
}
}
}
}
/** Given a graph, a list of edges, and a NodeArg name, checks if each of the edges provides an implicit input
to a subgraph. If so, it checks if there is no clash of the given NodeArg name in each of the subgraphs.
This is important when removing a node with this NodeArg as input. */
static bool CanUpdateImplicitInputNameInSubgraphs(const Graph& graph,
const std::vector<GraphEdge>& output_edges,
const std::string& new_arg_name, const logging::Logger& logger) {
for (const auto& output_edge : output_edges) {
if (OutputEdgeProvidesImplicitInput(graph, output_edge)) {
const Node& output_edge_node = *graph.GetNode(output_edge.dst_node);
if (!CanUpdateImplicitInputNameInSubgraph(output_edge_node, output_edge.arg_name, new_arg_name)) {
LOGS(logger, WARNING) << " Implicit input name " << output_edge.arg_name
<< " cannot be safely updated to " << new_arg_name << " in one of the subgraphs.";
return false;
}
}
}
return true;
}
/** Removes a node with a single incoming node and connects the incoming node with the output node/s.*/
static bool RemoveNodeWithSingleNodeInSingleUsedOutput(Graph& graph, Node& node) {
// Store info for input and output edges.
std::vector<GraphEdge> output_edges = GraphEdge::GetNodeOutputEdges(node);
if (!output_edges.empty()) {
// get non-const incoming Node
const Node::EdgeEnd& input_edge = *node.InputEdgesBegin();
Node& incoming_node = *graph.GetNode(input_edge.GetNode().Index());
auto src_idx = output_edges.front().src_arg_index;
ORT_ENFORCE(std::all_of(output_edges.cbegin(), output_edges.cend(),
[&src_idx](const GraphEdge& edge) {
return edge.src_arg_index == src_idx;
}),
"Node must only have one used output");
// replace the output edges from 'node' with an edge to node's incoming node
ReplaceDownstreamNodeInput(graph, node, src_idx, incoming_node, input_edge.GetSrcArgIndex());
}
graph.RemoveNode(node.Index());
return true;
}
void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node) {
auto target_idx = target_node.Index();
auto input_edges = GraphEdge::GetNodeInputEdges(src_node);
for (auto cur = input_edges.cbegin(), end = input_edges.cend(); cur != end; ++cur) {
auto target_arg_index = GetNodeInputIndexFromInputName(target_node, cur->arg_name);
graph.AddEdge(cur->src_node, target_idx, cur->src_arg_index, target_arg_index);
}
GraphEdge::RemoveGraphEdges(graph, input_edges);
}
/** Move the output defs and edges from src_node to target_node.
After the move is complete src_node will have no output edges and can be safely removed by Graph::RemoveNode.
*/
static void MoveAllNodeOutputs(Graph& graph, Node& src_node, Node& target_node) {
// copy the NodeArg*'s for all output defs.
target_node.MutableOutputDefs() = src_node.MutableOutputDefs();
auto target_idx = target_node.Index();
auto output_edges = GraphEdge::GetNodeOutputEdges(src_node);
for (auto cur = output_edges.cbegin(), end = output_edges.cend(); cur != end; ++cur) {
graph.AddEdge(target_idx, cur->dst_node, cur->src_arg_index, cur->dst_arg_index);
}
GraphEdge::RemoveGraphEdges(graph, output_edges);
}
#endif // !defined(ORT_MINIMAL_BUILD)
//----------------------------
//--- end of local helpers ---
//----------------------------
bool MatchesOpSinceVersion(const Node& node, std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> versions) {
return std::find(versions.begin(), versions.end(), node.SinceVersion()) != versions.end();
}
bool MatchesOpSinceVersion(const Node& node, gsl::span<const ONNX_NAMESPACE::OperatorSetVersion> versions) {
return std::find(versions.begin(), versions.end(), node.SinceVersion()) != versions.end();
}
bool MatchesOpSetDomain(const Node& node, std::string_view domain) {
const auto& node_domain = node.Domain();
return node_domain == domain;
}
bool IsSupportedOptypeVersionAndDomain(const Node& node,
std::string_view op_type,
std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> versions,
std::string_view domain) {
std::vector<ONNX_NAMESPACE::OperatorSetVersion> versions_vec(versions);
return IsSupportedOptypeVersionAndDomain(node, op_type, versions_vec, domain);
}
bool IsSupportedOptypeVersionAndDomain(const Node& node, std::string_view op_type,
gsl::span<const ONNX_NAMESPACE::OperatorSetVersion> versions,
std::string_view domain) {
return (node.OpType() == op_type &&
// we don't have op schemas in the minimal build so there's no way to check the deprecated flag
#if !defined(ORT_MINIMAL_BUILD)
!node.Op()->Deprecated() &&
#endif
MatchesOpSinceVersion(node, versions) && MatchesOpSetDomain(node, domain));
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name) {
const auto& attrs = node.GetAttributes();
const auto iter = attrs.find(attr_name);
return iter == attrs.end() ? nullptr : &iter->second;
}
NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) {
// sanity check as AddInitializedTensor silently ignores attempts to add a duplicate initializer
const ONNX_NAMESPACE::TensorProto* existing = nullptr;
ORT_ENFORCE(!graph.GetInitializedTensor(new_initializer.name(), existing),
"Initializer with same name exists. Name:", new_initializer.name());
graph.AddInitializedTensor(new_initializer);
ONNX_NAMESPACE::TypeProto new_type;
auto* typeproto_tensor = new_type.mutable_tensor_type();
typeproto_tensor->set_elem_type(new_initializer.data_type());
auto* shape = typeproto_tensor->mutable_shape();
for (auto dim : new_initializer.dims()) {
shape->add_dim()->set_dim_value(dim);
}
return graph.GetOrCreateNodeArg(new_initializer.name(), &new_type);
}
int GetNodeOutputIndexFromOutputName(const Node& node, const std::string& output_name) {
return GetIndexFromName(node, output_name, false);
}
std::vector<const Node*> FindParentsByType(const Node& node, const std::string& parent_type) {
// find parents and sort them by destination argument index
// as there is at most one input edge for each input argument,
// there is no need of extra work like FindChildrenByType
std::vector<const Node*> parents(node.InputDefs().size(), nullptr);
for (auto it = node.InputEdgesBegin(); it != node.InputEdgesEnd(); it++) {
if (it->GetNode().OpType().compare(parent_type) == 0) {
parents[it->GetDstArgIndex()] = &(it->GetNode());
}
}
// remove unmatched nodes
parents.erase(std::remove(parents.begin(), parents.end(), nullptr), parents.end());
return parents;
}
std::vector<const Node*> FindChildrenByType(const Node& node, const std::string& child_type) {
// find children and sort them by source argument index:
// Create a 2D vector to hold the result.
// 1st dimension index is output index,
// and the 2nd dimension stores the edges from the output.
std::vector<std::vector<const Node*>> children(node.OutputDefs().size(), std::vector<const Node*>());
for (auto it = node.OutputEdgesBegin(); it != node.OutputEdgesEnd(); it++) {
if (it->GetNode().OpType().compare(child_type) == 0) {
children[it->GetSrcArgIndex()].push_back(&(it->GetNode()));
}
}
// aggregate children
std::vector<const Node*> agg_res;
for (size_t output_idx = 0; output_idx < children.size(); output_idx++) {
agg_res.insert(agg_res.end(), children[output_idx].begin(), children[output_idx].end());
}
return agg_res;
}
const std::string& GetNodeInputName(const Node& node, int index) {
const auto& inputs = node.InputDefs();
ORT_ENFORCE(index >= 0 && static_cast<size_t>(index) < inputs.size(),
"Attempting to get an input that does not exist.");
return inputs[index]->Name();
}
const std::string& GetNodeOutputName(const Node& node, int index) {
const auto& outputs = node.OutputDefs();
ORT_ENFORCE(index >= 0 && static_cast<size_t>(index) < outputs.size(),
"Attempting to get an output that does not exist.");
return outputs[index]->Name();
}
size_t RemoveNodeOutputEdges(Graph& graph, Node& node) {
std::vector<GraphEdge> output_edges = GraphEdge::GetNodeOutputEdges(node);
GraphEdge::RemoveGraphEdges(graph, output_edges);
return output_edges.size();
}
size_t RemoveNodeOutputEdges(Graph& graph, Node& node, int output_idx) {
std::vector<GraphEdge> output_edges = GraphEdge::GetNodeOutputEdges(node, output_idx);
GraphEdge::RemoveGraphEdges(graph, output_edges);
return output_edges.size();
}
const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const Graph& graph, const std::string& initializer_name,
bool check_outer_scope) {
return graph.GetConstantInitializer(initializer_name, check_outer_scope);
}
GraphEdge::GraphEdge(NodeIndex src_node,
NodeIndex dst_node,
int src_arg_index,
int dst_arg_index,
const std::string& arg_name) : src_node(src_node),
dst_node(dst_node),
src_arg_index(src_arg_index),
dst_arg_index(dst_arg_index),
arg_name(arg_name) {}
// Constructs a GraphEdge given a node, an edge_end, and a boolean for the edge direction.
GraphEdge GraphEdge::CreateGraphEdge(const Node& node, const Node::EdgeEnd& edge_end, bool is_input_edge) {
return is_input_edge
? GraphEdge(edge_end.GetNode().Index(),
node.Index(),
edge_end.GetSrcArgIndex(),
edge_end.GetDstArgIndex(),
GetNodeInputName(node, edge_end.GetDstArgIndex()))
: GraphEdge(node.Index(),
edge_end.GetNode().Index(),
edge_end.GetSrcArgIndex(),
edge_end.GetDstArgIndex(),
GetNodeOutputName(node, edge_end.GetSrcArgIndex()));
}
const Node::EdgeEnd* GetInputEdge(const Node& node, int arg_index) {
for (auto it = node.InputEdgesBegin(), end = node.InputEdgesEnd(); it != end; ++it) {
if (arg_index == it->GetDstArgIndex()) {
return &(*it);
}
}
return nullptr;
}
/** Returns a vector of the input GraphEdges of a node. */
std::vector<GraphEdge> GraphEdge::GetNodeInputEdges(const Node& node) {
std::vector<GraphEdge> input_edges;
for (auto it = node.InputEdgesBegin(), end = node.InputEdgesEnd(); it != end; ++it) {
input_edges.push_back(GraphEdge::CreateGraphEdge(node, *it, true));
}
return input_edges;
}
/** Returns a vector of the input GraphEdges of a node for the provided input index. */
std::vector<GraphEdge> GraphEdge::GetNodeInputEdges(const Node& node, size_t index) {
std::vector<GraphEdge> input_edges;
for (auto it = node.InputEdgesBegin(), end = node.InputEdgesEnd(); it != end; ++it) {
if (static_cast<size_t>(it->GetDstArgIndex()) == index) {
input_edges.push_back(GraphEdge::CreateGraphEdge(node, *it, true));
}
}
return input_edges;
}
/** Returns a vector of the output GraphEdges of a node. */
std::vector<GraphEdge> GraphEdge::GetNodeOutputEdges(const Node& node) {
std::vector<GraphEdge> output_edges;
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
output_edges.push_back(GraphEdge::CreateGraphEdge(node, *it, false));
}
return output_edges;
}
/** Returns a vector of output GraphEdges of a node for the provided output index. */
std::vector<GraphEdge> GraphEdge::GetNodeOutputEdges(const Node& node, size_t index) {
std::vector<GraphEdge> output_edges;
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
if (static_cast<size_t>(it->GetSrcArgIndex()) == index) {
output_edges.push_back(GraphEdge::CreateGraphEdge(node, *it, false));
}
}
return output_edges;
}
/** Removes a set of GraphEdges from the graph. */
void GraphEdge::RemoveGraphEdges(Graph& graph, const std::vector<GraphEdge>& edges) {
for (const auto& edge_to_remove : edges) {
graph.RemoveEdge(edge_to_remove.src_node,
edge_to_remove.dst_node,
edge_to_remove.src_arg_index,
edge_to_remove.dst_arg_index);
}
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD)
int GetNodeInputIndexFromInputName(const Node& node, const std::string& input_name) {
return GetIndexFromName(node, input_name, true);
}
bool IsSupportedProvider(const Node& node,
const InlinedHashSet<std::string_view>& compatible_providers) {
return !(!compatible_providers.empty() &&
compatible_providers.find(node.GetExecutionProviderType()) == compatible_providers.end());
}
/** Checks for nodes with >= 1 outputs, if only one of the outputs is input to downstream Operators.
Returns the name of the single used output in output_name. */
static bool IsOnlyOneOutputUsed(const Graph& graph, const Node& node, const std::string*& output_name) {
constexpr int unassigned = -1;
int first_output = unassigned;
// check that there are only edges for one output, and set the output_name
if (node.GetOutputEdgesCount() > 0) {
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
if (first_output == unassigned) {
first_output = it->GetSrcArgIndex();
} else if (first_output != it->GetSrcArgIndex()) {
return false;
}
}
output_name = &node.OutputDefs()[first_output]->Name();
}
// outputs could also be direct graph outputs so check if there are any graph outputs that
// a) there's only 1, and b) it's the same as any output consumed by another node
auto output_indexes = graph.GetNodeOutputsInGraphOutputs(node);
auto num_graph_outputs = output_indexes.size();
if (num_graph_outputs > 1)
return false;
else if (num_graph_outputs == 1) {
if (first_output != unassigned)
// an output is consumed by other nodes, so make sure the same output is providing the graph output
return output_indexes.front() == first_output;
else {
// graph output only as no other nodes are consuming the output, so just update the output_name
output_name = &node.OutputDefs()[output_indexes.front()]->Name();
}
}
return true;
}
bool IsOutputUsed(const Node& node, int index) {
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
if (it->GetSrcArgIndex() == index) {
return true;
}
}
return false;
}
bool CanRemoveNode(const Graph& graph, const Node& node, const logging::Logger& logger) {
const std::string* output_name = nullptr;
if (!IsOnlyOneOutputUsed(graph, node, output_name)) {
return false;
}
// TODO: Currently we remove the node and use the input name from the node being removed.
// It may also be possible to instead update an upstream node to use the output name from the node being removed.
// This would allow removal of a node that is providing a graph output, as that output name would come from updating
// the upstream node. This should also enable removal if CanUpdateImplicitInputNameInSubgraphs returns false.
if (graph.NodeProducesGraphOutput(node)) {
return false;
}
bool can_remove = false;
const std::string* new_name = nullptr;
if (node.GetInputEdgesCount() == 1) {
// we will merge the single input edge with the edges for the output that is used
// Note that the node may have other inputs coming from initializers or graph inputs that do not have edges.
new_name = &GetNodeInputName(node, node.InputEdgesBegin()->GetDstArgIndex());
} else if (node.InputDefs().size() == 1) {
// we can also handle a node with a single input from an initializer or graph input (no edges)
new_name = &node.InputDefs()[0]->Name();
} else {
// No other node removal is supported
}
if (new_name) {
// Check that changing the current output name to the new name won't break any subgraphs that consume it
std::vector<GraphEdge> output_edges = GraphEdge::GetNodeOutputEdges(node);
can_remove = CanUpdateImplicitInputNameInSubgraphs(graph, output_edges, *new_name, logger);
}
return can_remove;
}
bool RemoveNode(Graph& graph, Node& node) {
// TODO: enable the check back
// assert(CanRemoveNode(graph, node, nullptr));
// Note: Node does not produce any graph outputs, and only a single output is used.
// If there is a single input edge from another node (initializers are not connected with edges to nodes)
if (node.GetInputEdgesCount() == 1) {
// remove the node and wire its incoming node to its outgoing node/s
return RemoveNodeWithSingleNodeInSingleUsedOutput(graph, node);
}
// single input def so replace node with that
if (node.InputDefs().size() == 1) {
return ReplaceNodeWithInitializer(graph, node, *node.MutableInputDefs()[0]);
}
ORT_THROW("Should be unreachable if CanRemoveNodeAndMergeEdges is in sync with the logic here.");
}
bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const std::string& initializer_name,
const logging::Logger& logger) {
// we have no way to handle replacing multiple outputs so check only one is used
const std::string* output_name = nullptr;
if (!IsOnlyOneOutputUsed(graph, node, output_name) || output_name == nullptr) {
return false;
}
bool output_name_is_changing = *output_name != initializer_name;
auto num_graph_outputs = graph.GetNodeOutputsInGraphOutputs(node).size();
if (num_graph_outputs > 0) {
// Cannot remove a node that provides more than one graph output,
// or a node whose single graph output is not being replaced by an initializer with the same name
if (num_graph_outputs > 1 || output_name_is_changing) {
return false;
}
}
bool can_remove = true;
if (output_name_is_changing) {
// Check that changing the current output name to the new name won't break any subgraphs
// that consume the current name
std::vector<GraphEdge> output_edges = GraphEdge::GetNodeOutputEdges(node);
can_remove = CanUpdateImplicitInputNameInSubgraphs(graph, output_edges, initializer_name, logger);
}
return can_remove;
}
bool ReplaceNodeWithInitializer(Graph& graph, Node& node, NodeArg& replacement) {
// We have to remove the output edges before we create replacement ones, so save the current output edge information
std::vector<GraphEdge> output_edges = GraphEdge::GetNodeOutputEdges(node);
// Remove the output edges of the node and then the node (this will remove any input edges).
RemoveNodeOutputEdges(graph, node);
graph.RemoveNode(node.Index());
// Re-create the output edges using 'replacement' as the source NodeArg (input) to the destination node/s
for (auto& output_edge : output_edges) {
// Take care of subgraph inputs.
if (OutputEdgeProvidesImplicitInput(graph, output_edge)) {
Node& mutable_output_edge_node = *graph.GetNode(output_edge.dst_node);
UpdateImplicitInputNameInSubgraph(mutable_output_edge_node, output_edge.arg_name,
replacement.Name());
}
// Replace outgoing node's input.
auto& output_node = *graph.GetNode(output_edge.dst_node);
ReplaceNodeInput(output_node, output_edge.dst_arg_index, replacement);
}
return true;
}
bool IsGraphInput(const Graph& graph, const NodeArg* input) {
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
return std::find(graph_inputs.begin(), graph_inputs.end(), input) != graph_inputs.end();
}
bool IsInitializer(const Graph& graph, const std::string& name, bool check_outer_scope) {
bool is_initializer = false;
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
if (graph.GetInitializedTensor(name, initializer)) {
is_initializer = true;
} else if (check_outer_scope && graph.IsSubgraph() && graph.IsOuterScopeValue(name)) {
is_initializer = IsInitializer(*graph.ParentGraph(), name, check_outer_scope);
}
return is_initializer;
}
bool IsConstantInitializer(const Graph& graph, const std::string& initializer_name, bool check_outer_scope) {
const ONNX_NAMESPACE::TensorProto* initializer = GetConstantInitializer(graph, initializer_name, check_outer_scope);
return initializer != nullptr;
}
bool NodeArgIsConstant(const Graph& graph, const NodeArg& node_arg) {
return IsConstantInitializer(graph, node_arg.Name(), true);
}
bool AllNodeInputsAreConstant(const Graph& graph, const Node& node, InitializedTensorSet& constant_inputs,
const InlinedHashSet<std::string>& excluded_initializers) {
// clear so we have a known state. if we fail part way through we go back to this state.
constant_inputs.clear();
// only initializers can be constant. There's no edge from a node to an initializer
// so the input edges count will be 0 if all the inputs are initializers.
if (node.GetInputEdgesCount() > 0) {
return false;
}
for (const auto* input_def : node.InputDefs()) {
// For optional node inputs which are missing, we can safely ignore them
if (input_def->Name().empty()) {
continue;
}
// Important note: when an initializer appears in the graph's input, this input will not be considered constant,
// because it can be overridden by the user at runtime. For constant folding to be applied, the initializer should
// not appear in the graph's inputs (that is the only way to guarantee it will always be constant).
const ONNX_NAMESPACE::TensorProto* initializer = GetConstantInitializer(graph, input_def->Name(), true);
if (initializer && excluded_initializers.find(input_def->Name()) == excluded_initializers.cend()) {
constant_inputs.insert({input_def->Name(), initializer});
} else {
constant_inputs.clear();
return false;
}
}
return true;
}
const Node* FirstChildByType(const Node& node, const std::string& child_type) {
for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) {
if ((*it).OpType().compare(child_type) == 0) {
return &(*it);
}
}
return nullptr;
}
const Node* FirstParentByType(const Node& node, const std::string& parent_type) {
for (auto it = node.InputNodesBegin(); it != node.InputNodesEnd(); ++it) {
if ((*it).OpType().compare(parent_type) == 0) {
return &(*it);
}
}
return nullptr;
}
void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node& replacement, int replacement_output_idx) {
// get the output edges from node for output_idx
std::vector<GraphEdge> output_edges = GraphEdge::GetNodeOutputEdges(node, output_idx);
if (!output_edges.empty()) {
const auto& replacement_name = replacement.MutableOutputDefs()[replacement_output_idx]->Name();
// Remove the output edges of the node first
GraphEdge::RemoveGraphEdges(graph, output_edges);
// Create connections between the replacement node and the outgoing nodes
for (const auto& output_edge : output_edges) {
// Take care of subgraph inputs.
if (OutputEdgeProvidesImplicitInput(graph, output_edge)) {
Node& mutable_output_edge_node = *graph.GetNode(output_edge.dst_node);
UpdateImplicitInputNameInSubgraph(mutable_output_edge_node, output_edge.arg_name, replacement_name);
}
// Add new edge connecting the input with the output nodes directly.
// This also updates the destination node's input node args
graph.AddEdge(replacement.Index(), output_edge.dst_node, replacement_output_idx, output_edge.dst_arg_index);
}
}
}
void ReplaceNodeInput(Node& target, int target_input_idx, NodeArg& new_input) {
size_t dst_arg_idx = static_cast<size_t>(target_input_idx);
auto num_explicit_inputs = target.InputDefs().size();
if (dst_arg_idx < num_explicit_inputs) {
target.MutableInputDefs()[target_input_idx] = &new_input;
} else if (dst_arg_idx < num_explicit_inputs + target.ImplicitInputDefs().size()) {
// If we need to update an implicit input.
target.MutableImplicitInputDefs()[dst_arg_idx - num_explicit_inputs] = &new_input;
} else {
// logic error in our code
ORT_THROW("Invalid input index for node ", target.Name(), ". Index:", target_input_idx,
" ExplicitInputs:", num_explicit_inputs,
" ImplicitInputs:", target.ImplicitInputDefs().size());
}
}
void AddNodeInput(Node& target, int target_input_idx, NodeArg& new_input) {
auto num_explicit_inputs = target.InputDefs().size();
ORT_ENFORCE(num_explicit_inputs == static_cast<size_t>(target_input_idx),
"Can only add a new input at the end of the current ones.");
target.MutableInputDefs().push_back(&new_input);
assert(target.MutableInputArgsCount().size() > static_cast<size_t>(target_input_idx)); // expect existing entry for all possible inputs
target.MutableInputArgsCount()[target_input_idx] = 1;
}
void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node) {
// move the outputs from second_node to first_node
RemoveNodeOutputEdges(graph, first_node);
MoveAllNodeOutputs(graph, second_node, first_node);
// second node now has no output edges and can be removed
graph.RemoveNode(second_node.Index());
}
void FinalizeNodeFusion(Graph& graph, gsl::span<const std::reference_wrapper<Node>> nodes, Node& replacement_node_start,
Node& replacement_node_end) {
MoveAllNodeInputEdges(graph, *nodes.begin(), replacement_node_start);
MoveAllNodeOutputs(graph, nodes.back(), replacement_node_end);
for (Node& node : nodes) {
RemoveNodeOutputEdges(graph, node);
graph.RemoveNode(node.Index());
}
}
const Node* GetInputNode(const Node& node, int arg_index) {
const Node::EdgeEnd* edge = GetInputEdge(node, arg_index);
if (nullptr == edge) {
return nullptr;
}
return &(edge->GetNode());
}
inline std::string ToString(gsl::span<const ONNX_NAMESPACE::OperatorSetVersion> versions) {
std::ostringstream output;
if (!versions.empty()) {
// Convert all but the last element to avoid a trailing ";"
std::copy(versions.begin(), versions.end() - 1,
std::ostream_iterator<ONNX_NAMESPACE::OperatorSetVersion>(output, ";"));
// Now add the last element with no delimiter
output << versions.back();
}
return output.str();
}
bool FindPath(const Node& node, bool is_input_edge, gsl::span<const EdgeEndToMatch> edges_to_match,
std::vector<const Node::EdgeEnd*>& result, const logging::Logger& logger) {
result.clear();
result.reserve(edges_to_match.size());
const Node* current_node = &node;
for (const auto& edge : edges_to_match) {
const Node::EdgeEnd* edge_found = nullptr;
#ifndef NDEBUG
LOGS(logger, VERBOSE) << (is_input_edge ? "I:" : "O:") << edge.src_arg_index << "," << edge.dst_arg_index
<< "," << edge.op_type << "," << edge.domain << "," << ToString(edge.versions);
#endif
auto edges_begin = is_input_edge ? current_node->InputEdgesBegin() : current_node->OutputEdgesBegin();
auto edges_end = is_input_edge ? current_node->InputEdgesEnd() : current_node->OutputEdgesEnd();
for (auto it = edges_begin; it != edges_end; ++it) {
#ifndef NDEBUG
LOGS(logger, VERBOSE) << "E:" << it->GetSrcArgIndex() << "," << it->GetDstArgIndex()
<< "," << it->GetNode().OpType() << "," << it->GetNode().Domain() << ","
<< it->GetNode().SinceVersion();
#endif
if (edge.dst_arg_index == it->GetDstArgIndex() &&
edge.src_arg_index == it->GetSrcArgIndex() &&
edge.op_type == it->GetNode().OpType() &&
MatchesOpSinceVersion(it->GetNode(), edge.versions) &&
MatchesOpSetDomain(it->GetNode(), edge.domain)) {
// For output edge, there could be multiple edges matched.
// This function will return failure in such case by design.
if (nullptr != edge_found) {
LOGS(logger, WARNING) << "Failed since multiple edges matched:" << current_node->OpType() << "->" << edge.op_type;
return false;
}
edge_found = &(*it);
// For input edge, each dst_arg_index only accepts one input edge so only there is at most one match.
if (is_input_edge) {
break;
}
}
}
if (nullptr == edge_found) {
return false;
}
result.push_back(edge_found);
current_node = &(edge_found->GetNode());
}
return true;
}
bool FindPath(Graph& graph, const Node& node, bool is_input_edge, gsl::span<const EdgeEndToMatch> edges_to_match, std::vector<std::reference_wrapper<Node>>& result, const logging::Logger& logger) {
result.clear();
std::vector<const Node::EdgeEnd*> edge_ends;
if (!FindPath(node, is_input_edge, edges_to_match, edge_ends, logger)) {
return false;
}
result.reserve(edges_to_match.size());
std::transform(edge_ends.begin(), edge_ends.end(), std::back_inserter(result), [&graph](const Node::EdgeEnd* edge_end) -> Node& {
return *graph.GetNode(edge_end->GetNode().Index());
});
return true;
}
bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) {
std::queue<NodeIndex> q;
InlinedHashSet<NodeIndex> removed_nodes;
NodeIndex start_node_index = start_node.Index();
q.push(start_node_index);
// From the current node, remove nodes bottom-up util it reaches a node with multiple outputs/graph output.
while (!q.empty()) {
NodeIndex cur_node_index = q.front();
q.pop();
if (removed_nodes.find(cur_node_index) != removed_nodes.end()) {
continue;
}
// Each eligible node in the subgraph must have less than one output edge and no output should be
// the graph output
const Node& cur_node = *graph.GetNode(cur_node_index);
if (cur_node.GetOutputEdgesCount() > 1 || graph.NodeProducesGraphOutput(cur_node)) {
continue;
}
// push the parents of current node to the queue.
for (unsigned int i = 0; i < cur_node.InputDefs().size(); ++i) {
const std::string& input_name = GetNodeInputName(cur_node, i);
if (IsInitializer(graph, input_name, true) || IsGraphInput(graph, cur_node.InputDefs()[i])) {
// skip initializers and graph inputs
continue;
}
const Node* parent_node = GetInputNode(cur_node, i);
if (nullptr == parent_node) {
continue;
}
q.push(parent_node->Index());
}
if (cur_node_index == start_node_index || cur_node.GetOutputEdgesCount() == 0) {
Node* cur_node_p = graph.GetNode(cur_node_index);
RemoveNodeOutputEdges(graph, *cur_node_p);
graph.RemoveNode(cur_node_index);
removed_nodes.insert(cur_node_index);
}
}
if (removed_nodes.size() == 0) {
// Nothing to remove
return false;
}
return true;
}
NodeArg& CreateNodeArg(Graph& graph, const NodeArg& base_arg) {
return graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_arg.Name()), base_arg.TypeAsProto());
}
#endif // !defined(ORT_MINIMAL_BUILD)
} // namespace graph_utils
} // namespace onnxruntime