Fix minor bugs in RemoveDuplicateGraphTransformer (#883)

This commit is contained in:
Hariharan Seshadri 2019-04-23 11:30:20 -07:00 committed by GitHub
parent fb3b63438d
commit 4b4b585f58
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -118,9 +118,9 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) {
const Node& output_node{*it};
if (output_node.OpType() == "Cast") {
// Skip if the node's output is also the output of the graph
// Skip this child node if this child node's output is also an output of the graph
if (graph_outputs.find(output_node.OutputDefs()[0]) != graph_outputs.end()) {
break;
continue;
}
auto src_type1 = output_node.InputDefs()[0]->Type();
auto dst_type1 = output_node.OutputDefs()[0]->Type();
@ -138,7 +138,9 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
num_child++;
}
if (child_removed == num_child && child_removed > 0) {
if (child_removed == num_child &&
child_removed > 0 &&
graph_outputs.find(node.OutputDefs()[0]) == graph_outputs.end()) {
removed_nodes.push_back(node.Index());
}
}