diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 48fbb6c6fe..55f9532076 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2554,8 +2554,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } bool early_termination = false; - // supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); - supported_nodes_vector = parser_nodes_vector; + supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); if (early_termination) { supported_nodes_vector.clear(); } @@ -2660,13 +2659,13 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, * Enable EP related L2+ graph optimizations with steps: * * 1. Call provider bridge API to lookup pre-defined optimizer by name and get selection function - * - Run selection function to get selection ComputeCapability - * - ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization + * 2. Run selection function to get selection ComputeCapability + - ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization * * * * Current available optimizations: - * - (ConstantFoldingDQ) constant folding on DQ nodes -> Dequantize INT32, UINT16, INT16 constant to FP32. + * - (ConstantFoldingDQ) constant folding on DQ nodes, i.e. dequantize INT32, UINT16, INT16 constant to FP32. */ std::function>(const GraphViewer&)> selection_func; @@ -2687,52 +2686,16 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, SelectQualifiedDQNode(graph, trt_selection_node_set, consumer_to_dq); - // Include nodes that are filtered out by TRT parser. - auto update_supported_node_vector = [&](SubGraph_t& supported_node_vector, SubGraphCollection_t& supported_nodes_vector) -> void { - if (!consumer_to_dq.empty()) { - const std::vector& node_index = graph.GetNodesInTopologicalOrder(1); - for (auto index : supported_node_vector.first) { - if (consumer_to_dq.find(node_index[index]) == consumer_to_dq.end()) { - continue; - } - - auto dq_node_index = consumer_to_dq[node_index[index]]; - - // Check if DQ node is included in one of the subgraphs - auto in_the_subgraph_collection = [&](NodeIndex node_idx) -> bool { - for (auto& node_vector : supported_nodes_vector) { - if (!node_vector.second) { - continue; - } - for (auto index : node_vector.first) { - if (node_index[index] == node_idx) { - return true; - } - } - } - return false; - }; - if (in_the_subgraph_collection(dq_node_index)) { - continue; - } - // Find the iterator pointing to the target element - auto it = std::find(node_index.begin(), node_index.end(), dq_node_index); - if (it != node_index.end()) { - // Calculate the index - int idx = std::distance(node_index.begin(), it); - supported_node_vector.first.push_back(static_cast(idx)); - auto node = graph.GetNode(dq_node_index); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << node->Name() << " is included which is filtered out by TRT parser."; - } - } - } - }; - // Create ComputeCapability int number_of_trt_nodes = 0, subgraph_index = 0; - for (const auto& group : supported_nodes_vector) { + for (auto& group : supported_nodes_vector) { if (!group.first.empty()) { - // TODO: Use consumer_to_dq table to include DQ node that is filtered out by TRT parser. + + if (!selection_cc.empty()) { + // Include DQ nodes that are filtered out by TRT parser + UpdateSupportedNodeVectorForDQ(graph, group, supported_nodes_vector, consumer_to_dq); + } + std::unique_ptr sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index); auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 5889ff9960..45b15368ac 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -612,5 +612,13 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unique_ptr CreateOptimizationComputeCapability(ComputeCapability* selection_cc, std::unordered_set& trt_selection_node_set, ComputeCapability* trt_cc) const; + /** + * This function helps add back the DQ nodes that are filtered out by TRT parser. + * The reason is the DQ nodes can be optimized and dequantized by applying ConstantFoldingDQ optimizer by ORT L2+ optimization. + */ + void TensorrtExecutionProvider::UpdateSupportedNodeVectorForDQ(const GraphViewer& graph, + SubGraph_t& supported_node_vector, + SubGraphCollection_t& supported_nodes_vector, + std::unordered_map consumer_to_dq) const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc index aeba2854b9..702bc6108b 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc @@ -293,7 +293,7 @@ void TensorrtExecutionProvider::SelectQualifiedDQNode(const GraphViewer& graph, const Node& consumer_node = *node->OutputNodesBegin(); selection_node_set.insert(index); consumer_to_dq[consumer_node.Index()] = index; - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << consumer_node.Name() << " < -" << node->Name(); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << consumer_node.Name() << " <- " << node->Name(); } } LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Total " << selection_node_set.size() << " DequantizeLinear node(s) are selected."; @@ -330,4 +330,61 @@ std::unique_ptr TensorrtExecutionProvider::CreateOptimization compute_capability->copy_optimization_func(selection_cc); return compute_capability; } + +/** + * This function helps add back the DQ nodes that are filtered out by TRT parser. + * The reason is the DQ nodes can be optimized and dequantized by applying ConstantFoldingDQ optimizer by ORT L2+ optimization. + */ +void TensorrtExecutionProvider::UpdateSupportedNodeVectorForDQ(const GraphViewer& graph, + SubGraph_t& supported_node_vector, + SubGraphCollection_t& supported_nodes_vector, + std::unordered_map consumer_to_dq) const { + if (consumer_to_dq.empty()) { + return; + } + + if (!supported_node_vector.second) { + return; + } + + const std::vector& node_index = graph.GetNodesInTopologicalOrder(1); + auto supported_nodes = supported_node_vector.first; + for (auto index : supported_nodes) { + if (consumer_to_dq.find(node_index[index]) == consumer_to_dq.end()) { + continue; + } + + auto dq_node_index = consumer_to_dq[node_index[index]]; + + // Check if DQ node is included in one of the subgraphs + auto in_the_subgraph_collection = [&](NodeIndex node_idx) -> bool { + for (auto& node_vector : supported_nodes_vector) { + if (!node_vector.second) { + continue; + } + for (auto i : node_vector.first) { + if (node_index[i] == node_idx) { + return true; + } + } + } + return false; + }; + + // If the DQ node is already in the subgraph, do nothing. + if (in_the_subgraph_collection(dq_node_index)) { + continue; + } + + // Find the iterator pointing to the target element + auto it = std::find(node_index.begin(), node_index.end(), dq_node_index); + if (it != node_index.end()) { + // Calculate the index + int idx = std::distance(node_index.begin(), it); + supported_node_vector.first.push_back(static_cast(idx)); + auto node = graph.GetNode(dq_node_index); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << node->Name() << " is included which is filtered out by TRT parser."; + } + } +} } // namespace onnxruntime