From cf5a4b5856594ad94f8b669c5be6929d452cd70a Mon Sep 17 00:00:00 2001 From: Ke Zhang Date: Mon, 29 Jul 2019 15:01:29 -0700 Subject: [PATCH] remove the GetStream from cuda ep. (#1514) * remove the GetStream from cuda ep. * fix comments --- onnxruntime/core/providers/cuda/cuda_execution_provider.cc | 7 ------- onnxruntime/core/providers/cuda/cuda_execution_provider.h | 6 ------ onnxruntime/core/providers/cuda/gpu_data_transfer.cc | 5 +++++ onnxruntime/core/providers/cuda/gpu_data_transfer.h | 1 + 4 files changed, 6 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index bd7544396c..71b8777788 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -64,10 +64,6 @@ CUDAExecutionProvider::PerThreadContext::~PerThreadContext() { CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kCudaExecutionProvider}, device_id_(info.device_id) { CUDA_CALL_THROW(cudaSetDevice(device_id_)); - // create streams, default is nullptr - streams_[kCudaStreamDefault] = nullptr; - CUDA_CALL_THROW(cudaStreamCreateWithFlags(&streams_[kCudaStreamCopyIn], cudaStreamNonBlocking)); - CUDA_CALL_THROW(cudaStreamCreateWithFlags(&streams_[kCudaStreamCopyOut], cudaStreamNonBlocking)); DeviceAllocatorRegistrationInfo default_allocator_info( {OrtMemTypeDefault, [](int id) { return std::make_unique(id); }, std::numeric_limits::max()}); @@ -93,9 +89,6 @@ CUDAExecutionProvider::~CUDAExecutionProvider() { CUDA_CALL_THROW(cudaEventDestroy(e)); it = deferred_release_cpu_ptr_.erase(it); } - CUDA_CALL_THROW(cudaStreamDestroy(streams_[kCudaStreamCopyIn])); - CUDA_CALL_THROW(cudaStreamDestroy(streams_[kCudaStreamCopyOut])); - ReleasePerThreadStuffs(); } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index ed3a509f85..bd6e25b18b 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -40,11 +40,6 @@ class CUDAExecutionProvider : public IExecutionProvider { return GetPerThreadContext().CudnnHandle(); } - cudaStream_t GetStream(int queue_id) const { - ORT_ENFORCE(queue_id >= 0 && queue_id < kTotalCudaStreams); - return streams_[queue_id]; - } - template const T* GetConstOnes(size_t count) { return GetPerThreadContext().template GetConstOnes(count); @@ -69,7 +64,6 @@ class CUDAExecutionProvider : public IExecutionProvider { int GetDeviceId() const { return device_id_; } private: - cudaStream_t streams_[kTotalCudaStreams]; int device_id_; struct DeferredReleaseCPUPtrs { diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc index ec930946aa..8fae7ae8b0 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc @@ -12,6 +12,11 @@ GPUDataTransfer::GPUDataTransfer() { CUDA_CALL_THROW(cudaStreamCreateWithFlags(&streams_[kCudaStreamCopyOut], cudaStreamNonBlocking)); } +GPUDataTransfer::~GPUDataTransfer() { + CUDA_CALL(cudaStreamDestroy(streams_[kCudaStreamCopyIn])); + CUDA_CALL(cudaStreamDestroy(streams_[kCudaStreamCopyOut])); +} + bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::CUDA_PINNED || dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED; diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.h b/onnxruntime/core/providers/cuda/gpu_data_transfer.h index f0acae0f8d..0f3d4687eb 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.h +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.h @@ -18,6 +18,7 @@ enum CUDAStreamType : int { class GPUDataTransfer : public IDataTransfer { public: GPUDataTransfer(); + ~GPUDataTransfer(); bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override;