diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 898df7a11b..6b9230657b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -95,6 +95,13 @@ namespace Dml::GraphDescBuilder } } + // Delete the edges that reference nodes that are not reachable before removing the nodes themselves + graphIntermediateEdges.erase( + std::remove_if(graphIntermediateEdges.begin(), graphIntermediateEdges.end(), [&nodesData](const auto& intermediateEdge){ + return nodesData[intermediateEdge.FromNodeIndex].state == NodeState::NotVisited || nodesData[intermediateEdge.ToNodeIndex].state == NodeState::NotVisited; + }), + graphIntermediateEdges.end()); + // Mapping from the old indices to the new indices that have been shifted after removing earlier nodes std::vector shiftedIndicesMapping(graphNodes.size()); @@ -134,12 +141,6 @@ namespace Dml::GraphDescBuilder intermediateEdge.FromNodeIndex = shiftedIndicesMapping[intermediateEdge.FromNodeIndex]; intermediateEdge.ToNodeIndex = shiftedIndicesMapping[intermediateEdge.ToNodeIndex]; } - - graphIntermediateEdges.erase( - std::remove_if(graphIntermediateEdges.begin(), graphIntermediateEdges.end(), [&nodesData](const auto& intermediateEdge){ - return nodesData[intermediateEdge.FromNodeIndex].state == NodeState::NotVisited || nodesData[intermediateEdge.ToNodeIndex].state == NodeState::NotVisited; - }), - graphIntermediateEdges.end()); } GraphDesc BuildGraphDesc(