diff --git a/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc b/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc index 17093c15e3..84c20531d6 100644 --- a/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc +++ b/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc @@ -93,24 +93,25 @@ bool MKLDNNExecutionProvider::UseSubgraph(const onnxruntime::GraphViewer& graph_ bool FP16_graph = false; bool mkldnn_nodes_in_the_graph = false; + int max_node_index = graph_viewer.MaxNodeIndex(); - if (graph_viewer.MaxNodeIndex() > 0) { - int index = 0; - auto node = graph_viewer.GetNode(index); - while (node == NULL) { - index++; - node = graph_viewer.GetNode(index); - } - if (!node->InputDefs().empty() && node->InputDefs()[0]->Type() != nullptr) + for (auto node_index = 0; node_index < max_node_index; node_index++) { + auto node = graph_viewer.GetNode(node_index); + if (node == NULL) + continue; + + if (!node->InputDefs().empty() && node->InputDefs()[0]->Type() != nullptr) { FP16_graph = node->InputDefs()[0]->Type()->find("16") != std::string::npos; + break; + } } - for (auto node_index = 0; node_index < graph_viewer.MaxNodeIndex(); node_index++) { + for (auto node_index = 0; node_index < max_node_index; node_index++) { auto node = graph_viewer.GetNode(node_index); if (node == nullptr) { - node_index++; continue; } + auto op_it = mkldnn_ops_.find(node->OpType()); if (op_it != mkldnn_ops_.end()) { mkldnn_nodes_in_the_graph = true;