mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Add helper to check if node provides a graph output. (#8186)
* Add helper to check if node provides a graph output. The current approach unnecessarily creates a vector when most of the optimizers only care about a true/false response. * Undo accidental change * Fix a couple of issues due to copying from larger set of changes.
This commit is contained in:
parent
17d4545ccb
commit
b3479367cf
22 changed files with 71 additions and 60 deletions
|
|
@ -681,6 +681,19 @@ class Graph {
|
|||
return std::find(graph_outputs_.begin(), graph_outputs_.end(), node_arg) != graph_outputs_.end();
|
||||
}
|
||||
|
||||
/** Returns true if one or more of the Node outputs are Graph outputs.
|
||||
@remarks Cheaper than calling GetNodeOutputsInGraphOutputs.
|
||||
*/
|
||||
bool GetNodeProvidesGraphOutput(const Node& node) const {
|
||||
auto end_outputs = graph_outputs_.cend();
|
||||
for (auto output_def : node.OutputDefs()) {
|
||||
if (std::find(graph_outputs_.cbegin(), end_outputs, output_def) != end_outputs) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/** Returns a vector with the indexes of the outputs of the given Node that are also Graph outputs. */
|
||||
std::vector<int> GetNodeOutputsInGraphOutputs(const Node& node) const {
|
||||
int output_idx = 0;
|
||||
|
|
|
|||
|
|
@ -323,7 +323,7 @@ bool CanRemoveNode(const Graph& graph, const Node& node, const logging::Logger&
|
|||
// This would allow removal of a node that is providing a graph output, as that output name would come from updating
|
||||
// the upstream node. This should also enable removal if CanUpdateImplicitInputNameInSubgraphs returns false.
|
||||
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
|
||||
if (graph.GetNodeProvidesGraphOutput(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -762,7 +762,7 @@ bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) {
|
|||
// Each eligible node in the subgraph must have less than one output edge and no output should be
|
||||
// the graph output
|
||||
const Node& cur_node = *graph.GetNode(cur_node_index);
|
||||
if (cur_node.GetOutputEdgesCount() > 1 || !graph.GetNodeOutputsInGraphOutputs(cur_node).empty()) {
|
||||
if (cur_node.GetOutputEdgesCount() > 1 || graph.GetNodeProvidesGraphOutput(cur_node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ void FuseResidualAddIfAny(Graph& graph, const Node& dropout_node,
|
|||
// To be able to fuse the residual Add,
|
||||
// the Dropout's output must not be a graph output and
|
||||
// there must be only one consumer of the Dropout's first output.
|
||||
if (dropout_consumers_count < 2 && graph.GetNodeOutputsInGraphOutputs(dropout_node).empty()) {
|
||||
if (dropout_consumers_count < 2 && !graph.GetNodeProvidesGraphOutput(dropout_node)) {
|
||||
for (auto last_node_itr = dropout_node.OutputNodesBegin(); last_node_itr != dropout_node.OutputNodesEnd(); ++last_node_itr) {
|
||||
const Node& last_node = (*last_node_itr);
|
||||
|
||||
|
|
@ -139,7 +139,7 @@ Status BiasDropoutFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve
|
|||
continue;
|
||||
}
|
||||
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
|
||||
if (graph.GetNodeProvidesGraphOutput(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
continue;
|
||||
}
|
||||
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
|
||||
if (graph.GetNodeProvidesGraphOutput(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -96,12 +96,12 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
continue;
|
||||
}
|
||||
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(*node).empty()) {
|
||||
if (graph.GetNodeProvidesGraphOutput(*node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->GetExecutionProviderType() == onnxruntime::kCudaExecutionProvider) {
|
||||
if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() !=
|
||||
if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() !=
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -125,7 +125,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
if (last_node.GetExecutionProviderType() != node->GetExecutionProviderType()) {
|
||||
continue;
|
||||
}
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Relu", {6, 13, 14}) &&
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Relu", {6, 13, 14}) &&
|
||||
next_node.GetOutputEdgesCount() == 1) {
|
||||
Node& conv_node = *node;
|
||||
Node& add_node = *graph.GetNode(next_node.Index());
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
|
||||
if (graph.GetNodeProvidesGraphOutput(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -177,7 +177,7 @@ bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const
|
|||
}
|
||||
}
|
||||
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
|
||||
if (graph.GetNodeProvidesGraphOutput(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
|
||||
if (graph.GetNodeProvidesGraphOutput(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ bool DivMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
|
||||
if (graph.GetNodeProvidesGraphOutput(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ static bool CheckNode(Graph& graph, const Node& node, const std::string& op_name
|
|||
node.GetExecutionProviderType() == provider &&
|
||||
IsSupportedDataType(node) &&
|
||||
(!require_single_output || node.GetOutputEdgesCount() == 1) &&
|
||||
graph.GetNodeOutputsInGraphOutputs(node).empty();
|
||||
!graph.GetNodeProvidesGraphOutput(node);
|
||||
}
|
||||
|
||||
MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node,
|
||||
|
|
@ -146,8 +146,8 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
|
|||
if (p_cast1_node != nullptr) {
|
||||
Node& cast1_node = *graph.GetNode(p_cast1_node->Index());
|
||||
// this is fused Cast node, so expect 2 output edges
|
||||
if (!CheckNode(graph, cast1_node, "Cast", {9, 13}, pow1_node.GetExecutionProviderType(), false) ||
|
||||
cast1_node.GetOutputEdgesCount() != 2){
|
||||
if (!CheckNode(graph, cast1_node, "Cast", {9, 13}, pow1_node.GetExecutionProviderType(), false) ||
|
||||
cast1_node.GetOutputEdgesCount() != 2) {
|
||||
return matchResult;
|
||||
}
|
||||
const Node* p_pow_node = graph_utils::FirstChildByType(cast1_node, "Pow");
|
||||
|
|
@ -242,7 +242,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
// if this is second formula and if pow node has Cast parent, expect mul5_node has Cast parent as well
|
||||
NodeArg* cast_input_arg = nullptr;
|
||||
if (second_formula) {
|
||||
const Node* p_cast1_node = graph_utils::FirstParentByType(node, "Cast");
|
||||
const Node* p_cast1_node = graph_utils::FirstParentByType(node, "Cast");
|
||||
if (p_cast1_node != nullptr) {
|
||||
// we've done the node check in second formula for pow node
|
||||
Node& cast1_node = *graph.GetNode(p_cast1_node->Index());
|
||||
|
|
@ -254,11 +254,11 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
Node& cast3_node = *graph.GetNode(p_cast3_node->Index());
|
||||
if (!CheckNode(graph, cast3_node, "Cast", {9, 13}, node.GetExecutionProviderType(), true)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// overwrite and continue as usual
|
||||
p_mul5_input_node = graph_utils::FirstParentByType(cast3_node, "Mul");
|
||||
nodes_to_fuse.push_back(cast3_node);
|
||||
// keep cast1_node for reuse, its output edges will be adjusted in FinalizeNodeFusion()
|
||||
p_mul5_input_node = graph_utils::FirstParentByType(cast3_node, "Mul");
|
||||
nodes_to_fuse.push_back(cast3_node);
|
||||
// keep cast1_node for reuse, its output edges will be adjusted in FinalizeNodeFusion()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -275,8 +275,8 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
}
|
||||
}
|
||||
|
||||
if (input_index == -1) continue;
|
||||
// check same parent for both mul6 and pow, with or without cast
|
||||
if (input_index == -1) continue;
|
||||
// check same parent for both mul6 and pow, with or without cast
|
||||
if (cast_input_arg != nullptr) {
|
||||
if (mul6_node.InputDefs()[(input_index + 1) % 2]->Name() != cast_input_arg->Name())
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
continue;
|
||||
}
|
||||
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
|
||||
if (graph.GetNodeProvidesGraphOutput(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m
|
|||
nodes_to_remove.push_back(gemm_node);
|
||||
|
||||
// check if output node is Transpose
|
||||
if (output_node_ptr != gemm_node.OutputNodesEnd() &&
|
||||
gemm_node.InputDefs().size() <= 2 && // C is missing
|
||||
if (output_node_ptr != gemm_node.OutputNodesEnd() &&
|
||||
gemm_node.InputDefs().size() <= 2 && // C is missing
|
||||
output_node_ptr->OpType() == "Transpose") {
|
||||
Node& output_node = *graph.GetNode(output_node_ptr->Index());
|
||||
// (AB)' = B'A' : reverse the inputs
|
||||
|
|
@ -83,7 +83,7 @@ bool GemmTransposeFusion::SatisfyCondition(const Graph& graph, const Node& node,
|
|||
for (auto node_it = node.InputNodesBegin(); node_it != node.InputNodesEnd(); ++node_it) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(*node_it, "Transpose", {1, 13}) &&
|
||||
node_it->GetOutputEdgesCount() == 1 &&
|
||||
graph.GetNodeOutputsInGraphOutputs(*node_it).empty() &&
|
||||
!graph.GetNodeProvidesGraphOutput(*node_it) &&
|
||||
// Make sure the two nodes do not span execution providers.
|
||||
node_it->GetExecutionProviderType() == node.GetExecutionProviderType()) {
|
||||
return true;
|
||||
|
|
@ -94,7 +94,7 @@ bool GemmTransposeFusion::SatisfyCondition(const Graph& graph, const Node& node,
|
|||
// by the rule (AB)' = B'A' provided that C is missing
|
||||
// Supported for Opset >=11 as earlier opsets have C as a required input
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gemm", {11, 13}) ||
|
||||
!graph.GetNodeOutputsInGraphOutputs(node).empty() ||
|
||||
graph.GetNodeProvidesGraphOutput(node) ||
|
||||
// verify that C is missing
|
||||
node.InputDefs().size() > 2) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ namespace onnxruntime {
|
|||
X (def0/arg0) ---> Identity ---> Y
|
||||
*/
|
||||
Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
if (graph.GetNodeOutputsInGraphOutputs(node).empty()) {
|
||||
if (!graph.GetNodeProvidesGraphOutput(node)) {
|
||||
if (graph_utils::RemoveNode(graph, node)) {
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
}
|
||||
|
|
@ -65,7 +65,7 @@ bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node, c
|
|||
return true;
|
||||
}
|
||||
|
||||
bool node_output_is_graph_output = !graph.GetNodeOutputsInGraphOutputs(node).empty();
|
||||
bool node_output_is_graph_output = graph.GetNodeProvidesGraphOutput(node);
|
||||
|
||||
// relax the condition if Identity is connecting to graph output
|
||||
if (node.GetOutputEdgesCount() != 0 ||
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
|
|||
// and is assigned to the CPU EP (we have fp32 implementations of all kernels so forcing to fp32 is safe)
|
||||
if (node.GetInputEdgesCount() > 0 &&
|
||||
!node.ContainsSubgraph() &&
|
||||
graph.GetNodeOutputsInGraphOutputs(node).empty() &&
|
||||
!graph.GetNodeProvidesGraphOutput(node) &&
|
||||
node.GetExecutionProviderType() == kCpuExecutionProvider) {
|
||||
do {
|
||||
// find the number of fp16 inputs as we need to make sure they're all coming from nodes that will be cast
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(isinf_node, "IsInf", {10}) ||
|
||||
isinf_node.GetOutputEdgesCount() != 1 ||
|
||||
!graph.GetNodeOutputsInGraphOutputs(isinf_node).empty()) {
|
||||
graph.GetNodeProvidesGraphOutput(isinf_node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -67,7 +67,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
Node& cast2_node = *graph.GetNode(cast2_node_itr->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast2_node, "Cast", {9, 13}) ||
|
||||
cast2_node.GetOutputEdgesCount() != 1 ||
|
||||
!graph.GetNodeOutputsInGraphOutputs(cast2_node).empty()) {
|
||||
graph.GetNodeProvidesGraphOutput(cast2_node)) {
|
||||
continue;
|
||||
}
|
||||
nodes_to_remove.push_back(cast2_node);
|
||||
|
|
@ -80,7 +80,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
Node& reduce_sum_node = *graph.GetNode(reduce_sum_node_itr->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_sum_node, "ReduceSum", {1, 11, 13}) ||
|
||||
reduce_sum_node.GetOutputEdgesCount() != 1 ||
|
||||
!graph.GetNodeOutputsInGraphOutputs(reduce_sum_node).empty()) {
|
||||
graph.GetNodeProvidesGraphOutput(reduce_sum_node)) {
|
||||
continue;
|
||||
}
|
||||
nodes_to_remove.push_back(reduce_sum_node);
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) ||
|
||||
(reduce_mean_node.GetOutputEdgesCount() != 1 && reduce_mean_node.GetOutputEdgesCount() != 2) ||
|
||||
!graph.GetNodeOutputsInGraphOutputs(reduce_mean_node).empty() ||
|
||||
graph.GetNodeProvidesGraphOutput(reduce_mean_node) ||
|
||||
!IsSupportedDataType(reduce_mean_node)) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -377,7 +377,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow_node, "Pow", {7, 12, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(pow_node, GetCompatibleExecutionProviders()) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, pow_node, 1) ||
|
||||
!graph.GetNodeOutputsInGraphOutputs(pow_node).empty() ||
|
||||
graph.GetNodeProvidesGraphOutput(pow_node) ||
|
||||
!IsSupportedDataType(pow_node)) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
continue;
|
||||
}
|
||||
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
|
||||
if (graph.GetNodeProvidesGraphOutput(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -89,7 +89,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
// valid bias_shapes are (N) or (1, N) or (M, 1) or (M, N) as
|
||||
// GEMM only supports unidirectional broadcast on the bias input C
|
||||
if (!gemm_input_defs.back()->Shape()) {
|
||||
continue;
|
||||
continue;
|
||||
}
|
||||
const auto& bias_shape = *gemm_input_defs.back()->Shape();
|
||||
const auto& M = matmul_output.Shape()->dim()[0];
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ static Node* GetTransposeNodeFromOutput(Graph& graph, NodeArg& node_arg) {
|
|||
}
|
||||
|
||||
// if the node has Graph output, skip it too
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(*trans_node).empty()) {
|
||||
if (graph.GetNodeProvidesGraphOutput(*trans_node)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
@ -117,9 +117,8 @@ static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_
|
|||
* V
|
||||
*/
|
||||
static Node* ReorderCastAndTranspose(Graph& graph, Node* cast,
|
||||
std::unordered_map<NodeArg*, size_t>& consumer_count,
|
||||
std::deque<onnxruntime::NodeIndex>& removed_nodes) {
|
||||
|
||||
std::unordered_map<NodeArg*, size_t>& consumer_count,
|
||||
std::deque<onnxruntime::NodeIndex>& removed_nodes) {
|
||||
ORT_ENFORCE(cast != nullptr);
|
||||
auto transpose = GetTransposeNodeFromOutput(graph, *cast->MutableInputDefs()[0]);
|
||||
if (transpose == nullptr) {
|
||||
|
|
@ -138,18 +137,18 @@ static Node* ReorderCastAndTranspose(Graph& graph, Node* cast,
|
|||
new_cast_output_type_proto.mutable_tensor_type()->set_elem_type(element_type);
|
||||
auto& new_cast_output = graph.GetOrCreateNodeArg(cast_output->Name() + "_transformed", &new_cast_output_type_proto);
|
||||
|
||||
const std::vector<NodeArg*> new_cast_input_defs {transpose_input};
|
||||
const std::vector<NodeArg*> new_cast_output_defs {&new_cast_output};
|
||||
const std::vector<NodeArg*> new_cast_input_defs{transpose_input};
|
||||
const std::vector<NodeArg*> new_cast_output_defs{&new_cast_output};
|
||||
const std::vector<NodeArg*> new_transpose_input_defs = {&new_cast_output};
|
||||
const std::vector<NodeArg*> new_transpose_output_defs = {cast_output};
|
||||
|
||||
(void) graph.AddNode(graph.GenerateNodeName(cast->Name() + "_transformed"),
|
||||
cast->OpType(),
|
||||
"Created a new Cast node to interchange Cast and Transpose nodes",
|
||||
new_cast_input_defs,
|
||||
new_cast_output_defs,
|
||||
&cast->GetAttributes(),
|
||||
cast->Domain());
|
||||
(void)graph.AddNode(graph.GenerateNodeName(cast->Name() + "_transformed"),
|
||||
cast->OpType(),
|
||||
"Created a new Cast node to interchange Cast and Transpose nodes",
|
||||
new_cast_input_defs,
|
||||
new_cast_output_defs,
|
||||
&cast->GetAttributes(),
|
||||
cast->Domain());
|
||||
|
||||
Node& new_transpose = graph.AddNode(graph.GenerateNodeName(transpose->Name() + "_transformed"),
|
||||
transpose->OpType(),
|
||||
|
|
@ -169,8 +168,7 @@ static Node* ReorderCastAndTranspose(Graph& graph, Node* cast,
|
|||
}
|
||||
|
||||
// Check whether the element_type is an allowed FusedMatMul data type or not.
|
||||
static bool IsAllowedFusedMatMulDataType(ONNX_NAMESPACE::TensorProto_DataType element_type)
|
||||
{
|
||||
static bool IsAllowedFusedMatMulDataType(ONNX_NAMESPACE::TensorProto_DataType element_type) {
|
||||
return element_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ||
|
||||
element_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 ||
|
||||
element_type == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE ||
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ size_t NchwcTransformerImpl::RemoveOutputEdges(Node& node) {
|
|||
}
|
||||
// Bias the edge count to handle the case of a node that produces a graph
|
||||
// output.
|
||||
if (!graph_.GetNodeOutputsInGraphOutputs(node).empty()) {
|
||||
if (graph_.GetNodeProvidesGraphOutput(node)) {
|
||||
output_edges_count++;
|
||||
}
|
||||
return output_edges_count;
|
||||
|
|
@ -1145,7 +1145,7 @@ void NchwcTransformerImpl::TrackTransposeFromNhwc(Node& node) {
|
|||
|
||||
// Verify that the node does not produce a graph output and produces output
|
||||
// for a single node.
|
||||
if (!graph_.GetNodeOutputsInGraphOutputs(node).empty() || node.GetOutputEdgesCount() != 1) {
|
||||
if (graph_.GetNodeProvidesGraphOutput(node) || node.GetOutputEdgesCount() != 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ static bool TryCancelOutDQQPair(Graph& graph, Node& dq_node, Node& q_node) {
|
|||
// check if dq_node has only one output edge and,
|
||||
// dq_node and q_node output are not graph outputs
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, dq_node, 1) ||
|
||||
!graph.GetNodeOutputsInGraphOutputs(q_node).empty()) {
|
||||
graph.GetNodeProvidesGraphOutput(q_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -173,8 +173,8 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
|
||||
if (CheckFirstAdd(*p_add1, ln_node.GetExecutionProviderType()) &&
|
||||
CheckSecondAdd(graph, *p_add2, ln_node.GetExecutionProviderType()) &&
|
||||
graph.GetNodeOutputsInGraphOutputs(*p_add1).empty() &&
|
||||
graph.GetNodeOutputsInGraphOutputs(*p_add2).empty()) {
|
||||
!graph.GetNodeProvidesGraphOutput(*p_add1) &&
|
||||
!graph.GetNodeProvidesGraphOutput(*p_add2)) {
|
||||
matched_format = Format::Format1;
|
||||
}
|
||||
}
|
||||
|
|
@ -191,8 +191,8 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
|
||||
if (CheckFirstAdd(*p_add1, ln_node.GetExecutionProviderType()) &&
|
||||
CheckSecondAdd(graph, *p_add2, ln_node.GetExecutionProviderType()) &&
|
||||
graph.GetNodeOutputsInGraphOutputs(*p_add1).empty() &&
|
||||
graph.GetNodeOutputsInGraphOutputs(*p_add2).empty()) {
|
||||
!graph.GetNodeProvidesGraphOutput(*p_add1) &&
|
||||
!graph.GetNodeProvidesGraphOutput(*p_add2)) {
|
||||
matched_format = Format::Format2;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -267,7 +267,7 @@ int32_t IndexOfNodeOutput(const Node& node, const NodeArg& node_arg) {
|
|||
}
|
||||
|
||||
bool CheckOutputEdges(const Graph& graph, const Node& node, size_t expected_output_edges) {
|
||||
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
|
||||
if (graph.GetNodeProvidesGraphOutput(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue