mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Fix crash in releasing TLS from CUDA EP dtor (#2329)
thread_local/global/static destruction order depends on implementation details of compilers and OS. The bug happens when thread_local is already out of scope while static EP being destructed, thus causing access violation in EP's destructor when accessing thread_local. The fix is to maintain ownership inside EP with a mapping from tid to ThreadLocalContext, to avoid accessing thread_local in EP's destructor. This way, no matter what the destruction order is, no access violation would be triggered.
This commit is contained in:
parent
c0b8926863
commit
58e6aaa414
2 changed files with 29 additions and 17 deletions
|
|
@ -103,7 +103,6 @@ CUDAExecutionProvider::~CUDAExecutionProvider() {
|
|||
CUDA_CALL_THROW(cudaEventDestroy(e));
|
||||
it = deferred_release_cpu_ptr_.erase(it);
|
||||
}
|
||||
ReleasePerThreadStuffs();
|
||||
}
|
||||
|
||||
CUDAExecutionProvider::PerThreadContext& CUDAExecutionProvider::GetPerThreadContext() const {
|
||||
|
|
@ -114,28 +113,39 @@ CUDAExecutionProvider::PerThreadContext& CUDAExecutionProvider::GetPerThreadCont
|
|||
auto* p = per_thread_context_map_.get();
|
||||
if (p->count(this) == 0) {
|
||||
std::lock_guard<OrtMutex> lock(context_pool_mutex_);
|
||||
if (context_pool_.empty()) {
|
||||
p->insert(std::make_pair(this, std::make_shared<PerThreadContext>(device_id_)));
|
||||
unsigned int tid = logging::GetThreadId();
|
||||
auto inuse_iter = inuse_contexts_.find(tid);
|
||||
std::shared_ptr<PerThreadContext> ptc;
|
||||
if (inuse_iter == inuse_contexts_.end()) {
|
||||
if (retired_context_pool_.empty()) {
|
||||
ptc = std::make_shared<PerThreadContext>(device_id_);
|
||||
} else {
|
||||
ptc = retired_context_pool_.back();
|
||||
retired_context_pool_.pop_back();
|
||||
}
|
||||
} else {
|
||||
p->insert(std::make_pair(this, context_pool_.back()));
|
||||
context_pool_.pop_back();
|
||||
ptc = inuse_iter->second;
|
||||
}
|
||||
p->insert(std::make_pair(this, ptc));
|
||||
}
|
||||
return *(p->at(this));
|
||||
}
|
||||
|
||||
void CUDAExecutionProvider::ReleasePerThreadStuffs() const {
|
||||
if (per_thread_context_map_ != nullptr && !per_thread_context_map_->empty()) {
|
||||
auto iter_ctx = per_thread_context_map_->find(this);
|
||||
if (iter_ctx != per_thread_context_map_->end()) {
|
||||
std::lock_guard<OrtMutex> lock(context_pool_mutex_);
|
||||
context_pool_.push_back(iter_ctx->second);
|
||||
per_thread_context_map_->erase(iter_ctx);
|
||||
// Release TLS if empty to avoid memory leak report
|
||||
if (per_thread_context_map_->empty()) {
|
||||
per_thread_context_map_.reset(nullptr);
|
||||
}
|
||||
}
|
||||
ORT_ENFORCE(per_thread_context_map_ != nullptr);
|
||||
auto iter_ctx = per_thread_context_map_->find(this);
|
||||
ORT_ENFORCE(iter_ctx != per_thread_context_map_->end());
|
||||
|
||||
std::lock_guard<OrtMutex> lock(context_pool_mutex_);
|
||||
unsigned int tid = logging::GetThreadId();
|
||||
if (inuse_contexts_.count(tid)) {
|
||||
inuse_contexts_.erase(tid);
|
||||
}
|
||||
retired_context_pool_.push_back(iter_ctx->second);
|
||||
per_thread_context_map_->erase(iter_ctx);
|
||||
// Release TLS if empty to avoid memory leak report
|
||||
if (per_thread_context_map_->empty()) {
|
||||
per_thread_context_map_.reset(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -138,9 +138,11 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
// thread local context during execution
|
||||
using PerThreadContextMap = std::unordered_map<const CUDAExecutionProvider*, std::shared_ptr<PerThreadContext>>;
|
||||
static thread_local std::unique_ptr<PerThreadContextMap> per_thread_context_map_;
|
||||
// in some compilers, TLS may not be accessible in EP's dtor, so hold ownership in here
|
||||
mutable std::unordered_map<unsigned int, std::shared_ptr<PerThreadContext>> inuse_contexts_;
|
||||
|
||||
// reuse thread local context
|
||||
mutable std::deque<std::shared_ptr<PerThreadContext>> context_pool_;
|
||||
mutable std::deque<std::shared_ptr<PerThreadContext>> retired_context_pool_;
|
||||
mutable OrtMutex context_pool_mutex_;
|
||||
|
||||
PerThreadContext& GetPerThreadContext() const;
|
||||
|
|
|
|||
Loading…
Reference in a new issue