diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 8d342eaea4..7ca71dd532 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -499,6 +499,20 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv } engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt"); } + + if (fp16_enable_ || int8_enable_) { // DLA can only be enabled with FP16 or INT8 + const std::string dla_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDLAEnable); + if (!dla_enable_env.empty()) { + dla_enable_ = (std::stoi(dla_enable_env) == 0 ? false : true); + } + + if (dla_enable_) { + const std::string dla_core_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDLACore); + if (!dla_core_env.empty()) { + dla_core_ = std::stoi(dla_core_env); + } + } + } } TensorrtExecutionProvider::~TensorrtExecutionProvider() { @@ -948,13 +962,9 @@ void TensorrtExecutionProvider::RemoveTensorRTGraphCycles(SubGraphCollection_t& for (int i = 0; i < static_cast(cycles.size()); ++i) { auto loc = index_to_node_map.find(cycles[i]); if (loc != index_to_node_map.end() && loc->second.find("TRTKernel") != std::string::npos) { - std::size_t found = loc->second.rfind("_"); - if (found != std::string::npos) { - int trt_node_index = std::stoi(loc->second.substr(found + 1)); - supported_nodes_vector.erase(supported_nodes_vector.begin() + trt_node_index); - trt_cycle = true; - break; - } + supported_nodes_vector.erase(supported_nodes_vector.begin() + cycles[i]); + trt_cycle = true; + break; } } } @@ -1145,6 +1155,27 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; } + // Set DLA + if (fp16_enable_ || int8_enable_) { + if (dla_enable_ && dla_core_ >= 0) {//DLA can only run with FP16 and INT8 + int number_of_dla_core = trt_builder->getNbDLACores(); + if (number_of_dla_core == 0) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core"; + dla_enable_ = false; + } else { + if (dla_core_ >= number_of_dla_core) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core #" << dla_core_ << ", but it exceeds platform's maximum DLA core number " << number_of_dla_core << ". Use DLA core 0 instead."; + dla_core_ = 0; + } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << dla_core_; + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(dla_core_); + trt_node_name_with_precision += "_dlacore" + std::to_string(dla_core_); + } + } + } + // Build TRT engine here if the graph doesn't have dynamic shape input. Otherwise engine will // be built at runtime tensorrt_ptr::unique_pointer trt_engine; @@ -1261,8 +1292,8 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse *p = {context->allocate_func, context->release_func, context->allocator_handle, &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], &tensorrt_mu_, &fp16_enable_, &int8_enable_, &max_workspace_size_, - trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), nullptr, + input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, dla_enable_, + dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), nullptr, allocator_, dynamic_range_map, engine_decryption_enable_, engine_decryption_}; *state = p.release(); return 0; @@ -1504,7 +1535,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse trt_config->addOptimizationProfile(*trt_profile); // Set INT8 Per Tensor Dynamic range - if (*(trt_state->int8_enable_ptr) && trt_builder->platformHasFastInt8()) { + if (trt_state->int8_enable && trt_builder->platformHasFastInt8()) { trt_config->setInt8Calibrator(nullptr); if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); @@ -1512,14 +1543,22 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse } // Set precision - if (*(trt_state->fp16_enable_ptr) && *(trt_state->int8_enable_ptr)) { + if (trt_state->fp16_enable && trt_state->int8_enable) { trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); - } else if (*(trt_state->fp16_enable_ptr)) { + } else if (trt_state->fp16_enable) { trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); - } else if (*(trt_state->int8_enable_ptr)) { + } else if (trt_state->int8_enable) { trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); } + // Set DLA (DLA can only run with FP16 or INT8) + if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(trt_state->dla_core); + } + // Build engine { auto lock = GetEngineBuildLock(); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 7dc121378c..2ac979b358 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -25,6 +25,8 @@ static const std::string kForceSequentialEngineBuild= "ORT_TENSORRT_FORCE_SEQUEN static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH"; static const std::string kDecryptionEnable = "ORT_TENSORRT_ENGINE_DECRYPTION_ENABLE"; static const std::string kDecryptionLibPath = "ORT_TENSORRT_ENGINE_DECRYPTION_LIB_PATH"; +static const std::string kDLAEnable = "ORT_TENSORRT_DLA_ENABLE"; +static const std::string kDLACore = "ORT_TENSORRT_DLA_CORE"; } // namespace tensorrt_env_vars class TensorrtLogger : public nvinfer1::ILogger { @@ -95,14 +97,15 @@ struct TensorrtFuncState { std::vector> output_info; std::unordered_map>> input_shape_ranges; OrtMutex* tensorrt_mu_ptr = nullptr; - bool* fp16_enable_ptr = nullptr; - bool* int8_enable_ptr = nullptr; + bool fp16_enable; + bool int8_enable; + bool dla_enable; + int dla_core; size_t* max_workspace_size_ptr = nullptr; std::string trt_node_name_with_precision; bool engine_cache_enable; std::string engine_cache_path; nvinfer1::IRuntime* runtime = nullptr; - nvinfer1::IOptimizationProfile* trt_profile = nullptr; AllocatorPtr scratch_allocator; std::unordered_map dynamic_range_map; @@ -146,6 +149,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { size_t max_workspace_size_ = 1 << 30; // 1GB bool fp16_enable_ = false; bool int8_enable_ = false; + bool dla_enable_ = false; + int dla_core_ = 0; bool force_sequential_engine_build_ = false; std::string int8_calibration_cache_name_ = "INT8_calibration_table"; bool int8_use_native_tensorrt_calibration_table_ = false; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index a5508b3ae3..d1f2270565 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -578,7 +578,13 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { for (auto option : it->second) { - if (option.first == "has_trt_options") { + if (option.first == "device_id") { + if (!option.second.empty()) { + params.device_id = std::stoi(option.second); + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'device_id' should be a number i.e. '0'.\n"); + } + } else if (option.first == "has_trt_options") { if (option.second == "True" || option.second == "true") { params.has_trt_options = true; } else if (option.second == "False" || option.second == "false") {