From daa22f919feb0280e2951e5345483c3c8a8afebf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Thu, 11 Jan 2024 22:37:10 +0100 Subject: [PATCH] [TensorRT] query GPU properties only once when setting device_id (#19092) ### Description For most models this does not show significant overhead but for very small models it shows significant impact. Attached screenshot shows impact when only using 2 FC layers: ![image](https://github.com/microsoft/onnxruntime/assets/44298237/b4fdf8bf-0422-43ab-a49e-7d2996cba68e) --- .../tensorrt/tensorrt_execution_provider.cc | 23 ++++++++----------- .../tensorrt/tensorrt_execution_provider.h | 1 + .../tensorrt_execution_provider_utils.h | 4 ++-- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 7397b84373..4ece068b50 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1315,6 +1315,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv InitProviderOrtApi(); CUDA_CALL_THROW(cudaSetDevice(device_id_)); + cudaDeviceProp prop; + CUDA_CALL_THROW(cudaGetDeviceProperties(&prop, device_id_)); + compute_capability_ = GetComputeCapacity(prop); if (info.has_user_compute_stream) { external_stream_ = true; stream_ = static_cast(info.user_compute_stream); @@ -2778,19 +2781,15 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine_cache_path, trt_state->trt_node_name_with_precision); - const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine"; + const std::string engine_cache_path = cache_path + "_sm" + compute_capability_ + ".engine"; const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; - const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile"; + const std::string profile_cache_path = cache_path + "_sm" + compute_capability_ + ".profile"; std::string timing_cache_path = ""; if (timing_cache_enable_) { - timing_cache_path = GetTimingCachePath(global_cache_path_, prop); + timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_); } // Load serialized engine diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 7eefdd3cba..bacdf0f3c9 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -258,6 +258,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unique_ptr runtime_ = nullptr; OrtMutex tensorrt_mu_; int device_id_; + std::string compute_capability_; bool context_memory_sharing_enable_ = false; bool layer_norm_fp32_fallback_ = false; size_t max_ctx_mem_size_ = 0; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h index 6bbeab7e94..c69299d0ec 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h @@ -456,10 +456,10 @@ std::string GetComputeCapacity(const cudaDeviceProp& prop) { * Get Timing by compute capability * */ -std::string GetTimingCachePath(const std::string& root, cudaDeviceProp prop) { +std::string GetTimingCachePath(const std::string& root, std::string& compute_cap) { // append compute capability of the GPU as this invalidates the cache and TRT will throw when loading the cache const std::string timing_cache_name = "TensorrtExecutionProvider_cache_sm" + - GetComputeCapacity(prop) + ".timing"; + compute_cap + ".timing"; return GetCachePath(root, timing_cache_name); }