From 58e6aaa414d05293f68e8480f4bc0c4980e4fa63 Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Wed, 6 Nov 2019 13:00:17 -0800 Subject: [PATCH] Fix crash in releasing TLS from CUDA EP dtor (#2329) thread_local/global/static destruction order depends on implementation details of compilers and OS. The bug happens when thread_local is already out of scope while static EP being destructed, thus causing access violation in EP's destructor when accessing thread_local. The fix is to maintain ownership inside EP with a mapping from tid to ThreadLocalContext, to avoid accessing thread_local in EP's destructor. This way, no matter what the destruction order is, no access violation would be triggered. --- .../providers/cuda/cuda_execution_provider.cc | 42 ++++++++++++------- .../providers/cuda/cuda_execution_provider.h | 4 +- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index d11f2d60aa..2306fd10fb 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -103,7 +103,6 @@ CUDAExecutionProvider::~CUDAExecutionProvider() { CUDA_CALL_THROW(cudaEventDestroy(e)); it = deferred_release_cpu_ptr_.erase(it); } - ReleasePerThreadStuffs(); } CUDAExecutionProvider::PerThreadContext& CUDAExecutionProvider::GetPerThreadContext() const { @@ -114,28 +113,39 @@ CUDAExecutionProvider::PerThreadContext& CUDAExecutionProvider::GetPerThreadCont auto* p = per_thread_context_map_.get(); if (p->count(this) == 0) { std::lock_guard lock(context_pool_mutex_); - if (context_pool_.empty()) { - p->insert(std::make_pair(this, std::make_shared(device_id_))); + unsigned int tid = logging::GetThreadId(); + auto inuse_iter = inuse_contexts_.find(tid); + std::shared_ptr ptc; + if (inuse_iter == inuse_contexts_.end()) { + if (retired_context_pool_.empty()) { + ptc = std::make_shared(device_id_); + } else { + ptc = retired_context_pool_.back(); + retired_context_pool_.pop_back(); + } } else { - p->insert(std::make_pair(this, context_pool_.back())); - context_pool_.pop_back(); + ptc = inuse_iter->second; } + p->insert(std::make_pair(this, ptc)); } return *(p->at(this)); } void CUDAExecutionProvider::ReleasePerThreadStuffs() const { - if (per_thread_context_map_ != nullptr && !per_thread_context_map_->empty()) { - auto iter_ctx = per_thread_context_map_->find(this); - if (iter_ctx != per_thread_context_map_->end()) { - std::lock_guard lock(context_pool_mutex_); - context_pool_.push_back(iter_ctx->second); - per_thread_context_map_->erase(iter_ctx); - // Release TLS if empty to avoid memory leak report - if (per_thread_context_map_->empty()) { - per_thread_context_map_.reset(nullptr); - } - } + ORT_ENFORCE(per_thread_context_map_ != nullptr); + auto iter_ctx = per_thread_context_map_->find(this); + ORT_ENFORCE(iter_ctx != per_thread_context_map_->end()); + + std::lock_guard lock(context_pool_mutex_); + unsigned int tid = logging::GetThreadId(); + if (inuse_contexts_.count(tid)) { + inuse_contexts_.erase(tid); + } + retired_context_pool_.push_back(iter_ctx->second); + per_thread_context_map_->erase(iter_ctx); + // Release TLS if empty to avoid memory leak report + if (per_thread_context_map_->empty()) { + per_thread_context_map_.reset(nullptr); } } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 57a5b3957b..4d1f051f9a 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -138,9 +138,11 @@ class CUDAExecutionProvider : public IExecutionProvider { // thread local context during execution using PerThreadContextMap = std::unordered_map>; static thread_local std::unique_ptr per_thread_context_map_; + // in some compilers, TLS may not be accessible in EP's dtor, so hold ownership in here + mutable std::unordered_map> inuse_contexts_; // reuse thread local context - mutable std::deque> context_pool_; + mutable std::deque> retired_context_pool_; mutable OrtMutex context_pool_mutex_; PerThreadContext& GetPerThreadContext() const;