From b829d5532047f2424930fc2337cd2d954f9f987d Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sun, 13 Oct 2019 11:37:57 +1000 Subject: [PATCH] Fix invalid logic that ran past end of nodes and double increment. (#2117) --- .../mkldnn/mkldnn_execution_provider.cc | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc b/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc index 17093c15e3..84c20531d6 100644 --- a/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc +++ b/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc @@ -93,24 +93,25 @@ bool MKLDNNExecutionProvider::UseSubgraph(const onnxruntime::GraphViewer& graph_ bool FP16_graph = false; bool mkldnn_nodes_in_the_graph = false; + int max_node_index = graph_viewer.MaxNodeIndex(); - if (graph_viewer.MaxNodeIndex() > 0) { - int index = 0; - auto node = graph_viewer.GetNode(index); - while (node == NULL) { - index++; - node = graph_viewer.GetNode(index); - } - if (!node->InputDefs().empty() && node->InputDefs()[0]->Type() != nullptr) + for (auto node_index = 0; node_index < max_node_index; node_index++) { + auto node = graph_viewer.GetNode(node_index); + if (node == NULL) + continue; + + if (!node->InputDefs().empty() && node->InputDefs()[0]->Type() != nullptr) { FP16_graph = node->InputDefs()[0]->Type()->find("16") != std::string::npos; + break; + } } - for (auto node_index = 0; node_index < graph_viewer.MaxNodeIndex(); node_index++) { + for (auto node_index = 0; node_index < max_node_index; node_index++) { auto node = graph_viewer.GetNode(node_index); if (node == nullptr) { - node_index++; continue; } + auto op_it = mkldnn_ops_.find(node->OpType()); if (op_it != mkldnn_ops_.end()) { mkldnn_nodes_in_the_graph = true;