From 4b4b585f58cd329a1c7737aa6bcfaf1a47cfd76a Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 23 Apr 2019 11:30:20 -0700 Subject: [PATCH] Fix minor bugs in RemoveDuplicateGraphTransformer (#883) --- onnxruntime/core/optimizer/insert_cast_transformer.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 873633555c..12147acaf4 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -118,9 +118,9 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) { const Node& output_node{*it}; if (output_node.OpType() == "Cast") { - // Skip if the node's output is also the output of the graph + // Skip this child node if this child node's output is also an output of the graph if (graph_outputs.find(output_node.OutputDefs()[0]) != graph_outputs.end()) { - break; + continue; } auto src_type1 = output_node.InputDefs()[0]->Type(); auto dst_type1 = output_node.OutputDefs()[0]->Type(); @@ -138,7 +138,9 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { num_child++; } - if (child_removed == num_child && child_removed > 0) { + if (child_removed == num_child && + child_removed > 0 && + graph_outputs.find(node.OutputDefs()[0]) == graph_outputs.end()) { removed_nodes.push_back(node.Index()); } }