mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
[TensorRT] query GPU properties only once when setting device_id (#19092)
### Description For most models this does not show significant overhead but for very small models it shows significant impact. Attached screenshot shows impact when only using 2 FC layers: 
This commit is contained in:
parent
4d1243b4b4
commit
daa22f919f
3 changed files with 12 additions and 16 deletions
|
|
@ -1315,6 +1315,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
|
|||
InitProviderOrtApi();
|
||||
|
||||
CUDA_CALL_THROW(cudaSetDevice(device_id_));
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CALL_THROW(cudaGetDeviceProperties(&prop, device_id_));
|
||||
compute_capability_ = GetComputeCapacity(prop);
|
||||
if (info.has_user_compute_stream) {
|
||||
external_stream_ = true;
|
||||
stream_ = static_cast<cudaStream_t>(info.user_compute_stream);
|
||||
|
|
@ -2778,19 +2781,15 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
|
||||
// Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
|
||||
// Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CALL_THROW(cudaGetDeviceProperties(&prop, device_id_));
|
||||
std::string compute_capability = GetComputeCapacity(prop);
|
||||
|
||||
if (!has_dynamic_shape) {
|
||||
const std::string cache_path = GetCachePath(cache_path_, trt_node_name_with_precision);
|
||||
const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine";
|
||||
const std::string engine_cache_path = cache_path + "_sm" + compute_capability_ + ".engine";
|
||||
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
|
||||
const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile";
|
||||
const std::string profile_cache_path = cache_path + "_sm" + compute_capability_ + ".profile";
|
||||
std::string timing_cache_path = "";
|
||||
bool engine_update = false;
|
||||
if (timing_cache_enable_) {
|
||||
timing_cache_path = GetTimingCachePath(global_cache_path_, prop);
|
||||
timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_);
|
||||
}
|
||||
{
|
||||
// ifstream file check, engine serialization/deserialization and engine build are in critical section. It needs lock protection to prevent race condition when inferencing with multithreading.
|
||||
|
|
@ -3043,18 +3042,14 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
|
||||
// Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
|
||||
// Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CALL_THROW(cudaGetDeviceProperties(&prop, device_id_));
|
||||
std::string compute_capability = GetComputeCapacity(prop);
|
||||
|
||||
// Prepare cache name
|
||||
const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
|
||||
const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine";
|
||||
const std::string engine_cache_path = cache_path + "_sm" + compute_capability_ + ".engine";
|
||||
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
|
||||
const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile";
|
||||
const std::string profile_cache_path = cache_path + "_sm" + compute_capability_ + ".profile";
|
||||
std::string timing_cache_path = "";
|
||||
if (timing_cache_enable_) {
|
||||
timing_cache_path = GetTimingCachePath(global_cache_path_, prop);
|
||||
timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_);
|
||||
}
|
||||
|
||||
// Load serialized engine
|
||||
|
|
|
|||
|
|
@ -258,6 +258,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
|
|||
std::unique_ptr<nvinfer1::IRuntime> runtime_ = nullptr;
|
||||
OrtMutex tensorrt_mu_;
|
||||
int device_id_;
|
||||
std::string compute_capability_;
|
||||
bool context_memory_sharing_enable_ = false;
|
||||
bool layer_norm_fp32_fallback_ = false;
|
||||
size_t max_ctx_mem_size_ = 0;
|
||||
|
|
|
|||
|
|
@ -456,10 +456,10 @@ std::string GetComputeCapacity(const cudaDeviceProp& prop) {
|
|||
* Get Timing by compute capability
|
||||
*
|
||||
*/
|
||||
std::string GetTimingCachePath(const std::string& root, cudaDeviceProp prop) {
|
||||
std::string GetTimingCachePath(const std::string& root, std::string& compute_cap) {
|
||||
// append compute capability of the GPU as this invalidates the cache and TRT will throw when loading the cache
|
||||
const std::string timing_cache_name = "TensorrtExecutionProvider_cache_sm" +
|
||||
GetComputeCapacity(prop) + ".timing";
|
||||
compute_cap + ".timing";
|
||||
return GetCachePath(root, timing_cache_name);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue