diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 432a61fb91..ffc6fe1744 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -681,6 +681,19 @@ class Graph { return std::find(graph_outputs_.begin(), graph_outputs_.end(), node_arg) != graph_outputs_.end(); } + /** Returns true if one or more of the Node outputs are Graph outputs. + @remarks Cheaper than calling GetNodeOutputsInGraphOutputs. + */ + bool GetNodeProvidesGraphOutput(const Node& node) const { + auto end_outputs = graph_outputs_.cend(); + for (auto output_def : node.OutputDefs()) { + if (std::find(graph_outputs_.cbegin(), end_outputs, output_def) != end_outputs) { + return true; + } + } + return false; + } + /** Returns a vector with the indexes of the outputs of the given Node that are also Graph outputs. */ std::vector GetNodeOutputsInGraphOutputs(const Node& node) const { int output_idx = 0; diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index 741eb39c13..f0f9c9a212 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -323,7 +323,7 @@ bool CanRemoveNode(const Graph& graph, const Node& node, const logging::Logger& // This would allow removal of a node that is providing a graph output, as that output name would come from updating // the upstream node. This should also enable removal if CanUpdateImplicitInputNameInSubgraphs returns false. - if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) { + if (graph.GetNodeProvidesGraphOutput(node)) { return false; } @@ -762,7 +762,7 @@ bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) { // Each eligible node in the subgraph must have less than one output edge and no output should be // the graph output const Node& cur_node = *graph.GetNode(cur_node_index); - if (cur_node.GetOutputEdgesCount() > 1 || !graph.GetNodeOutputsInGraphOutputs(cur_node).empty()) { + if (cur_node.GetOutputEdgesCount() > 1 || graph.GetNodeProvidesGraphOutput(cur_node)) { continue; } diff --git a/onnxruntime/core/optimizer/bias_dropout_fusion.cc b/onnxruntime/core/optimizer/bias_dropout_fusion.cc index 0c59123656..27603008eb 100644 --- a/onnxruntime/core/optimizer/bias_dropout_fusion.cc +++ b/onnxruntime/core/optimizer/bias_dropout_fusion.cc @@ -25,7 +25,7 @@ void FuseResidualAddIfAny(Graph& graph, const Node& dropout_node, // To be able to fuse the residual Add, // the Dropout's output must not be a graph output and // there must be only one consumer of the Dropout's first output. - if (dropout_consumers_count < 2 && graph.GetNodeOutputsInGraphOutputs(dropout_node).empty()) { + if (dropout_consumers_count < 2 && !graph.GetNodeProvidesGraphOutput(dropout_node)) { for (auto last_node_itr = dropout_node.OutputNodesBegin(); last_node_itr != dropout_node.OutputNodesEnd(); ++last_node_itr) { const Node& last_node = (*last_node_itr); @@ -139,7 +139,7 @@ Status BiasDropoutFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve continue; } - if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) { + if (graph.GetNodeProvidesGraphOutput(node)) { continue; } diff --git a/onnxruntime/core/optimizer/bias_gelu_fusion.cc b/onnxruntime/core/optimizer/bias_gelu_fusion.cc index 88ed36c123..5618f8b4c3 100644 --- a/onnxruntime/core/optimizer/bias_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/bias_gelu_fusion.cc @@ -76,7 +76,7 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } - if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) { + if (graph.GetNodeProvidesGraphOutput(node)) { continue; } diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index 52351ebabc..38838f61ab 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -96,12 +96,12 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } - if (!graph.GetNodeOutputsInGraphOutputs(*node).empty()) { + if (graph.GetNodeProvidesGraphOutput(*node)) { continue; } if (node->GetExecutionProviderType() == onnxruntime::kCudaExecutionProvider) { - if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() != + if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { continue; } @@ -125,7 +125,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l if (last_node.GetExecutionProviderType() != node->GetExecutionProviderType()) { continue; } - if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Relu", {6, 13, 14}) && + if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Relu", {6, 13, 14}) && next_node.GetOutputEdgesCount() == 1) { Node& conv_node = *node; Node& add_node = *graph.GetNode(next_node.Index()); diff --git a/onnxruntime/core/optimizer/conv_add_fusion.cc b/onnxruntime/core/optimizer/conv_add_fusion.cc index 57c52643f9..67a718c4e0 100644 --- a/onnxruntime/core/optimizer/conv_add_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_fusion.cc @@ -125,7 +125,7 @@ bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const return false; } - if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) { + if (graph.GetNodeProvidesGraphOutput(node)) { return false; } diff --git a/onnxruntime/core/optimizer/conv_bn_fusion.cc b/onnxruntime/core/optimizer/conv_bn_fusion.cc index a3cceb9524..21056b27d5 100644 --- a/onnxruntime/core/optimizer/conv_bn_fusion.cc +++ b/onnxruntime/core/optimizer/conv_bn_fusion.cc @@ -177,7 +177,7 @@ bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const } } - if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) { + if (graph.GetNodeProvidesGraphOutput(node)) { return false; } diff --git a/onnxruntime/core/optimizer/conv_mul_fusion.cc b/onnxruntime/core/optimizer/conv_mul_fusion.cc index a70618460c..dd91b41ef7 100644 --- a/onnxruntime/core/optimizer/conv_mul_fusion.cc +++ b/onnxruntime/core/optimizer/conv_mul_fusion.cc @@ -133,7 +133,7 @@ bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const return false; } - if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) { + if (graph.GetNodeProvidesGraphOutput(node)) { return false; } diff --git a/onnxruntime/core/optimizer/div_mul_fusion.cc b/onnxruntime/core/optimizer/div_mul_fusion.cc index 39701c3248..565df3f672 100644 --- a/onnxruntime/core/optimizer/div_mul_fusion.cc +++ b/onnxruntime/core/optimizer/div_mul_fusion.cc @@ -74,7 +74,7 @@ bool DivMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const return false; } - if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) { + if (graph.GetNodeProvidesGraphOutput(node)) { return false; } diff --git a/onnxruntime/core/optimizer/fast_gelu_fusion.cc b/onnxruntime/core/optimizer/fast_gelu_fusion.cc index d42aedac8d..2014433bea 100644 --- a/onnxruntime/core/optimizer/fast_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/fast_gelu_fusion.cc @@ -31,7 +31,7 @@ static bool CheckNode(Graph& graph, const Node& node, const std::string& op_name node.GetExecutionProviderType() == provider && IsSupportedDataType(node) && (!require_single_output || node.GetOutputEdgesCount() == 1) && - graph.GetNodeOutputsInGraphOutputs(node).empty(); + !graph.GetNodeProvidesGraphOutput(node); } MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node, @@ -146,8 +146,8 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node, if (p_cast1_node != nullptr) { Node& cast1_node = *graph.GetNode(p_cast1_node->Index()); // this is fused Cast node, so expect 2 output edges - if (!CheckNode(graph, cast1_node, "Cast", {9, 13}, pow1_node.GetExecutionProviderType(), false) || - cast1_node.GetOutputEdgesCount() != 2){ + if (!CheckNode(graph, cast1_node, "Cast", {9, 13}, pow1_node.GetExecutionProviderType(), false) || + cast1_node.GetOutputEdgesCount() != 2) { return matchResult; } const Node* p_pow_node = graph_utils::FirstChildByType(cast1_node, "Pow"); @@ -242,7 +242,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // if this is second formula and if pow node has Cast parent, expect mul5_node has Cast parent as well NodeArg* cast_input_arg = nullptr; if (second_formula) { - const Node* p_cast1_node = graph_utils::FirstParentByType(node, "Cast"); + const Node* p_cast1_node = graph_utils::FirstParentByType(node, "Cast"); if (p_cast1_node != nullptr) { // we've done the node check in second formula for pow node Node& cast1_node = *graph.GetNode(p_cast1_node->Index()); @@ -254,11 +254,11 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& cast3_node = *graph.GetNode(p_cast3_node->Index()); if (!CheckNode(graph, cast3_node, "Cast", {9, 13}, node.GetExecutionProviderType(), true)) { continue; - } + } // overwrite and continue as usual - p_mul5_input_node = graph_utils::FirstParentByType(cast3_node, "Mul"); - nodes_to_fuse.push_back(cast3_node); - // keep cast1_node for reuse, its output edges will be adjusted in FinalizeNodeFusion() + p_mul5_input_node = graph_utils::FirstParentByType(cast3_node, "Mul"); + nodes_to_fuse.push_back(cast3_node); + // keep cast1_node for reuse, its output edges will be adjusted in FinalizeNodeFusion() } } @@ -275,8 +275,8 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } } - if (input_index == -1) continue; - // check same parent for both mul6 and pow, with or without cast + if (input_index == -1) continue; + // check same parent for both mul6 and pow, with or without cast if (cast_input_arg != nullptr) { if (mul6_node.InputDefs()[(input_index + 1) % 2]->Name() != cast_input_arg->Name()) continue; diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index 214d3163fb..0cbb82a9b4 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -61,7 +61,7 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } - if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) { + if (graph.GetNodeProvidesGraphOutput(node)) { continue; } diff --git a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc index 922b4ec22e..d52944f72f 100644 --- a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc @@ -42,8 +42,8 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m nodes_to_remove.push_back(gemm_node); // check if output node is Transpose - if (output_node_ptr != gemm_node.OutputNodesEnd() && - gemm_node.InputDefs().size() <= 2 && // C is missing + if (output_node_ptr != gemm_node.OutputNodesEnd() && + gemm_node.InputDefs().size() <= 2 && // C is missing output_node_ptr->OpType() == "Transpose") { Node& output_node = *graph.GetNode(output_node_ptr->Index()); // (AB)' = B'A' : reverse the inputs @@ -83,7 +83,7 @@ bool GemmTransposeFusion::SatisfyCondition(const Graph& graph, const Node& node, for (auto node_it = node.InputNodesBegin(); node_it != node.InputNodesEnd(); ++node_it) { if (graph_utils::IsSupportedOptypeVersionAndDomain(*node_it, "Transpose", {1, 13}) && node_it->GetOutputEdgesCount() == 1 && - graph.GetNodeOutputsInGraphOutputs(*node_it).empty() && + !graph.GetNodeProvidesGraphOutput(*node_it) && // Make sure the two nodes do not span execution providers. node_it->GetExecutionProviderType() == node.GetExecutionProviderType()) { return true; @@ -94,7 +94,7 @@ bool GemmTransposeFusion::SatisfyCondition(const Graph& graph, const Node& node, // by the rule (AB)' = B'A' provided that C is missing // Supported for Opset >=11 as earlier opsets have C as a required input if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gemm", {11, 13}) || - !graph.GetNodeOutputsInGraphOutputs(node).empty() || + graph.GetNodeProvidesGraphOutput(node) || // verify that C is missing node.InputDefs().size() > 2) { return false; diff --git a/onnxruntime/core/optimizer/identity_elimination.cc b/onnxruntime/core/optimizer/identity_elimination.cc index 845a5a2f1b..f762cb0d5f 100644 --- a/onnxruntime/core/optimizer/identity_elimination.cc +++ b/onnxruntime/core/optimizer/identity_elimination.cc @@ -40,7 +40,7 @@ namespace onnxruntime { X (def0/arg0) ---> Identity ---> Y */ Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { - if (graph.GetNodeOutputsInGraphOutputs(node).empty()) { + if (!graph.GetNodeProvidesGraphOutput(node)) { if (graph_utils::RemoveNode(graph, node)) { rule_effect = RewriteRuleEffect::kRemovedCurrentNode; } @@ -65,7 +65,7 @@ bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node, c return true; } - bool node_output_is_graph_output = !graph.GetNodeOutputsInGraphOutputs(node).empty(); + bool node_output_is_graph_output = graph.GetNodeProvidesGraphOutput(node); // relax the condition if Identity is connecting to graph output if (node.GetOutputEdgesCount() != 0 || diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 9e5cd6b567..6663c49a06 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -94,7 +94,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: // and is assigned to the CPU EP (we have fp32 implementations of all kernels so forcing to fp32 is safe) if (node.GetInputEdgesCount() > 0 && !node.ContainsSubgraph() && - graph.GetNodeOutputsInGraphOutputs(node).empty() && + !graph.GetNodeProvidesGraphOutput(node) && node.GetExecutionProviderType() == kCpuExecutionProvider) { do { // find the number of fp16 inputs as we need to make sure they're all coming from nodes that will be cast diff --git a/onnxruntime/core/optimizer/isinf_reducesum_fusion.cc b/onnxruntime/core/optimizer/isinf_reducesum_fusion.cc index 3de8a32ca4..51c704c436 100644 --- a/onnxruntime/core/optimizer/isinf_reducesum_fusion.cc +++ b/onnxruntime/core/optimizer/isinf_reducesum_fusion.cc @@ -35,7 +35,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l if (!graph_utils::IsSupportedOptypeVersionAndDomain(isinf_node, "IsInf", {10}) || isinf_node.GetOutputEdgesCount() != 1 || - !graph.GetNodeOutputsInGraphOutputs(isinf_node).empty()) { + graph.GetNodeProvidesGraphOutput(isinf_node)) { continue; } @@ -67,7 +67,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l Node& cast2_node = *graph.GetNode(cast2_node_itr->Index()); if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast2_node, "Cast", {9, 13}) || cast2_node.GetOutputEdgesCount() != 1 || - !graph.GetNodeOutputsInGraphOutputs(cast2_node).empty()) { + graph.GetNodeProvidesGraphOutput(cast2_node)) { continue; } nodes_to_remove.push_back(cast2_node); @@ -80,7 +80,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l Node& reduce_sum_node = *graph.GetNode(reduce_sum_node_itr->Index()); if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_sum_node, "ReduceSum", {1, 11, 13}) || reduce_sum_node.GetOutputEdgesCount() != 1 || - !graph.GetNodeOutputsInGraphOutputs(reduce_sum_node).empty()) { + graph.GetNodeProvidesGraphOutput(reduce_sum_node)) { continue; } nodes_to_remove.push_back(reduce_sum_node); diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index c6fee2556c..68ebc089e1 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -79,7 +79,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13}) || !graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) || (reduce_mean_node.GetOutputEdgesCount() != 1 && reduce_mean_node.GetOutputEdgesCount() != 2) || - !graph.GetNodeOutputsInGraphOutputs(reduce_mean_node).empty() || + graph.GetNodeProvidesGraphOutput(reduce_mean_node) || !IsSupportedDataType(reduce_mean_node)) { continue; } @@ -377,7 +377,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow_node, "Pow", {7, 12, 13}) || !graph_utils::IsSupportedProvider(pow_node, GetCompatibleExecutionProviders()) || !optimizer_utils::CheckOutputEdges(graph, pow_node, 1) || - !graph.GetNodeOutputsInGraphOutputs(pow_node).empty() || + graph.GetNodeProvidesGraphOutput(pow_node) || !IsSupportedDataType(pow_node)) { continue; } diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index d65a8343c1..2ba339dd46 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -30,7 +30,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } - if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) { + if (graph.GetNodeProvidesGraphOutput(node)) { continue; } @@ -89,7 +89,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // valid bias_shapes are (N) or (1, N) or (M, 1) or (M, N) as // GEMM only supports unidirectional broadcast on the bias input C if (!gemm_input_defs.back()->Shape()) { - continue; + continue; } const auto& bias_shape = *gemm_input_defs.back()->Shape(); const auto& M = matmul_output.Shape()->dim()[0]; diff --git a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc index 703babffe9..f8211ed19b 100644 --- a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc @@ -39,7 +39,7 @@ static Node* GetTransposeNodeFromOutput(Graph& graph, NodeArg& node_arg) { } // if the node has Graph output, skip it too - if (!graph.GetNodeOutputsInGraphOutputs(*trans_node).empty()) { + if (graph.GetNodeProvidesGraphOutput(*trans_node)) { return nullptr; } @@ -117,9 +117,8 @@ static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_ * V */ static Node* ReorderCastAndTranspose(Graph& graph, Node* cast, - std::unordered_map& consumer_count, - std::deque& removed_nodes) { - + std::unordered_map& consumer_count, + std::deque& removed_nodes) { ORT_ENFORCE(cast != nullptr); auto transpose = GetTransposeNodeFromOutput(graph, *cast->MutableInputDefs()[0]); if (transpose == nullptr) { @@ -138,18 +137,18 @@ static Node* ReorderCastAndTranspose(Graph& graph, Node* cast, new_cast_output_type_proto.mutable_tensor_type()->set_elem_type(element_type); auto& new_cast_output = graph.GetOrCreateNodeArg(cast_output->Name() + "_transformed", &new_cast_output_type_proto); - const std::vector new_cast_input_defs {transpose_input}; - const std::vector new_cast_output_defs {&new_cast_output}; + const std::vector new_cast_input_defs{transpose_input}; + const std::vector new_cast_output_defs{&new_cast_output}; const std::vector new_transpose_input_defs = {&new_cast_output}; const std::vector new_transpose_output_defs = {cast_output}; - (void) graph.AddNode(graph.GenerateNodeName(cast->Name() + "_transformed"), - cast->OpType(), - "Created a new Cast node to interchange Cast and Transpose nodes", - new_cast_input_defs, - new_cast_output_defs, - &cast->GetAttributes(), - cast->Domain()); + (void)graph.AddNode(graph.GenerateNodeName(cast->Name() + "_transformed"), + cast->OpType(), + "Created a new Cast node to interchange Cast and Transpose nodes", + new_cast_input_defs, + new_cast_output_defs, + &cast->GetAttributes(), + cast->Domain()); Node& new_transpose = graph.AddNode(graph.GenerateNodeName(transpose->Name() + "_transformed"), transpose->OpType(), @@ -169,8 +168,7 @@ static Node* ReorderCastAndTranspose(Graph& graph, Node* cast, } // Check whether the element_type is an allowed FusedMatMul data type or not. -static bool IsAllowedFusedMatMulDataType(ONNX_NAMESPACE::TensorProto_DataType element_type) -{ +static bool IsAllowedFusedMatMulDataType(ONNX_NAMESPACE::TensorProto_DataType element_type) { return element_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || element_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 || element_type == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE || diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index ae4a87eda4..16beae1747 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -164,7 +164,7 @@ size_t NchwcTransformerImpl::RemoveOutputEdges(Node& node) { } // Bias the edge count to handle the case of a node that produces a graph // output. - if (!graph_.GetNodeOutputsInGraphOutputs(node).empty()) { + if (graph_.GetNodeProvidesGraphOutput(node)) { output_edges_count++; } return output_edges_count; @@ -1145,7 +1145,7 @@ void NchwcTransformerImpl::TrackTransposeFromNhwc(Node& node) { // Verify that the node does not produce a graph output and produces output // for a single node. - if (!graph_.GetNodeOutputsInGraphOutputs(node).empty() || node.GetOutputEdgesCount() != 1) { + if (graph_.GetNodeProvidesGraphOutput(node) || node.GetOutputEdgesCount() != 1) { return; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index 2e3f3c71e5..ada8e4a5ae 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -26,7 +26,7 @@ static bool TryCancelOutDQQPair(Graph& graph, Node& dq_node, Node& q_node) { // check if dq_node has only one output edge and, // dq_node and q_node output are not graph outputs if (!optimizer_utils::CheckOutputEdges(graph, dq_node, 1) || - !graph.GetNodeOutputsInGraphOutputs(q_node).empty()) { + graph.GetNodeProvidesGraphOutput(q_node)) { return false; } diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index 6c291a058b..fde661172a 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -173,8 +173,8 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (CheckFirstAdd(*p_add1, ln_node.GetExecutionProviderType()) && CheckSecondAdd(graph, *p_add2, ln_node.GetExecutionProviderType()) && - graph.GetNodeOutputsInGraphOutputs(*p_add1).empty() && - graph.GetNodeOutputsInGraphOutputs(*p_add2).empty()) { + !graph.GetNodeProvidesGraphOutput(*p_add1) && + !graph.GetNodeProvidesGraphOutput(*p_add2)) { matched_format = Format::Format1; } } @@ -191,8 +191,8 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (CheckFirstAdd(*p_add1, ln_node.GetExecutionProviderType()) && CheckSecondAdd(graph, *p_add2, ln_node.GetExecutionProviderType()) && - graph.GetNodeOutputsInGraphOutputs(*p_add1).empty() && - graph.GetNodeOutputsInGraphOutputs(*p_add2).empty()) { + !graph.GetNodeProvidesGraphOutput(*p_add1) && + !graph.GetNodeProvidesGraphOutput(*p_add2)) { matched_format = Format::Format2; } } diff --git a/onnxruntime/core/optimizer/utils.cc b/onnxruntime/core/optimizer/utils.cc index 424d7db60d..fd9bfe3fcd 100644 --- a/onnxruntime/core/optimizer/utils.cc +++ b/onnxruntime/core/optimizer/utils.cc @@ -267,7 +267,7 @@ int32_t IndexOfNodeOutput(const Node& node, const NodeArg& node_arg) { } bool CheckOutputEdges(const Graph& graph, const Node& node, size_t expected_output_edges) { - if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) { + if (graph.GetNodeProvidesGraphOutput(node)) { return false; }