diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 3ca0935b9e..be924d6a68 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -169,11 +169,20 @@ nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { nvinfer1::TacticSource source{}; t = toUpper(t); if (t == "CUBLAS") { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS is deprecated in TensorRT 10.0"; +#if NV_TENSORRT_MAJOR < 10 source = nvinfer1::TacticSource::kCUBLAS; +#endif } else if (t == "CUBLASLT" || t == "CUBLAS_LT") { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS_LT is deprecated in TensorRT 9.0"; +#if NV_TENSORRT_MAJOR < 9 source = nvinfer1::TacticSource::kCUBLAS_LT; +#endif } else if (t == "CUDNN") { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUDNN is deprecated in TensorRT 10.0"; +#if NV_TENSORRT_MAJOR < 10 source = nvinfer1::TacticSource::kCUDNN; +#endif } else if (t == "EDGE_MASK_CONVOLUTIONS") { source = nvinfer1::TacticSource::kEDGE_MASK_CONVOLUTIONS; } else if (t == "JIT_CONVOLUTIONS") { @@ -298,6 +307,25 @@ void CudaCall(cudnnStatus_t retCode, const char* exprString return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); } +#if NV_TENSORRT_MAJOR >= 10 +void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, + uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { + // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr + // even for empty tensors, so allocate a dummy byte. + size = std::max(size, static_cast(1)); + if (size > allocated_size) { + cudaFree(outputPtr); + outputPtr = nullptr; + allocated_size = 0; + if (cudaMalloc(&outputPtr, size) == cudaSuccess) { + allocated_size = size; + } + } + // if cudaMalloc fails, returns nullptr. + return outputPtr; +} +#else +// Only override this method when TensorRT <= 8.6 void* OutputAllocator::reallocateOutput(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, uint64_t /*alignment*/) noexcept { // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr @@ -314,6 +342,7 @@ void* OutputAllocator::reallocateOutput(char const* /*tensorName*/, void* /*curr // if cudaMalloc fails, returns nullptr. return outputPtr; } +#endif void OutputAllocator::notifyShape(char const* /*tensorName*/, nvinfer1::Dims const& dims) noexcept { output_shapes.clear(); @@ -3152,14 +3181,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (mem_size > max_ctx_mem_size_) { max_ctx_mem_size_ = mem_size; } - -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) // nvinfer1::ICudaEngine::createExecutionContextWithoutDeviceMemory was deprecated -#endif +#if NV_TENSORRT_MAJOR < 10 trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); -#if defined(_MSC_VER) -#pragma warning(pop) +#else + trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); #endif } else { trt_context = std::unique_ptr(trt_engine->createExecutionContext()); @@ -3606,14 +3631,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (context_update) { if (trt_state->context_memory_sharing_enable) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) // nvinfer1::ICudaEngine::createExecutionContextWithoutDeviceMemory was deprecated -#endif +#if NV_TENSORRT_MAJOR < 10 *(trt_state->context) = std::unique_ptr( trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); -#if defined(_MSC_VER) -#pragma warning(pop) +#else + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); #endif } else { *(trt_state->context) = std::unique_ptr( @@ -3830,13 +3853,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con if (mem_size > max_ctx_mem_size_) { max_ctx_mem_size_ = mem_size; } -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) // nvinfer1::ICudaEngine::createExecutionContextWithoutDeviceMemory was deprecated -#endif +#if NV_TENSORRT_MAJOR < 10 trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); -#if defined(_MSC_VER) -#pragma warning(pop) +#else + trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); #endif } else { trt_context = std::unique_ptr(trt_engine->createExecutionContext()); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index f4dae57487..ec14057956 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -116,8 +116,11 @@ using unique_pointer = std::unique_ptr; // class OutputAllocator : public nvinfer1::IOutputAllocator { public: +#if NV_TENSORRT_MAJOR >= 10 + void* reallocateOutputAsync(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment, cudaStream_t stream) noexcept override; +#else void* reallocateOutput(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment) noexcept override; - +#endif void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override; void* getBuffer() {