Fix invalid logic that ran past end of nodes and double increment. (#2117)

This commit is contained in:
Scott McKay 2019-10-13 11:37:57 +10:00 committed by GitHub
parent eb24617d2e
commit b829d55320
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -93,24 +93,25 @@ bool MKLDNNExecutionProvider::UseSubgraph(const onnxruntime::GraphViewer& graph_
bool FP16_graph = false;
bool mkldnn_nodes_in_the_graph = false;
int max_node_index = graph_viewer.MaxNodeIndex();
if (graph_viewer.MaxNodeIndex() > 0) {
int index = 0;
auto node = graph_viewer.GetNode(index);
while (node == NULL) {
index++;
node = graph_viewer.GetNode(index);
}
if (!node->InputDefs().empty() && node->InputDefs()[0]->Type() != nullptr)
for (auto node_index = 0; node_index < max_node_index; node_index++) {
auto node = graph_viewer.GetNode(node_index);
if (node == NULL)
continue;
if (!node->InputDefs().empty() && node->InputDefs()[0]->Type() != nullptr) {
FP16_graph = node->InputDefs()[0]->Type()->find("16") != std::string::npos;
break;
}
}
for (auto node_index = 0; node_index < graph_viewer.MaxNodeIndex(); node_index++) {
for (auto node_index = 0; node_index < max_node_index; node_index++) {
auto node = graph_viewer.GetNode(node_index);
if (node == nullptr) {
node_index++;
continue;
}
auto op_it = mkldnn_ops_.find(node->OpType());
if (op_it != mkldnn_ops_.end()) {
mkldnn_nodes_in_the_graph = true;