Fix cast optimizer (#524)

This commit is contained in:
Raymond Yang 2019-02-28 13:35:52 -08:00 committed by Changming Sun
parent 2e6ec07d9a
commit 9fb80ea927

View file

@ -105,14 +105,22 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
if (node.OpType() == "Cast") {
// if cast's next node is also cast and next cast's output type equal to cast's input type
// remove those two cast.
// boolean is an exception case for this optimization
auto src_type = node.InputDefs()[0]->Type();
auto dst_type = node.OutputDefs()[0]->Type();
if (*src_type == "tensor(bool)" || *dst_type == "tensor(bool)") return Status::OK();
auto input = node.MutableInputDefs()[0];
int child_removed = 0;
int num_child = 0;
auto output_args = graph.GetOutputs();
std::unordered_set<const onnxruntime::NodeArg*> graph_outputs(output_args.begin(), output_args.end());
for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) {
const Node& output_node{*it};
if (output_node.OpType() == "Cast") {
// Skip if the node's output is also the output of the graph
if (graph_outputs.find(output_node.OutputDefs()[0]) != graph_outputs.end()) {
break;
}
auto src_type1 = output_node.InputDefs()[0]->Type();
auto dst_type1 = output_node.OutputDefs()[0]->Type();
if (src_type == dst_type1 && src_type1 == dst_type) {