mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-20 02:07:56 +00:00
Handle optional inputs and remove more empty shape nodes in TensorRT EP (#3455)
* check optional inputs and remove more empty shape affected nodes * fix some minor issues * update code according to feedback
This commit is contained in:
parent
d09d4a6b0d
commit
56e85484ba
1 changed files with 22 additions and 15 deletions
|
|
@ -183,11 +183,10 @@ bool FindCycleHelper(int i, const std::list<int>* adjacency_map,
|
|||
|
||||
// Remove nodes with empty shape (for example [1, 0]) because TensorRT 7 doens't support empty shape
|
||||
SubGraphCollection_t RemoveEmptyShapeNodes(const onnxruntime::GraphViewer& graph) {
|
||||
// Here only NonZero and NonMaxSuppression related empty shape nodes are removed, particularly for Faster-rcnn and Mask-rcnn models.
|
||||
// Here only NonZero, NonMaxSuppression and TopK related empty shape nodes are removed, particularly for RCNN models.
|
||||
// TODO: Remove the code if TensorRT fixed the issue in the future release, or find a better generic way here to work around
|
||||
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder();
|
||||
const std::string exclude_dim_name1 = "NonZero";
|
||||
const std::string exclude_dim_name2 = "NonMaxSuppression";
|
||||
const std::vector<std::string> exclude_dim_names{"NonZero", "NonMaxSuppression", "TopK"};
|
||||
SubGraphCollection_t parser_nodes_vector = {{{}, false}};
|
||||
std::vector<size_t> nodes_vector(node_index.size());
|
||||
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);
|
||||
|
|
@ -201,8 +200,13 @@ SubGraphCollection_t RemoveEmptyShapeNodes(const onnxruntime::GraphViewer& graph
|
|||
for (const auto& dim : input_shape->dim()) {
|
||||
std::string dim_name = dim.dim_param();
|
||||
if (!dim_name.empty()) {
|
||||
if ((dim_name.find(exclude_dim_name1) != std::string::npos) || (dim_name.find(exclude_dim_name2) != std::string::npos)) {
|
||||
exclude_node = true;
|
||||
for (const auto& exclude : exclude_dim_names) {
|
||||
if (dim_name.find(exclude) != std::string::npos) {
|
||||
exclude_node = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (exclude_node) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -260,7 +264,7 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
|
|||
}
|
||||
}
|
||||
|
||||
// For output searching, there is two special cases,
|
||||
// For output searching, there are two special cases,
|
||||
// One is, if node's OutputEdges are more than its outputs, meaning certain output is used more than once,
|
||||
// if the output is connected to nodes that don't belong to the subgraph, the output need to be added
|
||||
// to the output list
|
||||
|
|
@ -322,11 +326,15 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
|
|||
meta_def->domain = kMSDomain;
|
||||
|
||||
for (const auto& input : inputs) {
|
||||
meta_def->inputs.push_back(input.second->Name());
|
||||
if (input.second->Exists()) {
|
||||
meta_def->inputs.push_back(input.second->Name());
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& output : outputs) {
|
||||
meta_def->outputs.push_back(output.second->Name());
|
||||
if (output.second->Exists()) {
|
||||
meta_def->outputs.push_back(output.second->Name());
|
||||
}
|
||||
}
|
||||
|
||||
meta_def->since_version = 1;
|
||||
|
|
@ -385,6 +393,12 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
|
|||
graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain());
|
||||
}
|
||||
|
||||
// Add initializers to the subgraph
|
||||
const auto& init_tensors = graph.GetAllInitializedTensors();
|
||||
for (const auto& tensor : init_tensors) {
|
||||
graph_build.AddInitializedTensor(*(tensor.second));
|
||||
}
|
||||
|
||||
ORT_ENFORCE(graph_build.Resolve().IsOK());
|
||||
|
||||
// Add parent graph output to the subgraph
|
||||
|
|
@ -400,13 +414,6 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
|
|||
auto& graph_build_outputs = graph_build.GetOutputs();
|
||||
subgraph_outputs.insert(subgraph_outputs.begin(), graph_build_outputs.begin(), graph_build_outputs.end());
|
||||
graph_build.SetOutputs(graph_build_outputs);
|
||||
|
||||
// Add initializers to the subgraph
|
||||
const auto& init_tensors = graph.GetAllInitializedTensors();
|
||||
for (const auto& tensor : init_tensors) {
|
||||
graph_build.AddInitializedTensor(*(tensor.second));
|
||||
}
|
||||
|
||||
ORT_ENFORCE(graph_build.Resolve().IsOK());
|
||||
|
||||
// Check if input tensors have shapes
|
||||
|
|
|
|||
Loading…
Reference in a new issue