From c151afec7131e76d31b675817a4b3ff4ee80cfca Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 9 Jan 2023 15:40:09 -0800 Subject: [PATCH] [DML EP] Fix unconnected node removal logic (#14193) ### Description Fix unconnected node removal logic ### Motivation and Context The edges need to be removed before the nodes themselves, otherwise the indices will reference the wrong nodes. --- .../DmlExecutionProvider/src/GraphDescBuilder.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 898df7a11b..6b9230657b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -95,6 +95,13 @@ namespace Dml::GraphDescBuilder } } + // Delete the edges that reference nodes that are not reachable before removing the nodes themselves + graphIntermediateEdges.erase( + std::remove_if(graphIntermediateEdges.begin(), graphIntermediateEdges.end(), [&nodesData](const auto& intermediateEdge){ + return nodesData[intermediateEdge.FromNodeIndex].state == NodeState::NotVisited || nodesData[intermediateEdge.ToNodeIndex].state == NodeState::NotVisited; + }), + graphIntermediateEdges.end()); + // Mapping from the old indices to the new indices that have been shifted after removing earlier nodes std::vector shiftedIndicesMapping(graphNodes.size()); @@ -134,12 +141,6 @@ namespace Dml::GraphDescBuilder intermediateEdge.FromNodeIndex = shiftedIndicesMapping[intermediateEdge.FromNodeIndex]; intermediateEdge.ToNodeIndex = shiftedIndicesMapping[intermediateEdge.ToNodeIndex]; } - - graphIntermediateEdges.erase( - std::remove_if(graphIntermediateEdges.begin(), graphIntermediateEdges.end(), [&nodesData](const auto& intermediateEdge){ - return nodesData[intermediateEdge.FromNodeIndex].state == NodeState::NotVisited || nodesData[intermediateEdge.ToNodeIndex].state == NodeState::NotVisited; - }), - graphIntermediateEdges.end()); } GraphDesc BuildGraphDesc(