[DML EP] Fix unconnected node removal logic (#14193)

### Description
Fix unconnected node removal logic



### Motivation and Context
The edges need to be removed before the nodes themselves, otherwise the
indices will reference the wrong nodes.
This commit is contained in:
Patrice Vignola 2023-01-09 15:40:09 -08:00 committed by GitHub
parent 906f578be8
commit c151afec71
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -95,6 +95,13 @@ namespace Dml::GraphDescBuilder
}
}
// Delete the edges that reference nodes that are not reachable before removing the nodes themselves
graphIntermediateEdges.erase(
std::remove_if(graphIntermediateEdges.begin(), graphIntermediateEdges.end(), [&nodesData](const auto& intermediateEdge){
return nodesData[intermediateEdge.FromNodeIndex].state == NodeState::NotVisited || nodesData[intermediateEdge.ToNodeIndex].state == NodeState::NotVisited;
}),
graphIntermediateEdges.end());
// Mapping from the old indices to the new indices that have been shifted after removing earlier nodes
std::vector<uint32_t> shiftedIndicesMapping(graphNodes.size());
@ -134,12 +141,6 @@ namespace Dml::GraphDescBuilder
intermediateEdge.FromNodeIndex = shiftedIndicesMapping[intermediateEdge.FromNodeIndex];
intermediateEdge.ToNodeIndex = shiftedIndicesMapping[intermediateEdge.ToNodeIndex];
}
graphIntermediateEdges.erase(
std::remove_if(graphIntermediateEdges.begin(), graphIntermediateEdges.end(), [&nodesData](const auto& intermediateEdge){
return nodesData[intermediateEdge.FromNodeIndex].state == NodeState::NotVisited || nodesData[intermediateEdge.ToNodeIndex].state == NodeState::NotVisited;
}),
graphIntermediateEdges.end());
}
GraphDesc BuildGraphDesc(