mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
[TensorRT EP] Only instantiate TRT builder once (#18100)
The TRT builder instantization is slow (see [here](https://github.com/microsoft/onnxruntime/issues/18071)). In current TRT EP, we instantiate builder object every time we need it. There are multiple places need the TRT builder so this causes huge performance overhead.
This commit is contained in:
parent
6f9f653ada
commit
18a3675bf7
2 changed files with 28 additions and 7 deletions
|
|
@ -1272,6 +1272,20 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the pointer to the IBuilder instance.
|
||||
// Note: This function is not thread safe. Calls to this function from different threads must be serialized
|
||||
// even though it doesn't make sense to have multiple threads initializing the same inference session.
|
||||
nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const {
|
||||
if (!builder_) {
|
||||
TensorrtLogger& trt_logger = GetTensorrtLogger();
|
||||
{
|
||||
auto lock = GetApiLock();
|
||||
builder_ = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
|
||||
}
|
||||
}
|
||||
return builder_.get();
|
||||
}
|
||||
|
||||
void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) const {
|
||||
if (info_.custom_op_domain_list.empty()) {
|
||||
common::Status status = CreateTensorRTCustomOpDomainList(info_);
|
||||
|
|
@ -1633,7 +1647,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
|
|||
// Get supported node list recursively
|
||||
SubGraphCollection_t parser_nodes_list;
|
||||
TensorrtLogger& trt_logger = GetTensorrtLogger();
|
||||
auto trt_builder = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
|
||||
auto trt_builder = GetBuilder();
|
||||
const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
|
||||
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(explicitBatch));
|
||||
|
||||
|
|
@ -1985,7 +1999,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
}
|
||||
|
||||
TensorrtLogger& trt_logger = GetTensorrtLogger();
|
||||
auto trt_builder = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
|
||||
auto trt_builder = GetBuilder();
|
||||
const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
|
||||
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(explicitBatch));
|
||||
auto trt_config = std::unique_ptr<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
|
||||
|
|
@ -2438,7 +2452,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
parsers_.emplace(fused_node.Name(), std::move(trt_parser));
|
||||
engines_.emplace(fused_node.Name(), std::move(trt_engine));
|
||||
contexts_.emplace(fused_node.Name(), std::move(trt_context));
|
||||
builders_.emplace(fused_node.Name(), std::move(trt_builder));
|
||||
networks_.emplace(fused_node.Name(), std::move(trt_network));
|
||||
input_info_[fused_node.Name()].push_back(input_indexes);
|
||||
output_info_[fused_node.Name()].push_back(output_indexes);
|
||||
|
|
@ -2456,8 +2469,8 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
if (!tactic_sources_.empty()) {
|
||||
tactics = GetTacticSourceFromString(tactic_sources_);
|
||||
}
|
||||
*p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name,
|
||||
&parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name],
|
||||
*p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(),
|
||||
&parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name],
|
||||
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
|
||||
input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
|
||||
dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
|
||||
|
|
@ -2490,7 +2503,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue;
|
||||
auto fused_node_name = trt_state->fused_node_name;
|
||||
auto& shape_ranges = trt_state->input_shape_ranges;
|
||||
auto trt_builder = trt_state->builder->get();
|
||||
auto trt_builder = trt_state->builder;
|
||||
auto trt_engine = trt_state->engine->get();
|
||||
auto trt_context = trt_state->context->get();
|
||||
auto trt_profiles = trt_state->profiles;
|
||||
|
|
|
|||
|
|
@ -105,10 +105,10 @@ struct TensorrtFuncState {
|
|||
DestroyFunc test_release_func = nullptr;
|
||||
AllocatorHandle allocator = nullptr;
|
||||
std::string fused_node_name;
|
||||
nvinfer1::IBuilder* builder;
|
||||
tensorrt_ptr::unique_pointer<nvonnxparser::IParser>* parser = nullptr;
|
||||
std::unique_ptr<nvinfer1::ICudaEngine>* engine = nullptr;
|
||||
std::unique_ptr<nvinfer1::IExecutionContext>* context = nullptr;
|
||||
std::unique_ptr<nvinfer1::IBuilder>* builder = nullptr;
|
||||
std::unique_ptr<nvinfer1::INetworkDefinition>* network = nullptr;
|
||||
std::vector<std::unordered_map<std::string, size_t>> input_info;
|
||||
std::vector<std::unordered_map<std::string, size_t>> output_info;
|
||||
|
|
@ -245,6 +245,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {
|
|||
std::unordered_set<std::string> control_flow_op_set_ = {"If", "Loop", "Scan"};
|
||||
mutable std::unordered_map<std::string, std::unique_ptr<SubGraphContext>> subgraph_context_map_;
|
||||
|
||||
mutable std::unique_ptr<nvinfer1::IBuilder> builder_;
|
||||
|
||||
// Following maps that hold TRT objects will be accessible by different threads if ORT is using multithreading.
|
||||
// In general, TensorRT objects are not thread safe; accesses to an object from different threads must be serialized by the client.
|
||||
// But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
|
||||
|
|
@ -456,5 +458,11 @@ class TensorrtExecutionProvider : public IExecutionProvider {
|
|||
void CaptureBegin();
|
||||
void CaptureEnd();
|
||||
void IncrementRegularRunCountBeforeGraphCapture();
|
||||
|
||||
/**
|
||||
* Get the pointer to the IBuilder instance.
|
||||
* This function only creates the instance at the first time it's being called."
|
||||
*/
|
||||
nvinfer1::IBuilder* GetBuilder() const;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue