mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
[CUDA EP] remove per-thread allocator (#5415)
Now that we are using legacy default stream, which is shared among all inference threads, there is no need to have per-thread allocator. In the past, the race could happen when two threads running concurrently on GPU: thread1: allocA->copyA->computeA->freeA thread2: allocB->copyB->computeB->freeB Note that freeA/B only means the buffer is ready to be allocated on CPU, while the corresponding operation on GPU is not finished yet. It is possible for thread1/2 use the same buffer, when the alloc/free pair are not interleaved (note that alloc/free is thread-safe) If the GPU commands run in separate per-thread default stream, there's a chance that copyA/computeA are interleaved with copyB/computeB, even when the order in CPU execution is not interleaved. This would cause incorrect results if computeB uses copyA's results. By using one legacy default stream, CPU execution order would match the GPU execution order, so if A and B use the same buffer from alloc, the correpsonding copy/compute won't be interleaved. If the copy/compute is indeed interleaved, then allocA and allocB would return different buffers, thus no racing either.
This commit is contained in:
parent
2e1fa3ccb7
commit
b4869926d3
2 changed files with 26 additions and 63 deletions
|
|
@ -57,23 +57,10 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
|
||||
} // namespace cuda
|
||||
|
||||
CUDAExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy) {
|
||||
CUDAExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id) {
|
||||
CUDA_CALL_THROW(cudaSetDevice(device_id));
|
||||
CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_));
|
||||
CUDNN_CALL_THROW(cudnnCreate(&cudnn_handle_));
|
||||
|
||||
AllocatorCreationInfo default_memory_info(
|
||||
[](OrtDevice::DeviceId id) {
|
||||
return onnxruntime::make_unique<CUDAAllocator>(id, CUDA);
|
||||
},
|
||||
device_id,
|
||||
true,
|
||||
{cuda_mem_limit,
|
||||
static_cast<int>(arena_extend_strategy),
|
||||
-1, -1});
|
||||
|
||||
// CUDA malloc/free is expensive so always use an arena
|
||||
allocator_ = CreateAllocator(default_memory_info);
|
||||
}
|
||||
|
||||
CUDAExecutionProvider::PerThreadContext::~PerThreadContext() {
|
||||
|
|
@ -215,7 +202,7 @@ CUDAExecutionProvider::PerThreadContext& CUDAExecutionProvider::GetPerThreadCont
|
|||
|
||||
// get or create a context
|
||||
if (context_state_.retired_context_pool.empty()) {
|
||||
context = std::make_shared<PerThreadContext>(device_id_, cuda_mem_limit_, arena_extend_strategy_);
|
||||
context = std::make_shared<PerThreadContext>(device_id_);
|
||||
} else {
|
||||
context = context_state_.retired_context_pool.back();
|
||||
context_state_.retired_context_pool.pop_back();
|
||||
|
|
@ -251,17 +238,6 @@ void CUDAExecutionProvider::ReleasePerThreadContext() const {
|
|||
per_thread_context_cache->erase(cached_context_it);
|
||||
}
|
||||
|
||||
AllocatorPtr CUDAExecutionProvider::GetAllocator(int id, OrtMemType mem_type) const {
|
||||
// Pinned memory allocator is shared between threads, but CUDA memory allocator is per-thread or it may cause result changes
|
||||
// A hypothesis is that arena allocator is not aligned with CUDA output cache, and data from different kernel writes may
|
||||
// cause cacheline to contain dirty data.
|
||||
if (mem_type == OrtMemTypeDefault) {
|
||||
return GetPerThreadContext().GetAllocator();
|
||||
} else {
|
||||
return IExecutionProvider::GetAllocator(id, mem_type);
|
||||
}
|
||||
}
|
||||
|
||||
Status CUDAExecutionProvider::Sync() const {
|
||||
CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize());
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -34,8 +34,6 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
explicit CUDAExecutionProvider(const CUDAExecutionProviderInfo& info);
|
||||
virtual ~CUDAExecutionProvider();
|
||||
|
||||
AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const override;
|
||||
|
||||
Status Sync() const override;
|
||||
|
||||
Status OnRunStart() override;
|
||||
|
|
@ -57,7 +55,24 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
|
||||
template <typename T>
|
||||
const T* GetConstOnes(size_t count) {
|
||||
return GetPerThreadContext().template GetConstOnes<T>(count);
|
||||
if (std::is_same<T, float>::value) {
|
||||
if (!constant_ones_float_) {
|
||||
constant_ones_float_ = cuda::CreateConstantOnes<float>();
|
||||
}
|
||||
return reinterpret_cast<const T*>(constant_ones_float_->GetBuffer(count));
|
||||
} else if (std::is_same<T, double>::value) {
|
||||
if (!constant_ones_double_) {
|
||||
constant_ones_double_ = cuda::CreateConstantOnes<double>();
|
||||
}
|
||||
return reinterpret_cast<const T*>(constant_ones_double_->GetBuffer(count));
|
||||
} else if (std::is_same<T, half>::value) {
|
||||
if (!constant_ones_half_) {
|
||||
constant_ones_half_ = cuda::CreateConstantOnes<half>();
|
||||
}
|
||||
return reinterpret_cast<const T*>(constant_ones_half_->GetBuffer(count));
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void AddDeferredReleaseCPUPtr(void* p);
|
||||
|
|
@ -82,7 +97,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
int GetCudnnConvAlgo() const { return cudnn_conv_algo_; }
|
||||
void UpdateProviderOptionsInfo();
|
||||
|
||||
private:
|
||||
private:
|
||||
OrtDevice::DeviceId device_id_;
|
||||
cudaDeviceProp device_prop_;
|
||||
size_t cuda_mem_limit_;
|
||||
|
|
@ -98,9 +113,13 @@ private:
|
|||
std::unordered_map<cudaEvent_t, DeferredReleaseCPUPtrs> deferred_release_cpu_ptr_;
|
||||
OrtMutex deferred_release_cpu_ptr_mutex_;
|
||||
|
||||
std::unique_ptr<cuda::IConstantBuffer<float>> constant_ones_float_;
|
||||
std::unique_ptr<cuda::IConstantBuffer<double>> constant_ones_double_;
|
||||
std::unique_ptr<cuda::IConstantBuffer<half>> constant_ones_half_;
|
||||
|
||||
class PerThreadContext final {
|
||||
public:
|
||||
PerThreadContext(OrtDevice::DeviceId device_id, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy);
|
||||
PerThreadContext(OrtDevice::DeviceId device_id);
|
||||
~PerThreadContext();
|
||||
|
||||
cublasHandle_t CublasHandle() const {
|
||||
|
|
@ -115,32 +134,6 @@ private:
|
|||
return current_deferred_release_event_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* GetConstOnes(size_t count) {
|
||||
if (std::is_same<T, float>::value) {
|
||||
if (!constant_ones_float_) {
|
||||
constant_ones_float_ = cuda::CreateConstantOnes<float>();
|
||||
}
|
||||
return reinterpret_cast<const T*>(constant_ones_float_->GetBuffer(count));
|
||||
} else if (std::is_same<T, double>::value) {
|
||||
if (!constant_ones_double_) {
|
||||
constant_ones_double_ = cuda::CreateConstantOnes<double>();
|
||||
}
|
||||
return reinterpret_cast<const T*>(constant_ones_double_->GetBuffer(count));
|
||||
} else if (std::is_same<T, half>::value) {
|
||||
if (!constant_ones_half_) {
|
||||
constant_ones_half_ = cuda::CreateConstantOnes<half>();
|
||||
}
|
||||
return reinterpret_cast<const T*>(constant_ones_half_->GetBuffer(count));
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
AllocatorPtr GetAllocator() const {
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
private:
|
||||
cublasHandle_t cublas_handle_ = nullptr;
|
||||
cudnnHandle_t cudnn_handle_ = nullptr;
|
||||
|
|
@ -149,12 +142,6 @@ private:
|
|||
// note that cudaEvent will be assigned at OnRunEnd() when PerThreadContext destory
|
||||
// so the ownership is passed to deferred_release_cpu_ptr_
|
||||
cudaEvent_t current_deferred_release_event_ = nullptr;
|
||||
|
||||
std::unique_ptr<cuda::IConstantBuffer<float>> constant_ones_float_;
|
||||
std::unique_ptr<cuda::IConstantBuffer<double>> constant_ones_double_;
|
||||
std::unique_ptr<cuda::IConstantBuffer<half>> constant_ones_half_;
|
||||
|
||||
AllocatorPtr allocator_;
|
||||
};
|
||||
|
||||
using PerThreadContextMap = std::unordered_map<const CUDAExecutionProvider*, std::weak_ptr<PerThreadContext>>;
|
||||
|
|
|
|||
Loading…
Reference in a new issue