mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
update TRT EP
This commit is contained in:
parent
df5aca92d6
commit
60d9599383
3 changed files with 77 additions and 49 deletions
|
|
@ -2554,8 +2554,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
|
|||
}
|
||||
|
||||
bool early_termination = false;
|
||||
// supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination);
|
||||
supported_nodes_vector = parser_nodes_vector;
|
||||
supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination);
|
||||
if (early_termination) {
|
||||
supported_nodes_vector.clear();
|
||||
}
|
||||
|
|
@ -2660,13 +2659,13 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
|
|||
* Enable EP related L2+ graph optimizations with steps:
|
||||
*
|
||||
* 1. Call provider bridge API to lookup pre-defined optimizer by name and get selection function
|
||||
* - Run selection function to get selection ComputeCapability
|
||||
* - ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization
|
||||
* 2. Run selection function to get selection ComputeCapability
|
||||
- ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization
|
||||
*
|
||||
*
|
||||
*
|
||||
* Current available optimizations:
|
||||
* - (ConstantFoldingDQ) constant folding on DQ nodes -> Dequantize INT32, UINT16, INT16 constant to FP32.
|
||||
* - (ConstantFoldingDQ) constant folding on DQ nodes, i.e. dequantize INT32, UINT16, INT16 constant to FP32.
|
||||
*/
|
||||
|
||||
std::function<std::vector<std::unique_ptr<ComputeCapability>>(const GraphViewer&)> selection_func;
|
||||
|
|
@ -2687,52 +2686,16 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
|
|||
|
||||
SelectQualifiedDQNode(graph, trt_selection_node_set, consumer_to_dq);
|
||||
|
||||
// Include nodes that are filtered out by TRT parser.
|
||||
auto update_supported_node_vector = [&](SubGraph_t& supported_node_vector, SubGraphCollection_t& supported_nodes_vector) -> void {
|
||||
if (!consumer_to_dq.empty()) {
|
||||
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder(1);
|
||||
for (auto index : supported_node_vector.first) {
|
||||
if (consumer_to_dq.find(node_index[index]) == consumer_to_dq.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto dq_node_index = consumer_to_dq[node_index[index]];
|
||||
|
||||
// Check if DQ node is included in one of the subgraphs
|
||||
auto in_the_subgraph_collection = [&](NodeIndex node_idx) -> bool {
|
||||
for (auto& node_vector : supported_nodes_vector) {
|
||||
if (!node_vector.second) {
|
||||
continue;
|
||||
}
|
||||
for (auto index : node_vector.first) {
|
||||
if (node_index[index] == node_idx) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
if (in_the_subgraph_collection(dq_node_index)) {
|
||||
continue;
|
||||
}
|
||||
// Find the iterator pointing to the target element
|
||||
auto it = std::find(node_index.begin(), node_index.end(), dq_node_index);
|
||||
if (it != node_index.end()) {
|
||||
// Calculate the index
|
||||
int idx = std::distance(node_index.begin(), it);
|
||||
supported_node_vector.first.push_back(static_cast<NodeIndex>(idx));
|
||||
auto node = graph.GetNode(dq_node_index);
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << node->Name() << " is included which is filtered out by TRT parser.";
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Create ComputeCapability
|
||||
int number_of_trt_nodes = 0, subgraph_index = 0;
|
||||
for (const auto& group : supported_nodes_vector) {
|
||||
for (auto& group : supported_nodes_vector) {
|
||||
if (!group.first.empty()) {
|
||||
// TODO: Use consumer_to_dq table to include DQ node that is filtered out by TRT parser.
|
||||
|
||||
if (!selection_cc.empty()) {
|
||||
// Include DQ nodes that are filtered out by TRT parser
|
||||
UpdateSupportedNodeVectorForDQ(graph, group, supported_nodes_vector, consumer_to_dq);
|
||||
}
|
||||
|
||||
std::unique_ptr<IndexedSubGraph> sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index);
|
||||
auto compute_capability = ComputeCapability::Create(std::move(sub_graph));
|
||||
|
||||
|
|
|
|||
|
|
@ -612,5 +612,13 @@ class TensorrtExecutionProvider : public IExecutionProvider {
|
|||
std::unique_ptr<ComputeCapability> CreateOptimizationComputeCapability(ComputeCapability* selection_cc,
|
||||
std::unordered_set<NodeIndex>& trt_selection_node_set,
|
||||
ComputeCapability* trt_cc) const;
|
||||
/**
|
||||
* This function helps add back the DQ nodes that are filtered out by TRT parser.
|
||||
* The reason is the DQ nodes can be optimized and dequantized by applying ConstantFoldingDQ optimizer by ORT L2+ optimization.
|
||||
*/
|
||||
void TensorrtExecutionProvider::UpdateSupportedNodeVectorForDQ(const GraphViewer& graph,
|
||||
SubGraph_t& supported_node_vector,
|
||||
SubGraphCollection_t& supported_nodes_vector,
|
||||
std::unordered_map<NodeIndex, NodeIndex> consumer_to_dq) const;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -293,7 +293,7 @@ void TensorrtExecutionProvider::SelectQualifiedDQNode(const GraphViewer& graph,
|
|||
const Node& consumer_node = *node->OutputNodesBegin();
|
||||
selection_node_set.insert(index);
|
||||
consumer_to_dq[consumer_node.Index()] = index;
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << consumer_node.Name() << " < -" << node->Name();
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << consumer_node.Name() << " <- " << node->Name();
|
||||
}
|
||||
}
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Total " << selection_node_set.size() << " DequantizeLinear node(s) are selected.";
|
||||
|
|
@ -330,4 +330,61 @@ std::unique_ptr<ComputeCapability> TensorrtExecutionProvider::CreateOptimization
|
|||
compute_capability->copy_optimization_func(selection_cc);
|
||||
return compute_capability;
|
||||
}
|
||||
|
||||
/**
|
||||
* This function helps add back the DQ nodes that are filtered out by TRT parser.
|
||||
* The reason is the DQ nodes can be optimized and dequantized by applying ConstantFoldingDQ optimizer by ORT L2+ optimization.
|
||||
*/
|
||||
void TensorrtExecutionProvider::UpdateSupportedNodeVectorForDQ(const GraphViewer& graph,
|
||||
SubGraph_t& supported_node_vector,
|
||||
SubGraphCollection_t& supported_nodes_vector,
|
||||
std::unordered_map<NodeIndex, NodeIndex> consumer_to_dq) const {
|
||||
if (consumer_to_dq.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!supported_node_vector.second) {
|
||||
return;
|
||||
}
|
||||
|
||||
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder(1);
|
||||
auto supported_nodes = supported_node_vector.first;
|
||||
for (auto index : supported_nodes) {
|
||||
if (consumer_to_dq.find(node_index[index]) == consumer_to_dq.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto dq_node_index = consumer_to_dq[node_index[index]];
|
||||
|
||||
// Check if DQ node is included in one of the subgraphs
|
||||
auto in_the_subgraph_collection = [&](NodeIndex node_idx) -> bool {
|
||||
for (auto& node_vector : supported_nodes_vector) {
|
||||
if (!node_vector.second) {
|
||||
continue;
|
||||
}
|
||||
for (auto i : node_vector.first) {
|
||||
if (node_index[i] == node_idx) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// If the DQ node is already in the subgraph, do nothing.
|
||||
if (in_the_subgraph_collection(dq_node_index)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find the iterator pointing to the target element
|
||||
auto it = std::find(node_index.begin(), node_index.end(), dq_node_index);
|
||||
if (it != node_index.end()) {
|
||||
// Calculate the index
|
||||
int idx = std::distance(node_index.begin(), it);
|
||||
supported_node_vector.first.push_back(static_cast<NodeIndex>(idx));
|
||||
auto node = graph.GetNode(dq_node_index);
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << node->Name() << " is included which is filtered out by TRT parser.";
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue