mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Fix minor bugs in RemoveDuplicateGraphTransformer (#883)
This commit is contained in:
parent
fb3b63438d
commit
4b4b585f58
1 changed files with 5 additions and 3 deletions
|
|
@ -118,9 +118,9 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
|
|||
for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) {
|
||||
const Node& output_node{*it};
|
||||
if (output_node.OpType() == "Cast") {
|
||||
// Skip if the node's output is also the output of the graph
|
||||
// Skip this child node if this child node's output is also an output of the graph
|
||||
if (graph_outputs.find(output_node.OutputDefs()[0]) != graph_outputs.end()) {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
auto src_type1 = output_node.InputDefs()[0]->Type();
|
||||
auto dst_type1 = output_node.OutputDefs()[0]->Type();
|
||||
|
|
@ -138,7 +138,9 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
|
|||
num_child++;
|
||||
}
|
||||
|
||||
if (child_removed == num_child && child_removed > 0) {
|
||||
if (child_removed == num_child &&
|
||||
child_removed > 0 &&
|
||||
graph_outputs.find(node.OutputDefs()[0]) == graph_outputs.end()) {
|
||||
removed_nodes.push_back(node.Index());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue