mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
Fix cast optimizer (#524)
This commit is contained in:
parent
2e6ec07d9a
commit
9fb80ea927
1 changed files with 8 additions and 0 deletions
|
|
@ -105,14 +105,22 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
|
|||
if (node.OpType() == "Cast") {
|
||||
// if cast's next node is also cast and next cast's output type equal to cast's input type
|
||||
// remove those two cast.
|
||||
// boolean is an exception case for this optimization
|
||||
auto src_type = node.InputDefs()[0]->Type();
|
||||
auto dst_type = node.OutputDefs()[0]->Type();
|
||||
if (*src_type == "tensor(bool)" || *dst_type == "tensor(bool)") return Status::OK();
|
||||
auto input = node.MutableInputDefs()[0];
|
||||
int child_removed = 0;
|
||||
int num_child = 0;
|
||||
auto output_args = graph.GetOutputs();
|
||||
std::unordered_set<const onnxruntime::NodeArg*> graph_outputs(output_args.begin(), output_args.end());
|
||||
for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) {
|
||||
const Node& output_node{*it};
|
||||
if (output_node.OpType() == "Cast") {
|
||||
// Skip if the node's output is also the output of the graph
|
||||
if (graph_outputs.find(output_node.OutputDefs()[0]) != graph_outputs.end()) {
|
||||
break;
|
||||
}
|
||||
auto src_type1 = output_node.InputDefs()[0]->Type();
|
||||
auto dst_type1 = output_node.OutputDefs()[0]->Type();
|
||||
if (src_type == dst_type1 && src_type1 == dst_type) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue