mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
Fix invalid logic that ran past end of nodes and double increment. (#2117)
This commit is contained in:
parent
eb24617d2e
commit
b829d55320
1 changed files with 11 additions and 10 deletions
|
|
@ -93,24 +93,25 @@ bool MKLDNNExecutionProvider::UseSubgraph(const onnxruntime::GraphViewer& graph_
|
|||
|
||||
bool FP16_graph = false;
|
||||
bool mkldnn_nodes_in_the_graph = false;
|
||||
int max_node_index = graph_viewer.MaxNodeIndex();
|
||||
|
||||
if (graph_viewer.MaxNodeIndex() > 0) {
|
||||
int index = 0;
|
||||
auto node = graph_viewer.GetNode(index);
|
||||
while (node == NULL) {
|
||||
index++;
|
||||
node = graph_viewer.GetNode(index);
|
||||
}
|
||||
if (!node->InputDefs().empty() && node->InputDefs()[0]->Type() != nullptr)
|
||||
for (auto node_index = 0; node_index < max_node_index; node_index++) {
|
||||
auto node = graph_viewer.GetNode(node_index);
|
||||
if (node == NULL)
|
||||
continue;
|
||||
|
||||
if (!node->InputDefs().empty() && node->InputDefs()[0]->Type() != nullptr) {
|
||||
FP16_graph = node->InputDefs()[0]->Type()->find("16") != std::string::npos;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto node_index = 0; node_index < graph_viewer.MaxNodeIndex(); node_index++) {
|
||||
for (auto node_index = 0; node_index < max_node_index; node_index++) {
|
||||
auto node = graph_viewer.GetNode(node_index);
|
||||
if (node == nullptr) {
|
||||
node_index++;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto op_it = mkldnn_ops_.find(node->OpType());
|
||||
if (op_it != mkldnn_ops_.end()) {
|
||||
mkldnn_nodes_in_the_graph = true;
|
||||
|
|
|
|||
Loading…
Reference in a new issue