remove the GetStream from cuda ep. (#1514)

* remove the GetStream from cuda ep.

* fix comments
This commit is contained in:
Ke Zhang 2019-07-29 15:01:29 -07:00 committed by GitHub
parent d6a30485be
commit cf5a4b5856
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 6 additions and 13 deletions

View file

@ -64,10 +64,6 @@ CUDAExecutionProvider::PerThreadContext::~PerThreadContext() {
CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& info)
: IExecutionProvider{onnxruntime::kCudaExecutionProvider}, device_id_(info.device_id) {
CUDA_CALL_THROW(cudaSetDevice(device_id_));
// create streams, default is nullptr
streams_[kCudaStreamDefault] = nullptr;
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&streams_[kCudaStreamCopyIn], cudaStreamNonBlocking));
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&streams_[kCudaStreamCopyOut], cudaStreamNonBlocking));
DeviceAllocatorRegistrationInfo default_allocator_info(
{OrtMemTypeDefault, [](int id) { return std::make_unique<CUDAAllocator>(id); }, std::numeric_limits<size_t>::max()});
@ -93,9 +89,6 @@ CUDAExecutionProvider::~CUDAExecutionProvider() {
CUDA_CALL_THROW(cudaEventDestroy(e));
it = deferred_release_cpu_ptr_.erase(it);
}
CUDA_CALL_THROW(cudaStreamDestroy(streams_[kCudaStreamCopyIn]));
CUDA_CALL_THROW(cudaStreamDestroy(streams_[kCudaStreamCopyOut]));
ReleasePerThreadStuffs();
}

View file

@ -40,11 +40,6 @@ class CUDAExecutionProvider : public IExecutionProvider {
return GetPerThreadContext().CudnnHandle();
}
cudaStream_t GetStream(int queue_id) const {
ORT_ENFORCE(queue_id >= 0 && queue_id < kTotalCudaStreams);
return streams_[queue_id];
}
template <typename T>
const T* GetConstOnes(size_t count) {
return GetPerThreadContext().template GetConstOnes<T>(count);
@ -69,7 +64,6 @@ class CUDAExecutionProvider : public IExecutionProvider {
int GetDeviceId() const { return device_id_; }
private:
cudaStream_t streams_[kTotalCudaStreams];
int device_id_;
struct DeferredReleaseCPUPtrs {

View file

@ -12,6 +12,11 @@ GPUDataTransfer::GPUDataTransfer() {
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&streams_[kCudaStreamCopyOut], cudaStreamNonBlocking));
}
GPUDataTransfer::~GPUDataTransfer() {
CUDA_CALL(cudaStreamDestroy(streams_[kCudaStreamCopyIn]));
CUDA_CALL(cudaStreamDestroy(streams_[kCudaStreamCopyOut]));
}
bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const {
return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::CUDA_PINNED
|| dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED;

View file

@ -18,6 +18,7 @@ enum CUDAStreamType : int {
class GPUDataTransfer : public IDataTransfer {
public:
GPUDataTransfer();
~GPUDataTransfer();
bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override;