From 0ebe2fab51e8ec35c4682deb2adb6ba88f2fa618 Mon Sep 17 00:00:00 2001 From: stevenlix <38092805+stevenlix@users.noreply.github.com> Date: Wed, 15 Jul 2020 02:35:42 -0700 Subject: [PATCH] Refactor TensorRT EP code to better handle dynamic shape subgraphs (#4504) * build engine in runtime for dynamic shape subgraphs * Update TensorRT-ExecutionProvider.md * Update TensorRT-ExecutionProvider.md * fix build issue * Add more instructions on how to use engine caching * add precision to trt node name * Update tensorrt_execution_provider.cc * Update tensorrt_execution_provider.cc --- cmake/CMakeLists.txt | 4 + cmake/onnxruntime_providers.cmake | 1 + .../TensorRT-ExecutionProvider.md | 14 +- .../tensorrt/tensorrt_execution_provider.cc | 749 ++++++++++++------ .../tensorrt/tensorrt_execution_provider.h | 47 +- 5 files changed, 557 insertions(+), 258 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index af70328fb2..115bd52201 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -914,6 +914,10 @@ if (onnxruntime_USE_TENSORRT) set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:nvinfer.dll /DELAYLOAD:nvinfer_plugin.dll") else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated-declarations") + # needs to link with stdc++fs in Linux + if (NOT APPLE) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES stdc++fs) + endif() endif() endif() diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index ea3c9bbeeb..a9e246dbca 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -329,6 +329,7 @@ if (onnxruntime_USE_TENSORRT) include_directories(${ONNXRUNTIME_ROOT}/../cmake/external/onnx) set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) if (WIN32) + add_definitions(-D_SILENCE_EXPERIMENTAL_FILESYSTEM_DEPRECATION_WARNING=1) set(OLD_CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4996 /wd4244 /wd4267 /wd4099 /wd4551 /wd4505 /wd4515 /wd4706 /wd4456 /wd4324 /wd4701 /wd4804 /wd4702") if (CMAKE_BUILD_TYPE STREQUAL "Debug") diff --git a/docs/execution_providers/TensorRT-ExecutionProvider.md b/docs/execution_providers/TensorRT-ExecutionProvider.md index 3a7d78db8e..54bbced09e 100644 --- a/docs/execution_providers/TensorRT-ExecutionProvider.md +++ b/docs/execution_providers/TensorRT-ExecutionProvider.md @@ -67,9 +67,13 @@ ORT_TENSORRT_MIN_SUBGRAPH_SIZE: minimum node size in a subgraph after partitioni ORT_TENSORRT_FP16_ENABLE: Enable FP16 mode in TensorRT -By default TensorRT execution provider builds an ICudaEngine with max workspace size = 1 GB, max partition iterations = 1000, min subgraph size = 1 and FP16 mode is disabled. +ORT_TENSORRT_ENGINE_CACHE_ENABLE: Enable TensorRT engine caching. The purpose of using engine caching is to save engine build time in the cases that TensorRT may take long time to optimize and build engine. Engine will be cached after it's built at the first time so that next time when inference session is created the engine can be loaded directly from cache. Note each engine is created for specific settings such as precision (FP32/FP16/INT8 etc), workspace, profiles etc, and specific GPUs and it's not portable, so it's essential to make sure those settings are not changing, otherwise the engines need to be rebuilt and cached again. Also please clean up any old engine cache files (.engine) before enabling the feature for new models. Right now engine caching is only available for static shape models (subgraphs). For dynamic shape cases, since the engine is dynamically created at run-time it's hard to reuse it from previous run without knowing the profile the engine was created from. Dyanmic shape engine caching will be addressed in the future. -One can override these defaults by setting environment variables ORT_TENSORRT_MAX_WORKSPACE_SIZE, ORT_TENSORRT_MAX_PARTITION_ITERATIONS, ORT_TENSORRT_MIN_SUBGRAPH_SIZE and ORT_TENSORRT_FP16_ENABLE. +ORT_TENSORRT_ENGINE_CACHE_PATH: Specify path for TensorRT engine files if ORT_TENSORRT_ENGINE_CACHE_ENABLE is 1 + +By default TensorRT execution provider builds an ICudaEngine with max workspace size = 1 GB, max partition iterations = 1000, min subgraph size = 1, FP16 mode is disabled and TensorRT engine caching is disabled. + +One can override these defaults by setting environment variables ORT_TENSORRT_MAX_WORKSPACE_SIZE, ORT_TENSORRT_MAX_PARTITION_ITERATIONS, ORT_TENSORRT_MIN_SUBGRAPH_SIZE, ORT_TENSORRT_FP16_ENABLE, ORT_TENSORRT_ENGINE_CACHE_ENABLE and ORT_TENSORRT_ENGINE_CACHE_PATH. e.g. on Linux ### override default max workspace size to 2GB @@ -83,3 +87,9 @@ export ORT_TENSORRT_MIN_SUBGRAPH_SIZE=5 ### Enable FP16 mode in TensorRT export ORT_TENSORRT_FP16_ENABLE=1 + +### Enable TensorRT engine caching +export ORT_TENSORRT_ENGINE_CACHE_ENABLE=1 + +### Specify TensorRT engine cache path +export ORT_TENSORRT_ENGINE_CACHE_PATH="cache" diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 4321738bf2..7e7785cb55 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -3,7 +3,6 @@ #include #include "core/graph/onnx_protobuf.h" - #include "tensorrt_execution_provider.h" #include "core/providers/cuda/cuda_allocator.h" #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" @@ -24,14 +23,26 @@ #include "gsl/gsl" #include "core/graph/model.h" #include "core/providers/cuda/gpu_data_transfer.h" +#include using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::logging; +namespace fs = std::experimental::filesystem; namespace { struct KernelRegistryAndStatus { std::shared_ptr kernel_registry = std::make_shared(); Status st; }; + +std::string GetEnginePath(const ::std::string& root, const std::string& name) { + if (root.empty()) { + return name + ".engine"; + } else { + fs::path path = root; + path.append(name + ".engine"); + return path.string(); + } +} } // namespace namespace onnxruntime { @@ -140,6 +151,21 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv if (!dump_subgraphs_env.empty()) { dump_subgraphs_ = (std::stoi(dump_subgraphs_env) == 0 ? false : true); } + + const std::string engine_cache_enable_env = env_instance.GetEnvironmentVar(tensorrt_env_vars::kEngineCacheEnable); + if (!engine_cache_enable_env.empty()) { + engine_cache_enable_ = (std::stoi(engine_cache_enable_env) == 0 ? false : true); + } + + if (engine_cache_enable_) { + engine_cache_path_ = env_instance.GetEnvironmentVar(tensorrt_env_vars::kEngineCachePath); + if (!engine_cache_path_.empty() && !fs::is_directory(engine_cache_path_)) { + if (!fs::create_directory(engine_cache_path_)) { + throw std::runtime_error("Failed to create directory " + engine_cache_path_); + } + } + runtime_ = nvinfer1::createInferRuntime(GetTensorrtLogger()); + } } TensorrtExecutionProvider::~TensorrtExecutionProvider() {} @@ -204,7 +230,7 @@ bool FindCycleHelper(int i, const std::list* adjacency_map, return false; } -// Remove nodes with empty shape (for example [1, 0]) because TensorRT 7 doens't support empty shape +// Remove nodes with empty shape (for example [1, 0]) because TensorRT 7.0 doens't support empty shape SubGraphCollection_t RemoveEmptyShapeNodes(const onnxruntime::GraphViewer& graph) { // Here only NonZero, NonMaxSuppression and TopK related empty shape nodes are removed, particularly for RCNN models. // TODO: Remove the code if TensorRT fixed the issue in the future release, or find a better generic way here to work around @@ -287,6 +313,17 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } } + for (const auto& input : node->ImplicitInputDefs()) { + const auto& it = fused_outputs.find(input); + if (it != fused_outputs.end()) { + fused_outputs.erase(it); + erased.insert(input); + } else if (erased.find(input) == erased.end()) { + // Only when input is neither in output list nor erased list, add the input to input list + fused_inputs[input] = input_order++; + } + } + // For output searching, there are two special cases, // One is, if node's OutputEdges are more than its outputs, meaning certain output is used more than once, // if the output is connected to nodes that don't belong to the subgraph, the output need to be added @@ -470,7 +507,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect std::fstream dump("TensorrtExecutionProvider_TRT_Subgraph.onnx", std::ios::out | std::ios::trunc | std::ios::binary); model_proto.SerializeToOstream(&dump); } - + // Get supported node list recursively SubGraphCollection_t parser_nodes_list; TensorrtLogger& trt_logger = GetTensorrtLogger(); @@ -612,7 +649,7 @@ TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // Remove nodes with empty shape SubGraphCollection_t parser_nodes_vector = RemoveEmptyShapeNodes(graph); - // Get supported node list by TensorRT parser + // Get supported node list from TensorRT parser SubGraphCollection_t supported_nodes_vector; bool early_termination = false; supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); @@ -633,26 +670,28 @@ TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // Construct subgraph capability from node list std::vector> result; - int counter = 0; + int counter = 0, number_of_trt_nodes = 0; for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { std::unique_ptr sub_graph = GetSubGraph(group, counter, graph); result.push_back(onnxruntime::make_unique(std::move(sub_graph))); + number_of_trt_nodes += group.first.size(); } } + const int number_of_subgraphs = supported_nodes_vector.size(); + if (number_of_subgraphs == 0) { + LOGS_DEFAULT(WARNING) << "No graph is running on TensorRT exeuction provider."; + } else { + LOGS_DEFAULT(INFO) << "Number of subgraphs running on TensorRT exeuction provider: " << number_of_subgraphs; + } + return result; } common::Status TensorrtExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { for (const auto* fused_node : fused_nodes) { - std::vector input_indexes; - std::vector output_indexes; - std::unordered_map>> input_shape_ranges; - std::vector> output_shapes; - std::vector output_types; - // Build map from input name to its index in input definitions std::unordered_map input_map; const auto& input_defs = fused_node->InputDefs(); @@ -700,113 +739,102 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorparse(string_buf.data(), string_buf.size()); trt_config->setMaxWorkspaceSize(max_workspace_size_); - // Set optimization profile for dynamic shapes - auto trt_profile = trt_builder->createOptimizationProfile(); - for (unsigned int i = 0, end = trt_network->getNbInputs(); i < end; ++i) { + int num_inputs = trt_network->getNbInputs(); + int num_outputs = trt_network->getNbOutputs(); + std::unordered_map input_indexes(num_inputs); + std::unordered_map>> input_shape_ranges; + std::unordered_map output_indexes(num_outputs); + std::unordered_map output_types(num_outputs); + + // Initialize shape range for dynamic shape tensors + bool has_dynamic_shape = false; + for (unsigned int i = 0, end = num_inputs; i < end; ++i) { auto input = trt_network->getInput(i); + const std::string& input_name = input->getName(); nvinfer1::Dims dims = input->getDimensions(); - nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims); - int nb_dims = dims.nbDims; - if (input->isShapeTensor()) { // Shape tensor - std::vector shapes_min(nb_dims), shapes_opt(nb_dims), shapes_max(nb_dims); - for (int j = 0, end = nb_dims; j < end; ++j) { - shapes_min[j] = 1; - shapes_opt[j] = 1; - shapes_max[j] = 1000; - } - trt_profile->setShapeValues(input->getName(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], nb_dims); - trt_profile->setShapeValues(input->getName(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], nb_dims); - trt_profile->setShapeValues(input->getName(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], nb_dims); - } else { // Execution tensor - bool is_dynamic_shape = false; - for (int j = 0, end = nb_dims; j < end; ++j) { - // For dynamic shape subgraph, a dummy engine is created at compile phase. - // Real engine will be created at compute phase based on input data - if (dims.d[j] == -1) { // Dynamic shape - dims_min.d[j] = 1; - dims_opt.d[j] = 1; - dims_max.d[j] = 1; - is_dynamic_shape = true; - } - } - if (is_dynamic_shape) { - trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, dims_min); - trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, dims_opt); - trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, dims_max); + if (input->isShapeTensor()) { + // Shape tensor + input_shape_ranges[input_name][0] = std::make_pair(INT_MAX, INT_MIN); + has_dynamic_shape = true; + } else { + // Execution tensor + for (int j = 0, end = nb_dims; j < end; ++j) { + if (dims.d[j] == -1) { + input_shape_ranges[input_name][j] = std::make_pair(INT_MAX, INT_MIN); + has_dynamic_shape = true; + } } } } - trt_config->addOptimizationProfile(trt_profile); + std::string trt_node_name_with_precision = fused_node->Name(); if (fp16_enable_ && trt_builder->platformHasFastFp16()) { trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + trt_node_name_with_precision += "_fp16"; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled."; + } + + // Build TRT engine here if the graph doesn't have dynamic shape input. Otherwise engine will + // be built at runtime + tensorrt_ptr::unique_pointer trt_engine; + tensorrt_ptr::unique_pointer trt_context; + if (!has_dynamic_shape) { + std::ifstream planFile(GetEnginePath(engine_cache_path_, trt_node_name_with_precision), std::ios::binary | std::ios::in); + if (planFile && engine_cache_enable_) { + planFile.seekg(0, std::ios::end); + int engine_size = planFile.tellg(); + planFile.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + planFile.read((char*)engine_buf.get(), engine_size); + planFile.close(); + trt_engine = tensorrt_ptr::unique_pointer(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + GetEnginePath(engine_cache_path_, trt_node_name_with_precision); + } else { + trt_engine = tensorrt_ptr::unique_pointer(trt_builder->buildEngineWithConfig(*trt_network, *trt_config)); + if (trt_engine == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not build Engine for fused node: " + fused_node->Name()); + } + + if (engine_cache_enable_) { + nvinfer1::IHostMemory* serializedModel = trt_engine->serialize(); + std::ofstream file(GetEnginePath(engine_cache_path_, trt_node_name_with_precision), std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serializedModel->data()), serializedModel->size()); + serializedModel->destroy(); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + GetEnginePath(engine_cache_path_, trt_node_name_with_precision); + } + } + trt_context = tensorrt_ptr::unique_pointer(trt_engine->createExecutionContext()); + if (trt_context == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not build Execution Context for fused node: " + fused_node->Name()); + } } - auto trt_engine = tensorrt_ptr::unique_pointer(trt_builder->buildEngineWithConfig(*trt_network, *trt_config)); - if (trt_engine == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not build Engine for fused node: " + fused_node->Name()); - } - - // Build TensorRT context - auto trt_context = tensorrt_ptr::unique_pointer(trt_engine->createExecutionContext()); - if (trt_context == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not build Execution Context for fused node: " + fused_node->Name()); - } - - // Get input shape and binding index - int num_inputs = trt_network->getNbInputs(); - input_indexes.resize(num_inputs); + // Create input to index map for (int i = 0; i < num_inputs; ++i) { auto input = trt_network->getInput(i); - const std::string& name = input->getName(); - size_t bindingIndex = trt_engine->getBindingIndex(name.c_str()); - nvinfer1::Dims dimensions = trt_engine->getBindingDimensions(static_cast(bindingIndex)); - auto iter = input_map.find(name); + const std::string& input_name = input->getName(); + const auto& iter = input_map.find(input_name); if (iter != input_map.end()) { - input_indexes[bindingIndex] = iter->second; - } - if (input->isShapeTensor()) { // Shape tensor - for (int j = 0, end = dimensions.nbDims; j < end; ++j) { - input_shape_ranges[bindingIndex][j] = std::make_pair(INT_MAX, INT_MIN); - } - } else { - for (int j = 0, end = dimensions.nbDims; j < end; ++j) { - if (dimensions.d[j] == -1) { - input_shape_ranges[bindingIndex][j] = std::make_pair(INT_MAX, INT_MIN); - } - } + input_indexes[input_name] = iter->second; } } - // Get output shape and binding index - int num_outputs = trt_network->getNbOutputs(); - output_indexes.resize(num_outputs); - output_shapes.resize(num_outputs); - output_types.resize(num_outputs); + // Create output to index and type maps + const auto& graph_output = model_proto.graph().output(); for (int i = 0; i < num_outputs; ++i) { - const std::string& name = trt_network->getOutput(i)->getName(); - size_t bindingIndex = trt_engine->getBindingIndex(name.c_str()); - nvinfer1::Dims dimensions = trt_engine->getBindingDimensions(static_cast(bindingIndex)); - bindingIndex -= num_inputs; - auto iter = output_map.find(name); + const std::string& output_name = trt_network->getOutput(i)->getName(); + const auto& iter = output_map.find(output_name); if (iter != output_map.end()) { - output_indexes[bindingIndex] = iter->second; - } - for (int j = 0, end = dimensions.nbDims; j < end; ++j) { - output_shapes[bindingIndex].push_back(dimensions.d[j]); - } - - const auto& graph_output = model_proto.graph().output(); + output_indexes[output_name] = iter->second; + } const auto& tensor_type = graph_output[i].type().tensor_type(); - output_types[bindingIndex] = tensor_type.elem_type(); + output_types[output_name] = tensor_type.elem_type(); } - ORT_ENFORCE(trt_engine->getNbBindings() == (num_inputs + num_outputs)); - // Save engine, context and input/output info to map parsers_.emplace(fused_node->Name(), std::move(trt_parser)); engines_.emplace(fused_node->Name(), std::move(trt_engine)); @@ -817,7 +845,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorName()].push_back(output_indexes); output_info_[fused_node->Name()].push_back(output_types); input_shape_ranges_[fused_node->Name()] = input_shape_ranges; - output_shapes_[fused_node->Name()] = output_shapes; // Create function state // TODO: remove default capture @@ -827,7 +854,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, parsers_[context->node_name].get(), &engines_[context->node_name], &contexts_[context->node_name], builders_[context->node_name].get(), networks_[context->node_name].get(), input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], output_shapes_[context->node_name], &tensorrt_mu_, &fp16_enable_, + input_shape_ranges_[context->node_name], &tensorrt_mu_, &fp16_enable_, &max_workspace_size_}; *state = p.release(); return 0; @@ -844,85 +871,167 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(state); std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); - const std::vector& input_indexes = (trt_state->input_info)[0]; - const std::vector& output_indexes = (trt_state->output_info)[0]; - const std::vector& output_types = (trt_state->output_info)[1]; - - int num_binding_inputs = input_indexes.size(); - int num_binding_outputs = output_indexes.size(); - int total_bindings = num_binding_inputs + num_binding_outputs; - std::vector buffers(total_bindings); + const std::unordered_map& input_indexes = (trt_state->input_info)[0]; + const std::unordered_map& output_indexes = (trt_state->output_info)[0]; + const std::unordered_map& output_types = (trt_state->output_info)[1]; + auto& shape_ranges = trt_state->input_shape_ranges; + auto trt_builder = trt_state->builder; + auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); + int num_inputs = input_indexes.size(); + int num_outputs = output_indexes.size(); // Update shape ranges - bool dimension_update = false; - auto trt_context = trt_state->context->get(); - auto trt_builder = trt_state->builder; + bool engine_update = false; + std::unordered_map dimension_update; + std::unordered_map> tensor_shape_values; nvinfer1::IOptimizationProfile* trt_profile = nullptr; - for (int i = 0, end = num_binding_inputs; i < end; ++i) { + for (int i = 0, end = num_inputs; i < end; ++i) { + auto input = trt_state->network->getInput(i); + const std::string& input_name = input->getName(); + nvinfer1::Dims dims = input->getDimensions(); + int nb_dims = dims.nbDims; + // Check and update shape ranges for dynamic shape inputs - auto& shape_ranges = trt_state->input_shape_ranges; - if (shape_ranges.find(i) != shape_ranges.end()) { - // TODO: check if getInput indexing is same with binding index - auto input = trt_state->network->getInput(i); - nvinfer1::Dims dims = input->getDimensions(); - nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims); + dimension_update[input_name] = false; + if (shape_ranges.find(input_name) != shape_ranges.end()) { + int input_index = 0; + const auto& iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } - const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_indexes[i]); + const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_index); auto tensor_info = ort.GetTensorTypeAndShape(input_tensor); - const auto& tensor_shape = ort.GetTensorShape(tensor_info); - auto& engine = trt_context->getEngine(); - nvinfer1::Dims dimensions = engine.getBindingDimensions(static_cast(i)); - int nb_dims = dimensions.nbDims; - for (int j = 0, end = nb_dims; j < end; ++j) { - auto& shape_range = shape_ranges[i]; - if (shape_range.find(j) != shape_range.end()) { - dims_min.d[j] = shape_range[j].first; - dims_opt.d[j] = shape_range[j].second; - dims_max.d[j] = shape_range[j].second; + const auto& tensor_shapes = ort.GetTensorShape(tensor_info); + auto& shape_range = shape_ranges[input_name]; - // Update minimum dimension - if (tensor_shape[j] < shape_range[j].first) { - shape_range[j].first = tensor_shape[j]; - dims_min.d[j] = tensor_shape[j]; - dims_opt.d[j] = tensor_shape[j]; - dimension_update = true; + // Create shape profile + if (input->isShapeTensor()) { + // Get shape values for shape tensor input + const auto& tensor_type = ort.GetTensorElementType(tensor_info); + int shape_size = nb_dims == 0 ? 1 : tensor_shapes[0]; + tensor_shape_values[input_name].reserve(shape_size); + switch (tensor_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + int32_t* input = new int32_t[shape_size]; + cudaMemcpy(input, ort.GetTensorData(input_tensor), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost); + for (int j = 0; j < shape_size; ++j) { + tensor_shape_values[input_name][j] = input[j]; + } + delete[] input; + break; } - // Update maximum dimension - if (tensor_shape[j] > shape_range[j].second) { - shape_range[j].second = tensor_shape[j]; - dims_max.d[j] = tensor_shape[j]; - dims_opt.d[j] = tensor_shape[j]; - dimension_update = true; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + int64_t* input = new int64_t[shape_size]; + cudaMemcpy(input, ort.GetTensorData(input_tensor), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost); + for (int j = 0; j < shape_size; ++j) { + tensor_shape_values[input_name][j] = static_cast(input[j]); + } + delete[] input; + break; + } + default: { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported."); } } - } - if (dimension_update) { - if (trt_profile == nullptr) { - trt_profile = trt_builder->createOptimizationProfile(); - } - if (engine.isShapeBinding(i)) { - std::vector shapes_min(nb_dims), shapes_opt(nb_dims), shapes_max(nb_dims); - for (int j = 0, end = nb_dims; j < end; ++j) { - shapes_min[j] = dims_min.d[j]; - shapes_opt[j] = dims_opt.d[j]; - shapes_max[j] = dims_max.d[j]; + // Update shape ranges + std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); + int shape_range_size = shape_range.size(); + if (shape_size == shape_range_size) { + // If shape size matches, check/update shape range + for (int j = 0; j < shape_size; ++j) { + shapes_min[j] = shape_range[j].first; + shapes_opt[j] = shape_range[j].second; + shapes_max[j] = shape_range[j].second; + + const auto& tensor_shape_value = tensor_shape_values[input_name][j]; + // Update shape range lower bound + if (tensor_shape_value < shape_range[j].first) { + shape_range[j].first = tensor_shape_value; + shapes_min[j] = tensor_shape_value; + shapes_opt[j] = tensor_shape_value; + dimension_update[input_name] = true; + } + // Update shape range upper bound + if (tensor_shape_value > shape_range[j].second) { + shape_range[j].second = tensor_shape_value; + shapes_max[j] = tensor_shape_value; + shapes_opt[j] = tensor_shape_value; + dimension_update[input_name] = true; + } } - trt_profile->setShapeValues(input->getName(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], nb_dims); - trt_profile->setShapeValues(input->getName(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], nb_dims); - trt_profile->setShapeValues(input->getName(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], nb_dims); } else { - trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, dims_min); - trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, dims_opt); - trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, dims_max); + // If shape size doesn't match, initialize shape_range with the new shape value + shape_range.clear(); + for (int j = 0; j < shape_size; ++j) { + const auto& tensor_shape_value = tensor_shape_values[input_name][j]; + shape_range[j] = std::make_pair(tensor_shape_value, tensor_shape_value); + shapes_min[j] = tensor_shape_value; + shapes_opt[j] = tensor_shape_value; + shapes_max[j] = tensor_shape_value; + } + dimension_update[input_name] = true; + } + + if (dimension_update[input_name]) { + if (trt_profile == nullptr) { + trt_profile = trt_builder->createOptimizationProfile(); + } + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); + } + + } else //execution tensor + { + nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims); + for (int j = 0, end = nb_dims; j < end; ++j) { + const auto& tensor_shape = tensor_shapes[j]; + if (shape_range.find(j) != shape_range.end()) { + dims_min.d[j] = shape_range[j].first; + dims_opt.d[j] = shape_range[j].second; + dims_max.d[j] = shape_range[j].second; + + // Update minimum dimension + if (tensor_shape < shape_range[j].first) { + shape_range[j].first = tensor_shape; + dims_min.d[j] = tensor_shape; + dims_opt.d[j] = tensor_shape; + dimension_update[input_name] = true; + } + // Update maximum dimension + if (tensor_shape > shape_range[j].second) { + shape_range[j].second = tensor_shape; + dims_max.d[j] = tensor_shape; + dims_opt.d[j] = tensor_shape; + dimension_update[input_name] = true; + } + } + } + + if (dimension_update[input_name]) { + if (trt_profile == nullptr) { + trt_profile = trt_builder->createOptimizationProfile(); + } + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max); } } + ort.ReleaseTensorTypeAndShapeInfo(tensor_info); + } + + if (!engine_update && dimension_update[input_name]) { + engine_update = true; } } - // Regenerate engine and context + // Regenerate engine // Only one profile is generated, so no need to explicitly set optimization profile - if (dimension_update) { + if (engine_update) { auto trt_config = tensorrt_ptr::unique_pointer(trt_builder->createBuilderConfig()); trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr)); trt_config->addOptimizationProfile(trt_profile); @@ -931,111 +1040,285 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorcontext->reset(); trt_state->engine->reset(); - *(trt_state->engine) = tensorrt_ptr::unique_pointer( - trt_builder->buildEngineWithConfig(*trt_state->network, *trt_config)); + *(trt_state->engine) = tensorrt_ptr::unique_pointer( + trt_builder->buildEngineWithConfig(*trt_state->network, *trt_config)); if (trt_state->engine->get() == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); } + trt_engine = trt_state->engine->get(); + *(trt_state->context) = tensorrt_ptr::unique_pointer( - trt_state->engine->get()->createExecutionContext()); + trt_state->engine->get()->createExecutionContext()); if (trt_state->context->get() == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Create Context."); } trt_context = trt_state->context->get(); } - // Set input shapes and assign input buffers - for (int i = 0, end = num_binding_inputs; i < end; ++i) { - const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_indexes[i]); - auto tensor_info = ort.GetTensorTypeAndShape(input_tensor); - const auto& tensor_shape = ort.GetTensorShape(tensor_info); - - // Set dynamic shapes - nvinfer1::Dims dimensions = trt_context->getBindingDimensions(static_cast(i)); - int nb_dims = dimensions.nbDims; - if (dimension_update) { - for (int j = 0, end = nb_dims; j < end; ++j) - dimensions.d[j] = tensor_shape[j]; - trt_context->setBindingDimensions(i, dimensions); - } - - auto tensor_type = ort.GetTensorElementType(tensor_info); - ort.ReleaseTensorTypeAndShapeInfo(tensor_info); - if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { - buffers[i] = const_cast(ort.GetTensorData(input_tensor)); - } else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { - buffers[i] = const_cast(ort.GetTensorData(input_tensor)); - } else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) { - buffers[i] = const_cast(ort.GetTensorData(input_tensor)); - } else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) { - buffers[i] = const_cast(ort.GetTensorData(input_tensor)); - } else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { - buffers[i] = const_cast(ort.GetTensorData(input_tensor)); - } else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { - // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64 - SafeInt input_dim_size = 1; - for (int j = 0, end = nb_dims; j < end; ++j) { - input_dim_size *= tensor_shape[j]; - } - CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[i], input_dim_size * sizeof(int32_t))); - cuda::Impl_Cast(ort.GetTensorData(input_tensor), reinterpret_cast(buffers[i]), input_dim_size); + // Get input and output binding names + int total_bindings = trt_engine->getNbBindings(); + std::vector buffers(total_bindings); + std::vector input_binding_names, output_binding_names; + for (int i = 0, end = total_bindings; i < end; ++i) { + if (trt_engine->bindingIsInput(i)) { + input_binding_names.push_back(trt_engine->getBindingName(i)); } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); + output_binding_names.push_back(trt_engine->getBindingName(i)); } } - // Set output shapes and assign output buffers - std::vector output_dim_sizes(num_binding_outputs, 1); - std::vector output_tensor(num_binding_outputs, nullptr); - for (int i = 0, end = num_binding_outputs; i < end; ++i) { - // Set dynamic shapes - nvinfer1::Dims dimensions = trt_context->getBindingDimensions(static_cast(i + num_binding_inputs)); - int nb_dims = dimensions.nbDims; - for (int j = 0, end = nb_dims; j < end; ++j) { - trt_state->output_shapes[i][j] = dimensions.d[j]; + // Set input shapes and assign input buffers + std::vector binding_buffers_to_freeup; + for (int i = 0, end = input_binding_names.size(); i < end; ++i) { + const std::string& input_name = input_binding_names[i]; + int binding_index = trt_engine->getBindingIndex(input_name.c_str()); + if (binding_index == -1) { + continue; } - int output_index = output_indexes[i]; - output_tensor[i] = ort.KernelContext_GetOutput(context, output_index, trt_state->output_shapes[i].data(), trt_state->output_shapes[i].size()); + int input_index = 0; + const auto& iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } + const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_index); + auto tensor_info = ort.GetTensorTypeAndShape(input_tensor); + const auto& tensor_shapes = ort.GetTensorShape(tensor_info); - if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { - buffers[i + num_binding_inputs] = ort.GetTensorMutableData(output_tensor[i]); - } else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { - buffers[i + num_binding_inputs] = ort.GetTensorMutableData(output_tensor[i]); - } else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) { - buffers[i + num_binding_inputs] = ort.GetTensorMutableData(output_tensor[i]); - } else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) { - buffers[i + num_binding_inputs] = ort.GetTensorMutableData(output_tensor[i]); - } else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { - buffers[i + num_binding_inputs] = ort.GetTensorMutableData(output_tensor[i]); - } else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { - // Allocate INT32 CUDA memory for INT64 output type because TensorRT doesn't fully support INT64 - SafeInt output_dim_size(output_dim_sizes[i]); - for (int j = 0, end = nb_dims; j < end; ++j) { - output_dim_size *= dimensions.d[j]; + // Set dynamic shapes + nvinfer1::Dims dimensions = trt_engine->getBindingDimensions(static_cast(binding_index)); + int nb_dims = dimensions.nbDims; + if (dimension_update.find(input_name) != dimension_update.end()) { + if (trt_engine->isShapeBinding(binding_index)) { + trt_context->setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]); + } else { + for (int j = 0, end = nb_dims; j < end; ++j) { + dimensions.d[j] = tensor_shapes[j]; + } + trt_context->setBindingDimensions(binding_index, dimensions); + } + } + + const auto& input_type = ort.GetTensorElementType(tensor_info); + switch (input_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + auto input_tensor_ptr = ort.GetTensorData(input_tensor); + if (input_tensor_ptr == nullptr) { + CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[binding_index], sizeof(float))); + binding_buffers_to_freeup.push_back(binding_index); + } else { + buffers[binding_index] = const_cast(input_tensor_ptr); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { + auto input_tensor_ptr = ort.GetTensorData(input_tensor); + if (input_tensor_ptr == nullptr) { + CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[binding_index], sizeof(MLFloat16))); + binding_buffers_to_freeup.push_back(binding_index); + } else { + buffers[binding_index] = const_cast(input_tensor_ptr); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + auto input_tensor_ptr = ort.GetTensorData(input_tensor); + if (input_tensor_ptr == nullptr) { + CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[binding_index], sizeof(bool))); + binding_buffers_to_freeup.push_back(binding_index); + } else { + buffers[binding_index] = const_cast(input_tensor_ptr); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + auto input_tensor_ptr = ort.GetTensorData(input_tensor); + if (input_tensor_ptr == nullptr) { + CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[binding_index], sizeof(int8_t))); + binding_buffers_to_freeup.push_back(binding_index); + } else { + buffers[binding_index] = const_cast(input_tensor_ptr); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + auto input_tensor_ptr = ort.GetTensorData(input_tensor); + if (input_tensor_ptr == nullptr) { + CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[binding_index], sizeof(int32_t))); + binding_buffers_to_freeup.push_back(binding_index); + } else { + buffers[binding_index] = const_cast(input_tensor_ptr); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64 + auto input_tensor_ptr = ort.GetTensorData(input_tensor); + if (input_tensor_ptr == nullptr) { + CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[binding_index], sizeof(int32_t))); + } else { + SafeInt input_dim_size = 1; + for (int j = 0, end = nb_dims; j < end; ++j) { + if (tensor_shapes[j] == 0) { + input_dim_size = 1; + break; + } else { + input_dim_size *= tensor_shapes[j]; + } + } + CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[binding_index], input_dim_size * sizeof(int32_t))); + cuda::Impl_Cast(input_tensor_ptr, reinterpret_cast(buffers[binding_index]), input_dim_size); + } + binding_buffers_to_freeup.push_back(binding_index); + break; + } + default: { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP input onnx tensor data type: " + std::to_string(input_type) + " not supported."); + } + } + ort.ReleaseTensorTypeAndShapeInfo(tensor_info); + } + + // Set output shapes and assign output buffers + std::vector output_dim_sizes(num_outputs, 1); + std::vector output_tensor(num_outputs, nullptr); + for (int i = 0, end = output_binding_names.size(); i < end; ++i) { + // Set dynamic shapes + const std::string& output_name = output_binding_names[i]; + int binding_index = trt_engine->getBindingIndex(output_name.c_str()); + if (binding_index == -1) { + continue; + } + + int output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + nvinfer1::Dims dimensions = trt_context->getBindingDimensions(static_cast(binding_index)); + int nb_dims = dimensions.nbDims; + std::vector output_shapes(nb_dims); + for (int j = 0, end = nb_dims; j < end; ++j) { + output_shapes[j] = dimensions.d[j]; + } + output_tensor[i] = ort.KernelContext_GetOutput(context, output_index, output_shapes.data(), output_shapes.size()); + + int output_type = 0; + const auto& type_iter = output_types.find(output_name); + if (type_iter != output_types.end()) { + output_type = type_iter->second; + } + + switch (output_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + auto output_tensor_ptr = ort.GetTensorMutableData(output_tensor[i]); + if (output_tensor_ptr == nullptr) { + CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[binding_index], sizeof(float))); + binding_buffers_to_freeup.push_back(binding_index); + } else { + buffers[binding_index] = output_tensor_ptr; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { + auto output_tensor_ptr = ort.GetTensorMutableData(output_tensor[i]); + if (output_tensor_ptr == nullptr) { + CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[binding_index], sizeof(MLFloat16))); + binding_buffers_to_freeup.push_back(binding_index); + } else { + buffers[binding_index] = output_tensor_ptr; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + auto output_tensor_ptr = ort.GetTensorMutableData(output_tensor[i]); + if (output_tensor_ptr == nullptr) { + CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[binding_index], sizeof(bool))); + binding_buffers_to_freeup.push_back(binding_index); + } else { + buffers[binding_index] = output_tensor_ptr; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + auto output_tensor_ptr = ort.GetTensorMutableData(output_tensor[i]); + if (output_tensor_ptr == nullptr) { + CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[binding_index], sizeof(int8_t))); + binding_buffers_to_freeup.push_back(binding_index); + } else { + buffers[binding_index] = output_tensor_ptr; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + auto output_tensor_ptr = ort.GetTensorMutableData(output_tensor[i]); + if (output_tensor_ptr == nullptr) { + CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[binding_index], sizeof(int32_t))); + binding_buffers_to_freeup.push_back(binding_index); + } else { + buffers[binding_index] = output_tensor_ptr; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + // Allocate INT32 CUDA memory for INT64 output type because TensorRT doesn't fully support INT64 + auto output_tensor_ptr = ort.GetTensorMutableData(output_tensor[i]); + if (output_tensor_ptr == nullptr) { + CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[binding_index], sizeof(int32_t))); + output_dim_sizes[i] = 1; + } else { + SafeInt output_dim_size(output_dim_sizes[i]); + for (int j = 0, end = nb_dims; j < end; ++j) { + if (dimensions.d[j] == 0) { + output_dim_size = 1; + break; + } else { + output_dim_size *= dimensions.d[j]; + } + } + CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[binding_index], output_dim_size * sizeof(int32_t))); + output_dim_sizes[i] = output_dim_size; + } + binding_buffers_to_freeup.push_back(binding_index); + break; + } + default: { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported."); } - CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[i + num_binding_inputs], output_dim_size * sizeof(int32_t))); - output_dim_sizes[i] = output_dim_size; - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP output onnx tensor data type: " + std::to_string(output_types[i]) + " not supported."); } } // Run TRT inference if (!trt_context->enqueueV2(&buffers[0], nullptr, nullptr)) { + for (const auto& binding_index : binding_buffers_to_freeup) { + cudaFree(buffers[binding_index]); + } return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP Execution Context Enqueue Failed."); } // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64 - for (int i = 0, end = num_binding_outputs; i < end; ++i) { - if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { - cuda::Impl_Cast(reinterpret_cast(buffers[i + num_binding_inputs]), ort.GetTensorMutableData(output_tensor[i]), output_dim_sizes[i]); - cudaDeviceSynchronize(); - cudaFree(buffers[i + num_binding_inputs]); + for (int i = 0, end = output_binding_names.size(); i < end; ++i) { + const std::string& output_name = output_binding_names[i]; + size_t binding_index = trt_engine->getBindingIndex(output_name.c_str()); + int output_type = 0; + const auto& iter = output_types.find(output_name); + if (iter != output_types.end()) { + output_type = iter->second; } + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + auto output_tensor_ptr = ort.GetTensorMutableData(output_tensor[i]); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(reinterpret_cast(buffers[binding_index]), output_tensor_ptr, output_dim_sizes[i]); + } + } + } + + cudaDeviceSynchronize(); + for (const auto& binding_index : binding_buffers_to_freeup) { + cudaFree(buffers[binding_index]); } return Status::OK(); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 33d7eb1532..4a0b9d9c30 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -17,6 +17,8 @@ static const std::string kMinSubgraphSize = "ORT_TENSORRT_MIN_SUBGRAPH_SIZE"; static const std::string kMaxWorkspaceSize = "ORT_TENSORRT_MAX_WORKSPACE_SIZE"; static const std::string kFP16Enable = "ORT_TENSORRT_FP16_ENABLE"; static const std::string kDumpSubgraphs = "ORT_TENSORRT_DUMP_SUBGRAPHS"; +static const std::string kEngineCacheEnable = "ORT_TENSORRT_ENGINE_CACHE_ENABLE"; +static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH"; } // namespace tensorrt_env_vars class TensorrtLogger : public nvinfer1::ILogger { @@ -40,19 +42,19 @@ class TensorrtLogger : public nvinfer1::ILogger { namespace tensorrt_ptr { - struct TensorrtInferDeleter { - template - void operator()(T* obj) const { - if (obj) { - obj->destroy(); - } - } - }; - +struct TensorrtInferDeleter { template - using unique_pointer = std::unique_ptr; + void operator()(T* obj) const { + if (obj) { + obj->destroy(); + } + } }; +template +using unique_pointer = std::unique_ptr; +}; // namespace tensorrt_ptr + // Information needed to construct trt execution providers. struct TensorrtExecutionProviderInfo { int device_id{0}; @@ -60,19 +62,17 @@ struct TensorrtExecutionProviderInfo { // Information to construct kernel function state. struct TensorrtFuncState { - AllocateFunc test_allocate_func = nullptr; DestroyFunc test_release_func = nullptr; AllocatorHandle allocator = nullptr; nvonnxparser::IParser* parser = nullptr; - tensorrt_ptr::unique_pointer * engine = nullptr; - tensorrt_ptr::unique_pointer * context = nullptr; + tensorrt_ptr::unique_pointer* engine = nullptr; + tensorrt_ptr::unique_pointer* context = nullptr; nvinfer1::IBuilder* builder = nullptr; nvinfer1::INetworkDefinition* network = nullptr; - std::vector> input_info; - std::vector> output_info; - std::unordered_map>> input_shape_ranges; - std::vector> output_shapes; + std::vector> input_info; + std::vector> output_info; + std::unordered_map>> input_shape_ranges; OrtMutex* tensorrt_mu_ptr = nullptr; bool* fp16_enable_ptr = nullptr; size_t* max_workspace_size_ptr = nullptr; @@ -104,7 +104,9 @@ class TensorrtExecutionProvider : public IExecutionProvider { int min_subgraph_size_ = 1; bool fp16_enable_ = false; bool dump_subgraphs_ = false; - + bool engine_cache_enable_ = false; + std::string engine_cache_path_; + nvinfer1::IRuntime* runtime_ = nullptr; OrtMutex tensorrt_mu_; int device_id_; @@ -113,10 +115,9 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_map> contexts_; std::unordered_map> builders_; std::unordered_map> networks_; - std::unordered_map>> input_info_; - std::unordered_map>> output_info_; - std::unordered_map>>> input_shape_ranges_; - std::unordered_map>> output_shapes_; + std::unordered_map>> input_info_; + std::unordered_map>> output_info_; + std::unordered_map>>> input_shape_ranges_; /**Get IndexedSubGraph based on node list of the subgraph*/ std::unique_ptr GetSubGraph(SubGraph_t graph_nodes_index, int& kernels_index, @@ -137,4 +138,4 @@ class TensorrtExecutionProvider : public IExecutionProvider { AllocatorPtr allocator_; }; -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file