diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 6c36f62015..41500af9c3 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -105,14 +105,22 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { if (node.OpType() == "Cast") { // if cast's next node is also cast and next cast's output type equal to cast's input type // remove those two cast. + // boolean is an exception case for this optimization auto src_type = node.InputDefs()[0]->Type(); auto dst_type = node.OutputDefs()[0]->Type(); + if (*src_type == "tensor(bool)" || *dst_type == "tensor(bool)") return Status::OK(); auto input = node.MutableInputDefs()[0]; int child_removed = 0; int num_child = 0; + auto output_args = graph.GetOutputs(); + std::unordered_set graph_outputs(output_args.begin(), output_args.end()); for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) { const Node& output_node{*it}; if (output_node.OpType() == "Cast") { + // Skip if the node's output is also the output of the graph + if (graph_outputs.find(output_node.OutputDefs()[0]) != graph_outputs.end()) { + break; + } auto src_type1 = output_node.InputDefs()[0]->Type(); auto dst_type1 = output_node.OutputDefs()[0]->Type(); if (src_type == dst_type1 && src_type1 == dst_type) {