From 56e85484ba2f8079a070c4c6e1f50fe024b1f4e7 Mon Sep 17 00:00:00 2001 From: stevenlix <38092805+stevenlix@users.noreply.github.com> Date: Fri, 10 Apr 2020 11:13:38 -0700 Subject: [PATCH] Handle optional inputs and remove more empty shape nodes in TensorRT EP (#3455) * check optional inputs and remove more empty shape affected nodes * fix some minor issues * update code according to feedback --- .../tensorrt/tensorrt_execution_provider.cc | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index a090837361..2d5db7de07 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -183,11 +183,10 @@ bool FindCycleHelper(int i, const std::list* adjacency_map, // Remove nodes with empty shape (for example [1, 0]) because TensorRT 7 doens't support empty shape SubGraphCollection_t RemoveEmptyShapeNodes(const onnxruntime::GraphViewer& graph) { - // Here only NonZero and NonMaxSuppression related empty shape nodes are removed, particularly for Faster-rcnn and Mask-rcnn models. + // Here only NonZero, NonMaxSuppression and TopK related empty shape nodes are removed, particularly for RCNN models. // TODO: Remove the code if TensorRT fixed the issue in the future release, or find a better generic way here to work around const std::vector& node_index = graph.GetNodesInTopologicalOrder(); - const std::string exclude_dim_name1 = "NonZero"; - const std::string exclude_dim_name2 = "NonMaxSuppression"; + const std::vector exclude_dim_names{"NonZero", "NonMaxSuppression", "TopK"}; SubGraphCollection_t parser_nodes_vector = {{{}, false}}; std::vector nodes_vector(node_index.size()); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); @@ -201,8 +200,13 @@ SubGraphCollection_t RemoveEmptyShapeNodes(const onnxruntime::GraphViewer& graph for (const auto& dim : input_shape->dim()) { std::string dim_name = dim.dim_param(); if (!dim_name.empty()) { - if ((dim_name.find(exclude_dim_name1) != std::string::npos) || (dim_name.find(exclude_dim_name2) != std::string::npos)) { - exclude_node = true; + for (const auto& exclude : exclude_dim_names) { + if (dim_name.find(exclude) != std::string::npos) { + exclude_node = true; + break; + } + } + if (exclude_node) { break; } } @@ -260,7 +264,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } } - // For output searching, there is two special cases, + // For output searching, there are two special cases, // One is, if node's OutputEdges are more than its outputs, meaning certain output is used more than once, // if the output is connected to nodes that don't belong to the subgraph, the output need to be added // to the output list @@ -322,11 +326,15 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph meta_def->domain = kMSDomain; for (const auto& input : inputs) { - meta_def->inputs.push_back(input.second->Name()); + if (input.second->Exists()) { + meta_def->inputs.push_back(input.second->Name()); + } } for (const auto& output : outputs) { - meta_def->outputs.push_back(output.second->Name()); + if (output.second->Exists()) { + meta_def->outputs.push_back(output.second->Name()); + } } meta_def->since_version = 1; @@ -385,6 +393,12 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain()); } + // Add initializers to the subgraph + const auto& init_tensors = graph.GetAllInitializedTensors(); + for (const auto& tensor : init_tensors) { + graph_build.AddInitializedTensor(*(tensor.second)); + } + ORT_ENFORCE(graph_build.Resolve().IsOK()); // Add parent graph output to the subgraph @@ -400,13 +414,6 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect auto& graph_build_outputs = graph_build.GetOutputs(); subgraph_outputs.insert(subgraph_outputs.begin(), graph_build_outputs.begin(), graph_build_outputs.end()); graph_build.SetOutputs(graph_build_outputs); - - // Add initializers to the subgraph - const auto& init_tensors = graph.GetAllInitializedTensors(); - for (const auto& tensor : init_tensors) { - graph_build.AddInitializedTensor(*(tensor.second)); - } - ORT_ENFORCE(graph_build.Resolve().IsOK()); // Check if input tensors have shapes