Revert "[CUDA EP] remove per-thread allocator (#5415)" (#5647)

This reverts commit b4869926d3 because it broke our multiple GPU test pipeline.
This commit is contained in:
Changming Sun 2020-10-30 13:58:33 -07:00 committed by GitHub
parent 2c63196600
commit 3e71e8bd7e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 26 deletions

View file

@ -57,10 +57,23 @@ ONNX_OPERATOR_KERNEL_EX(
} // namespace cuda
CUDAExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id) {
CUDAExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy) {
CUDA_CALL_THROW(cudaSetDevice(device_id));
CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_));
CUDNN_CALL_THROW(cudnnCreate(&cudnn_handle_));
AllocatorCreationInfo default_memory_info(
[](OrtDevice::DeviceId id) {
return onnxruntime::make_unique<CUDAAllocator>(id, CUDA);
},
device_id,
true,
{cuda_mem_limit,
static_cast<int>(arena_extend_strategy),
-1, -1});
// CUDA malloc/free is expensive so always use an arena
allocator_ = CreateAllocator(default_memory_info);
}
CUDAExecutionProvider::PerThreadContext::~PerThreadContext() {
@ -202,7 +215,7 @@ CUDAExecutionProvider::PerThreadContext& CUDAExecutionProvider::GetPerThreadCont
// get or create a context
if (context_state_.retired_context_pool.empty()) {
context = std::make_shared<PerThreadContext>(device_id_);
context = std::make_shared<PerThreadContext>(device_id_, cuda_mem_limit_, arena_extend_strategy_);
} else {
context = context_state_.retired_context_pool.back();
context_state_.retired_context_pool.pop_back();
@ -238,6 +251,17 @@ void CUDAExecutionProvider::ReleasePerThreadContext() const {
per_thread_context_cache->erase(cached_context_it);
}
AllocatorPtr CUDAExecutionProvider::GetAllocator(int id, OrtMemType mem_type) const {
// Pinned memory allocator is shared between threads, but CUDA memory allocator is per-thread or it may cause result changes
// A hypothesis is that arena allocator is not aligned with CUDA output cache, and data from different kernel writes may
// cause cacheline to contain dirty data.
if (mem_type == OrtMemTypeDefault) {
return GetPerThreadContext().GetAllocator();
} else {
return IExecutionProvider::GetAllocator(id, mem_type);
}
}
Status CUDAExecutionProvider::Sync() const {
CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize());
return Status::OK();

View file

@ -34,6 +34,8 @@ class CUDAExecutionProvider : public IExecutionProvider {
explicit CUDAExecutionProvider(const CUDAExecutionProviderInfo& info);
virtual ~CUDAExecutionProvider();
AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const override;
Status Sync() const override;
Status OnRunStart() override;
@ -55,24 +57,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
template <typename T>
const T* GetConstOnes(size_t count) {
if (std::is_same<T, float>::value) {
if (!constant_ones_float_) {
constant_ones_float_ = cuda::CreateConstantOnes<float>();
}
return reinterpret_cast<const T*>(constant_ones_float_->GetBuffer(count));
} else if (std::is_same<T, double>::value) {
if (!constant_ones_double_) {
constant_ones_double_ = cuda::CreateConstantOnes<double>();
}
return reinterpret_cast<const T*>(constant_ones_double_->GetBuffer(count));
} else if (std::is_same<T, half>::value) {
if (!constant_ones_half_) {
constant_ones_half_ = cuda::CreateConstantOnes<half>();
}
return reinterpret_cast<const T*>(constant_ones_half_->GetBuffer(count));
} else {
return nullptr;
}
return GetPerThreadContext().template GetConstOnes<T>(count);
}
void AddDeferredReleaseCPUPtr(void* p);
@ -97,7 +82,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
int GetCudnnConvAlgo() const { return cudnn_conv_algo_; }
void UpdateProviderOptionsInfo();
private:
private:
OrtDevice::DeviceId device_id_;
cudaDeviceProp device_prop_;
size_t cuda_mem_limit_;
@ -113,13 +98,9 @@ class CUDAExecutionProvider : public IExecutionProvider {
std::unordered_map<cudaEvent_t, DeferredReleaseCPUPtrs> deferred_release_cpu_ptr_;
OrtMutex deferred_release_cpu_ptr_mutex_;
std::unique_ptr<cuda::IConstantBuffer<float>> constant_ones_float_;
std::unique_ptr<cuda::IConstantBuffer<double>> constant_ones_double_;
std::unique_ptr<cuda::IConstantBuffer<half>> constant_ones_half_;
class PerThreadContext final {
public:
PerThreadContext(OrtDevice::DeviceId device_id);
PerThreadContext(OrtDevice::DeviceId device_id, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy);
~PerThreadContext();
cublasHandle_t CublasHandle() const {
@ -134,6 +115,32 @@ class CUDAExecutionProvider : public IExecutionProvider {
return current_deferred_release_event_;
}
template <typename T>
const T* GetConstOnes(size_t count) {
if (std::is_same<T, float>::value) {
if (!constant_ones_float_) {
constant_ones_float_ = cuda::CreateConstantOnes<float>();
}
return reinterpret_cast<const T*>(constant_ones_float_->GetBuffer(count));
} else if (std::is_same<T, double>::value) {
if (!constant_ones_double_) {
constant_ones_double_ = cuda::CreateConstantOnes<double>();
}
return reinterpret_cast<const T*>(constant_ones_double_->GetBuffer(count));
} else if (std::is_same<T, half>::value) {
if (!constant_ones_half_) {
constant_ones_half_ = cuda::CreateConstantOnes<half>();
}
return reinterpret_cast<const T*>(constant_ones_half_->GetBuffer(count));
} else {
return nullptr;
}
}
AllocatorPtr GetAllocator() const {
return allocator_;
}
private:
cublasHandle_t cublas_handle_ = nullptr;
cudnnHandle_t cudnn_handle_ = nullptr;
@ -142,6 +149,12 @@ class CUDAExecutionProvider : public IExecutionProvider {
// note that cudaEvent will be assigned at OnRunEnd() when PerThreadContext destory
// so the ownership is passed to deferred_release_cpu_ptr_
cudaEvent_t current_deferred_release_event_ = nullptr;
std::unique_ptr<cuda::IConstantBuffer<float>> constant_ones_float_;
std::unique_ptr<cuda::IConstantBuffer<double>> constant_ones_double_;
std::unique_ptr<cuda::IConstantBuffer<half>> constant_ones_half_;
AllocatorPtr allocator_;
};
using PerThreadContextMap = std::unordered_map<const CUDAExecutionProvider*, std::weak_ptr<PerThreadContext>>;