Add helper to check if node provides a graph output. (#8186)

* Add helper to check if node provides a graph output. The current approach unnecessarily creates a vector when most of the optimizers only care about a true/false response.

* Undo accidental change

* Fix a couple of issues due to copying from larger set of changes.
This commit is contained in:
Scott McKay 2021-06-30 12:15:42 +10:00 committed by GitHub
parent 17d4545ccb
commit b3479367cf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 71 additions and 60 deletions

View file

@ -681,6 +681,19 @@ class Graph {
return std::find(graph_outputs_.begin(), graph_outputs_.end(), node_arg) != graph_outputs_.end();
}
/** Returns true if one or more of the Node outputs are Graph outputs.
@remarks Cheaper than calling GetNodeOutputsInGraphOutputs.
*/
bool GetNodeProvidesGraphOutput(const Node& node) const {
auto end_outputs = graph_outputs_.cend();
for (auto output_def : node.OutputDefs()) {
if (std::find(graph_outputs_.cbegin(), end_outputs, output_def) != end_outputs) {
return true;
}
}
return false;
}
/** Returns a vector with the indexes of the outputs of the given Node that are also Graph outputs. */
std::vector<int> GetNodeOutputsInGraphOutputs(const Node& node) const {
int output_idx = 0;

View file

@ -323,7 +323,7 @@ bool CanRemoveNode(const Graph& graph, const Node& node, const logging::Logger&
// This would allow removal of a node that is providing a graph output, as that output name would come from updating
// the upstream node. This should also enable removal if CanUpdateImplicitInputNameInSubgraphs returns false.
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
if (graph.GetNodeProvidesGraphOutput(node)) {
return false;
}
@ -762,7 +762,7 @@ bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) {
// Each eligible node in the subgraph must have less than one output edge and no output should be
// the graph output
const Node& cur_node = *graph.GetNode(cur_node_index);
if (cur_node.GetOutputEdgesCount() > 1 || !graph.GetNodeOutputsInGraphOutputs(cur_node).empty()) {
if (cur_node.GetOutputEdgesCount() > 1 || graph.GetNodeProvidesGraphOutput(cur_node)) {
continue;
}

View file

@ -25,7 +25,7 @@ void FuseResidualAddIfAny(Graph& graph, const Node& dropout_node,
// To be able to fuse the residual Add,
// the Dropout's output must not be a graph output and
// there must be only one consumer of the Dropout's first output.
if (dropout_consumers_count < 2 && graph.GetNodeOutputsInGraphOutputs(dropout_node).empty()) {
if (dropout_consumers_count < 2 && !graph.GetNodeProvidesGraphOutput(dropout_node)) {
for (auto last_node_itr = dropout_node.OutputNodesBegin(); last_node_itr != dropout_node.OutputNodesEnd(); ++last_node_itr) {
const Node& last_node = (*last_node_itr);
@ -139,7 +139,7 @@ Status BiasDropoutFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve
continue;
}
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
if (graph.GetNodeProvidesGraphOutput(node)) {
continue;
}

View file

@ -76,7 +76,7 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
if (graph.GetNodeProvidesGraphOutput(node)) {
continue;
}

View file

@ -96,12 +96,12 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
continue;
}
if (!graph.GetNodeOutputsInGraphOutputs(*node).empty()) {
if (graph.GetNodeProvidesGraphOutput(*node)) {
continue;
}
if (node->GetExecutionProviderType() == onnxruntime::kCudaExecutionProvider) {
if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() !=
if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() !=
ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
continue;
}
@ -125,7 +125,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
if (last_node.GetExecutionProviderType() != node->GetExecutionProviderType()) {
continue;
}
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Relu", {6, 13, 14}) &&
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Relu", {6, 13, 14}) &&
next_node.GetOutputEdgesCount() == 1) {
Node& conv_node = *node;
Node& add_node = *graph.GetNode(next_node.Index());

View file

@ -125,7 +125,7 @@ bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const
return false;
}
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
if (graph.GetNodeProvidesGraphOutput(node)) {
return false;
}

View file

@ -177,7 +177,7 @@ bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const
}
}
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
if (graph.GetNodeProvidesGraphOutput(node)) {
return false;
}

View file

@ -133,7 +133,7 @@ bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const
return false;
}
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
if (graph.GetNodeProvidesGraphOutput(node)) {
return false;
}

View file

@ -74,7 +74,7 @@ bool DivMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const
return false;
}
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
if (graph.GetNodeProvidesGraphOutput(node)) {
return false;
}

View file

@ -31,7 +31,7 @@ static bool CheckNode(Graph& graph, const Node& node, const std::string& op_name
node.GetExecutionProviderType() == provider &&
IsSupportedDataType(node) &&
(!require_single_output || node.GetOutputEdgesCount() == 1) &&
graph.GetNodeOutputsInGraphOutputs(node).empty();
!graph.GetNodeProvidesGraphOutput(node);
}
MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
@ -146,8 +146,8 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
if (p_cast1_node != nullptr) {
Node& cast1_node = *graph.GetNode(p_cast1_node->Index());
// this is fused Cast node, so expect 2 output edges
if (!CheckNode(graph, cast1_node, "Cast", {9, 13}, pow1_node.GetExecutionProviderType(), false) ||
cast1_node.GetOutputEdgesCount() != 2){
if (!CheckNode(graph, cast1_node, "Cast", {9, 13}, pow1_node.GetExecutionProviderType(), false) ||
cast1_node.GetOutputEdgesCount() != 2) {
return matchResult;
}
const Node* p_pow_node = graph_utils::FirstChildByType(cast1_node, "Pow");
@ -242,7 +242,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// if this is second formula and if pow node has Cast parent, expect mul5_node has Cast parent as well
NodeArg* cast_input_arg = nullptr;
if (second_formula) {
const Node* p_cast1_node = graph_utils::FirstParentByType(node, "Cast");
const Node* p_cast1_node = graph_utils::FirstParentByType(node, "Cast");
if (p_cast1_node != nullptr) {
// we've done the node check in second formula for pow node
Node& cast1_node = *graph.GetNode(p_cast1_node->Index());
@ -254,11 +254,11 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
Node& cast3_node = *graph.GetNode(p_cast3_node->Index());
if (!CheckNode(graph, cast3_node, "Cast", {9, 13}, node.GetExecutionProviderType(), true)) {
continue;
}
}
// overwrite and continue as usual
p_mul5_input_node = graph_utils::FirstParentByType(cast3_node, "Mul");
nodes_to_fuse.push_back(cast3_node);
// keep cast1_node for reuse, its output edges will be adjusted in FinalizeNodeFusion()
p_mul5_input_node = graph_utils::FirstParentByType(cast3_node, "Mul");
nodes_to_fuse.push_back(cast3_node);
// keep cast1_node for reuse, its output edges will be adjusted in FinalizeNodeFusion()
}
}
@ -275,8 +275,8 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
}
if (input_index == -1) continue;
// check same parent for both mul6 and pow, with or without cast
if (input_index == -1) continue;
// check same parent for both mul6 and pow, with or without cast
if (cast_input_arg != nullptr) {
if (mul6_node.InputDefs()[(input_index + 1) % 2]->Name() != cast_input_arg->Name())
continue;

View file

@ -61,7 +61,7 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
continue;
}
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
if (graph.GetNodeProvidesGraphOutput(node)) {
continue;
}

View file

@ -42,8 +42,8 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m
nodes_to_remove.push_back(gemm_node);
// check if output node is Transpose
if (output_node_ptr != gemm_node.OutputNodesEnd() &&
gemm_node.InputDefs().size() <= 2 && // C is missing
if (output_node_ptr != gemm_node.OutputNodesEnd() &&
gemm_node.InputDefs().size() <= 2 && // C is missing
output_node_ptr->OpType() == "Transpose") {
Node& output_node = *graph.GetNode(output_node_ptr->Index());
// (AB)' = B'A' : reverse the inputs
@ -83,7 +83,7 @@ bool GemmTransposeFusion::SatisfyCondition(const Graph& graph, const Node& node,
for (auto node_it = node.InputNodesBegin(); node_it != node.InputNodesEnd(); ++node_it) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(*node_it, "Transpose", {1, 13}) &&
node_it->GetOutputEdgesCount() == 1 &&
graph.GetNodeOutputsInGraphOutputs(*node_it).empty() &&
!graph.GetNodeProvidesGraphOutput(*node_it) &&
// Make sure the two nodes do not span execution providers.
node_it->GetExecutionProviderType() == node.GetExecutionProviderType()) {
return true;
@ -94,7 +94,7 @@ bool GemmTransposeFusion::SatisfyCondition(const Graph& graph, const Node& node,
// by the rule (AB)' = B'A' provided that C is missing
// Supported for Opset >=11 as earlier opsets have C as a required input
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gemm", {11, 13}) ||
!graph.GetNodeOutputsInGraphOutputs(node).empty() ||
graph.GetNodeProvidesGraphOutput(node) ||
// verify that C is missing
node.InputDefs().size() > 2) {
return false;

View file

@ -40,7 +40,7 @@ namespace onnxruntime {
X (def0/arg0) ---> Identity ---> Y
*/
Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
if (graph.GetNodeOutputsInGraphOutputs(node).empty()) {
if (!graph.GetNodeProvidesGraphOutput(node)) {
if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}
@ -65,7 +65,7 @@ bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node, c
return true;
}
bool node_output_is_graph_output = !graph.GetNodeOutputsInGraphOutputs(node).empty();
bool node_output_is_graph_output = graph.GetNodeProvidesGraphOutput(node);
// relax the condition if Identity is connecting to graph output
if (node.GetOutputEdgesCount() != 0 ||

View file

@ -94,7 +94,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
// and is assigned to the CPU EP (we have fp32 implementations of all kernels so forcing to fp32 is safe)
if (node.GetInputEdgesCount() > 0 &&
!node.ContainsSubgraph() &&
graph.GetNodeOutputsInGraphOutputs(node).empty() &&
!graph.GetNodeProvidesGraphOutput(node) &&
node.GetExecutionProviderType() == kCpuExecutionProvider) {
do {
// find the number of fp16 inputs as we need to make sure they're all coming from nodes that will be cast

View file

@ -35,7 +35,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
if (!graph_utils::IsSupportedOptypeVersionAndDomain(isinf_node, "IsInf", {10}) ||
isinf_node.GetOutputEdgesCount() != 1 ||
!graph.GetNodeOutputsInGraphOutputs(isinf_node).empty()) {
graph.GetNodeProvidesGraphOutput(isinf_node)) {
continue;
}
@ -67,7 +67,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
Node& cast2_node = *graph.GetNode(cast2_node_itr->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast2_node, "Cast", {9, 13}) ||
cast2_node.GetOutputEdgesCount() != 1 ||
!graph.GetNodeOutputsInGraphOutputs(cast2_node).empty()) {
graph.GetNodeProvidesGraphOutput(cast2_node)) {
continue;
}
nodes_to_remove.push_back(cast2_node);
@ -80,7 +80,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
Node& reduce_sum_node = *graph.GetNode(reduce_sum_node_itr->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_sum_node, "ReduceSum", {1, 11, 13}) ||
reduce_sum_node.GetOutputEdgesCount() != 1 ||
!graph.GetNodeOutputsInGraphOutputs(reduce_sum_node).empty()) {
graph.GetNodeProvidesGraphOutput(reduce_sum_node)) {
continue;
}
nodes_to_remove.push_back(reduce_sum_node);

View file

@ -79,7 +79,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13}) ||
!graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) ||
(reduce_mean_node.GetOutputEdgesCount() != 1 && reduce_mean_node.GetOutputEdgesCount() != 2) ||
!graph.GetNodeOutputsInGraphOutputs(reduce_mean_node).empty() ||
graph.GetNodeProvidesGraphOutput(reduce_mean_node) ||
!IsSupportedDataType(reduce_mean_node)) {
continue;
}
@ -377,7 +377,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow_node, "Pow", {7, 12, 13}) ||
!graph_utils::IsSupportedProvider(pow_node, GetCompatibleExecutionProviders()) ||
!optimizer_utils::CheckOutputEdges(graph, pow_node, 1) ||
!graph.GetNodeOutputsInGraphOutputs(pow_node).empty() ||
graph.GetNodeProvidesGraphOutput(pow_node) ||
!IsSupportedDataType(pow_node)) {
continue;
}

View file

@ -30,7 +30,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
if (graph.GetNodeProvidesGraphOutput(node)) {
continue;
}
@ -89,7 +89,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// valid bias_shapes are (N) or (1, N) or (M, 1) or (M, N) as
// GEMM only supports unidirectional broadcast on the bias input C
if (!gemm_input_defs.back()->Shape()) {
continue;
continue;
}
const auto& bias_shape = *gemm_input_defs.back()->Shape();
const auto& M = matmul_output.Shape()->dim()[0];

View file

@ -39,7 +39,7 @@ static Node* GetTransposeNodeFromOutput(Graph& graph, NodeArg& node_arg) {
}
// if the node has Graph output, skip it too
if (!graph.GetNodeOutputsInGraphOutputs(*trans_node).empty()) {
if (graph.GetNodeProvidesGraphOutput(*trans_node)) {
return nullptr;
}
@ -117,9 +117,8 @@ static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_
* V
*/
static Node* ReorderCastAndTranspose(Graph& graph, Node* cast,
std::unordered_map<NodeArg*, size_t>& consumer_count,
std::deque<onnxruntime::NodeIndex>& removed_nodes) {
std::unordered_map<NodeArg*, size_t>& consumer_count,
std::deque<onnxruntime::NodeIndex>& removed_nodes) {
ORT_ENFORCE(cast != nullptr);
auto transpose = GetTransposeNodeFromOutput(graph, *cast->MutableInputDefs()[0]);
if (transpose == nullptr) {
@ -138,18 +137,18 @@ static Node* ReorderCastAndTranspose(Graph& graph, Node* cast,
new_cast_output_type_proto.mutable_tensor_type()->set_elem_type(element_type);
auto& new_cast_output = graph.GetOrCreateNodeArg(cast_output->Name() + "_transformed", &new_cast_output_type_proto);
const std::vector<NodeArg*> new_cast_input_defs {transpose_input};
const std::vector<NodeArg*> new_cast_output_defs {&new_cast_output};
const std::vector<NodeArg*> new_cast_input_defs{transpose_input};
const std::vector<NodeArg*> new_cast_output_defs{&new_cast_output};
const std::vector<NodeArg*> new_transpose_input_defs = {&new_cast_output};
const std::vector<NodeArg*> new_transpose_output_defs = {cast_output};
(void) graph.AddNode(graph.GenerateNodeName(cast->Name() + "_transformed"),
cast->OpType(),
"Created a new Cast node to interchange Cast and Transpose nodes",
new_cast_input_defs,
new_cast_output_defs,
&cast->GetAttributes(),
cast->Domain());
(void)graph.AddNode(graph.GenerateNodeName(cast->Name() + "_transformed"),
cast->OpType(),
"Created a new Cast node to interchange Cast and Transpose nodes",
new_cast_input_defs,
new_cast_output_defs,
&cast->GetAttributes(),
cast->Domain());
Node& new_transpose = graph.AddNode(graph.GenerateNodeName(transpose->Name() + "_transformed"),
transpose->OpType(),
@ -169,8 +168,7 @@ static Node* ReorderCastAndTranspose(Graph& graph, Node* cast,
}
// Check whether the element_type is an allowed FusedMatMul data type or not.
static bool IsAllowedFusedMatMulDataType(ONNX_NAMESPACE::TensorProto_DataType element_type)
{
static bool IsAllowedFusedMatMulDataType(ONNX_NAMESPACE::TensorProto_DataType element_type) {
return element_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ||
element_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 ||
element_type == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE ||

View file

@ -164,7 +164,7 @@ size_t NchwcTransformerImpl::RemoveOutputEdges(Node& node) {
}
// Bias the edge count to handle the case of a node that produces a graph
// output.
if (!graph_.GetNodeOutputsInGraphOutputs(node).empty()) {
if (graph_.GetNodeProvidesGraphOutput(node)) {
output_edges_count++;
}
return output_edges_count;
@ -1145,7 +1145,7 @@ void NchwcTransformerImpl::TrackTransposeFromNhwc(Node& node) {
// Verify that the node does not produce a graph output and produces output
// for a single node.
if (!graph_.GetNodeOutputsInGraphOutputs(node).empty() || node.GetOutputEdgesCount() != 1) {
if (graph_.GetNodeProvidesGraphOutput(node) || node.GetOutputEdgesCount() != 1) {
return;
}

View file

@ -26,7 +26,7 @@ static bool TryCancelOutDQQPair(Graph& graph, Node& dq_node, Node& q_node) {
// check if dq_node has only one output edge and,
// dq_node and q_node output are not graph outputs
if (!optimizer_utils::CheckOutputEdges(graph, dq_node, 1) ||
!graph.GetNodeOutputsInGraphOutputs(q_node).empty()) {
graph.GetNodeProvidesGraphOutput(q_node)) {
return false;
}

View file

@ -173,8 +173,8 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
if (CheckFirstAdd(*p_add1, ln_node.GetExecutionProviderType()) &&
CheckSecondAdd(graph, *p_add2, ln_node.GetExecutionProviderType()) &&
graph.GetNodeOutputsInGraphOutputs(*p_add1).empty() &&
graph.GetNodeOutputsInGraphOutputs(*p_add2).empty()) {
!graph.GetNodeProvidesGraphOutput(*p_add1) &&
!graph.GetNodeProvidesGraphOutput(*p_add2)) {
matched_format = Format::Format1;
}
}
@ -191,8 +191,8 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
if (CheckFirstAdd(*p_add1, ln_node.GetExecutionProviderType()) &&
CheckSecondAdd(graph, *p_add2, ln_node.GetExecutionProviderType()) &&
graph.GetNodeOutputsInGraphOutputs(*p_add1).empty() &&
graph.GetNodeOutputsInGraphOutputs(*p_add2).empty()) {
!graph.GetNodeProvidesGraphOutput(*p_add1) &&
!graph.GetNodeProvidesGraphOutput(*p_add2)) {
matched_format = Format::Format2;
}
}

View file

@ -267,7 +267,7 @@ int32_t IndexOfNodeOutput(const Node& node, const NodeArg& node_arg) {
}
bool CheckOutputEdges(const Graph& graph, const Node& node, size_t expected_output_edges) {
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
if (graph.GetNodeProvidesGraphOutput(node)) {
return false;
}