Handle optional inputs and remove more empty shape nodes in TensorRT EP (#3455)

* check optional inputs and remove more empty shape affected nodes

* fix some minor issues

* update code according to feedback
This commit is contained in:
stevenlix 2020-04-10 11:13:38 -07:00 committed by GitHub
parent d09d4a6b0d
commit 56e85484ba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -183,11 +183,10 @@ bool FindCycleHelper(int i, const std::list<int>* adjacency_map,
// Remove nodes with empty shape (for example [1, 0]) because TensorRT 7 doens't support empty shape
SubGraphCollection_t RemoveEmptyShapeNodes(const onnxruntime::GraphViewer& graph) {
// Here only NonZero and NonMaxSuppression related empty shape nodes are removed, particularly for Faster-rcnn and Mask-rcnn models.
// Here only NonZero, NonMaxSuppression and TopK related empty shape nodes are removed, particularly for RCNN models.
// TODO: Remove the code if TensorRT fixed the issue in the future release, or find a better generic way here to work around
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder();
const std::string exclude_dim_name1 = "NonZero";
const std::string exclude_dim_name2 = "NonMaxSuppression";
const std::vector<std::string> exclude_dim_names{"NonZero", "NonMaxSuppression", "TopK"};
SubGraphCollection_t parser_nodes_vector = {{{}, false}};
std::vector<size_t> nodes_vector(node_index.size());
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);
@ -201,8 +200,13 @@ SubGraphCollection_t RemoveEmptyShapeNodes(const onnxruntime::GraphViewer& graph
for (const auto& dim : input_shape->dim()) {
std::string dim_name = dim.dim_param();
if (!dim_name.empty()) {
if ((dim_name.find(exclude_dim_name1) != std::string::npos) || (dim_name.find(exclude_dim_name2) != std::string::npos)) {
exclude_node = true;
for (const auto& exclude : exclude_dim_names) {
if (dim_name.find(exclude) != std::string::npos) {
exclude_node = true;
break;
}
}
if (exclude_node) {
break;
}
}
@ -260,7 +264,7 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
}
}
// For output searching, there is two special cases,
// For output searching, there are two special cases,
// One is, if node's OutputEdges are more than its outputs, meaning certain output is used more than once,
// if the output is connected to nodes that don't belong to the subgraph, the output need to be added
// to the output list
@ -322,11 +326,15 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
meta_def->domain = kMSDomain;
for (const auto& input : inputs) {
meta_def->inputs.push_back(input.second->Name());
if (input.second->Exists()) {
meta_def->inputs.push_back(input.second->Name());
}
}
for (const auto& output : outputs) {
meta_def->outputs.push_back(output.second->Name());
if (output.second->Exists()) {
meta_def->outputs.push_back(output.second->Name());
}
}
meta_def->since_version = 1;
@ -385,6 +393,12 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain());
}
// Add initializers to the subgraph
const auto& init_tensors = graph.GetAllInitializedTensors();
for (const auto& tensor : init_tensors) {
graph_build.AddInitializedTensor(*(tensor.second));
}
ORT_ENFORCE(graph_build.Resolve().IsOK());
// Add parent graph output to the subgraph
@ -400,13 +414,6 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
auto& graph_build_outputs = graph_build.GetOutputs();
subgraph_outputs.insert(subgraph_outputs.begin(), graph_build_outputs.begin(), graph_build_outputs.end());
graph_build.SetOutputs(graph_build_outputs);
// Add initializers to the subgraph
const auto& init_tensors = graph.GetAllInitializedTensors();
for (const auto& tensor : init_tensors) {
graph_build.AddInitializedTensor(*(tensor.second));
}
ORT_ENFORCE(graph_build.Resolve().IsOK());
// Check if input tensors have shapes