diff --git a/caffe2/core/common_cudnn.cc b/caffe2/core/common_cudnn.cc index 8b41b018add..18f741bdad9 100644 --- a/caffe2/core/common_cudnn.cc +++ b/caffe2/core/common_cudnn.cc @@ -15,13 +15,12 @@ */ #include "caffe2/core/common_cudnn.h" +#include "caffe2/core/cudnn_wrappers.h" #include "caffe2/core/init.h" namespace caffe2 { -thread_local CuDNNHandles CuDNNWrapper::tls_cudnn_handles_; - CuDNNWrapper::PerGPUCuDNNStates& CuDNNWrapper::cudnn_states() { // New it (never delete) to avoid calling the destructors on process // exit and racing against the CUDA shutdown sequence. diff --git a/caffe2/core/common_cudnn.h b/caffe2/core/common_cudnn.h index 5069eb0aaf2..f93f333befe 100644 --- a/caffe2/core/common_cudnn.h +++ b/caffe2/core/common_cudnn.h @@ -23,9 +23,7 @@ #include #include "caffe2/core/common.h" -#include "caffe2/core/common_gpu.h" #include "caffe2/core/context.h" -#include "caffe2/core/context_gpu.h" #include "caffe2/core/logging.h" #include "caffe2/core/types.h" #include "caffe2/proto/caffe2.pb.h" @@ -316,185 +314,6 @@ class cudnnFilterDescWrapper { DISABLE_COPY_AND_ASSIGN(cudnnFilterDescWrapper); }; -class CuDNNWrapper; -/** - * CuDNNHandles wraps around cudnnHandle_t so they can be - * properly destructed when threads exit. - */ -class CuDNNHandles { - friend class CuDNNWrapper; - - private: - CuDNNHandles() { - for (int i = 0; i < CAFFE2_COMPILE_TIME_MAX_GPUS; ++i) { - cudnn_handle_[i] = nullptr; - } - } - - ~CuDNNHandles() noexcept { - for (int i = 0; i < CAFFE2_COMPILE_TIME_MAX_GPUS; ++i) { - if (cudnn_handle_[i]) { - CUDNN_CHECK(cudnnDestroy(cudnn_handle_[i])); - } - } - } - - cudnnHandle_t cudnn_handle_[CAFFE2_COMPILE_TIME_MAX_GPUS]; -}; - -/** - * CuDNNWorkspace is a wrapper around a raw cuda pointer that holds the cudnn - * scratch space. This struct is meant to be only used in CuDNNWrapper to - * provide a program-wide scratch space for CuDNN. The reason behind it is that - * cudnn function calls are usually very efficient, hence one probably does not - * want to run multiple cudnn calls at the same time. As a result, one should - * not need more than one cudnn workspace per device. - */ -struct CuDNNWorkspace { - ~CuDNNWorkspace() noexcept {} - - void* get(size_t nbytes) { - if (nbytes_ < nbytes) { - reset(); - auto data_and_deleter = CUDAContext::New(nbytes); - data_ = {data_and_deleter.first, data_and_deleter.second}; - nbytes_ = nbytes; - } - CAFFE_ENFORCE_GE(nbytes_, nbytes); - return data_.get(); - } - - void reset() { - data_ = nullptr; - nbytes_ = 0; - } - - private: - std::unique_ptr data_{nullptr, NoDelete}; - size_t nbytes_{0}; -}; - -// CuDNNState is the owner of the CuDNNWorkspace, and serializes all -// executions of operations that use the state onto it's own stream -// (so multiple Net workers can reuse the same workspace from -// different threads and CUDA streams). -class CuDNNState { - public: - explicit CuDNNState(size_t gpu_id) : gpu_id_(gpu_id) { - DeviceGuard g(gpu_id_); - CUDNN_ENFORCE(cudnnCreate(&cudnn_handle_)); - CUDA_ENFORCE(cudaEventCreate(&before_)); - CUDA_ENFORCE(cudaEventCreate(&after_)); - CUDA_ENFORCE(cudaStreamCreate(&stream_)); - CUDNN_ENFORCE(cudnnSetStream(cudnn_handle_, stream_)); - } - - ~CuDNNState() noexcept { - DeviceGuard g(gpu_id_); - CUDNN_CHECK(cudnnDestroy(cudnn_handle_)); - CUDA_CHECK(cudaStreamDestroy(stream_)); - CUDA_CHECK(cudaEventDestroy(after_)); - CUDA_CHECK(cudaEventDestroy(before_)); - } - - cudnnHandle_t& cudnn_handle() { - return cudnn_handle_; - } - - CuDNNWorkspace& workspace() { - return workspace_; - } - - template - void execute(cudaStream_t stream, F&& f) { - CUDA_ENFORCE(cudaEventRecord(before_, stream)); - CUDA_ENFORCE(cudaStreamWaitEvent(stream_, before_, 0)); - f(this); - CUDA_ENFORCE(cudaEventRecord(after_, stream_)); - CUDA_ENFORCE(cudaStreamWaitEvent(stream, after_, 0)); - } - - private: - cudnnHandle_t cudnn_handle_{nullptr}; - cudaEvent_t before_{nullptr}; - cudaEvent_t after_{nullptr}; - cudaStream_t stream_{nullptr}; - CuDNNWorkspace workspace_; - size_t gpu_id_{0}; - DISABLE_COPY_AND_ASSIGN(CuDNNState); -}; - -/** - * CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces. - * - * The wrapper ensures that for each thread and each gpu, there is one - * identical cudnn handle, which is also associated with the thread-local - * per-device cuda stream. The wrapper also hosts the device-specific cudnn - * workspace (scratch space for some cudnn functions). - * - */ -class CuDNNWrapper { - public: - /** - * Creates a cudnn wrapper associated with a CUDAContext object. Note that - * the CUDAContext object should outlive the CuDNNWrapper. - */ - explicit CuDNNWrapper(CUDAContext* context) : context_(context) {} - - /** - * Returns the inline cudnn handle that executes on the current - * thread's cuda_stream. - */ - cudnnHandle_t& inline_cudnn_handle() { - int gpu_id = context_->cuda_gpu_id(); - auto& cudnn_handle_ = tls_cudnn_handles_.cudnn_handle_[gpu_id]; - if (!cudnn_handle_) { - context_->SwitchToDevice(); - CUDNN_ENFORCE(cudnnCreate(&cudnn_handle_)); - } - CUDNN_ENFORCE(cudnnSetStream(cudnn_handle_, context_->cuda_stream())); - return cudnn_handle_; - } - - // Executes the closure F on the CuDNNState associated with state_idx - template - void with_cudnn_state(size_t state_idx, F&& f) { - CAFFE_ENFORCE( - state_idx < CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES, "Invalid state_idx"); - auto& sync_state = cudnn_states()[context_->cuda_gpu_id()][state_idx]; - - DeviceGuard dg(context_->cuda_gpu_id()); - - // We need to serialize execution on the CuDNNState as we can't - // allow multiple threads to race through the cudaEventRecord - // calls (so a worker thread might wait on another worker thread's - // execution) - std::lock_guard g(sync_state.mutex); - if (!sync_state.state.get()) { - sync_state.state.reset(new CuDNNState(context_->cuda_gpu_id())); - } - CHECK_NOTNULL(sync_state.state.get())->execute(context_->cuda_stream(), f); - } - - protected: - // Pointer to an external cuda context that the cudnn wrapper will use. - CUDAContext* context_; - static thread_local CuDNNHandles tls_cudnn_handles_; - - static constexpr size_t CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES = 4; - - struct SyncedCuDNNState { - std::mutex mutex; - std::unique_ptr state; - }; - - using PerGPUCuDNNStates = std::array< - std::array, - CAFFE2_COMPILE_TIME_MAX_GPUS>; - static PerGPUCuDNNStates& cudnn_states(); - - DISABLE_COPY_AND_ASSIGN(CuDNNWrapper); -}; } // namespace caffe2 diff --git a/caffe2/core/context_gpu.h b/caffe2/core/context_gpu.h index 6d0aec48bbc..6a2d3f41b9e 100644 --- a/caffe2/core/context_gpu.h +++ b/caffe2/core/context_gpu.h @@ -20,12 +20,13 @@ #include #include +#include "caffe2/core/common_cudnn.h" #include "caffe2/core/common_gpu.h" #include "caffe2/core/context.h" +#include "caffe2/core/logging.h" #include "caffe2/core/tensor.h" #include "caffe2/core/types.h" #include "caffe2/proto/caffe2.pb.h" -#include "caffe2/core/logging.h" namespace caffe2 { @@ -52,16 +53,18 @@ CudaMemoryPoolType GetCudaMemoryPoolType(); */ class ThreadLocalCUDAObjects { friend class CUDAContext; + private: ThreadLocalCUDAObjects() { for (int i = 0; i < CAFFE2_COMPILE_TIME_MAX_GPUS; ++i) { cuda_streams_[i] = vector(); cublas_handles_[i] = vector(); + cudnn_handles_[i] = vector(); } } cudaStream_t GetStream(int gpu, int stream_id) { - vector &gpu_streams = cuda_streams_[gpu]; + vector& gpu_streams = cuda_streams_[gpu]; if (gpu_streams.size() <= stream_id) { gpu_streams.resize(stream_id + 1, nullptr); } @@ -75,7 +78,7 @@ class ThreadLocalCUDAObjects { cublasHandle_t GetHandle(int gpu, int stream_id) { DeviceGuard guard(gpu); - vector &gpu_handles = cublas_handles_[gpu]; + vector& gpu_handles = cublas_handles_[gpu]; if (gpu_handles.size() <= stream_id) { gpu_handles.resize(stream_id + 1, nullptr); } @@ -92,6 +95,20 @@ class ThreadLocalCUDAObjects { return gpu_handles[stream_id]; } + cudnnHandle_t GetCudnnHandle(int gpu, int stream_id) { + DeviceGuard guard(gpu); + vector& gpu_handles = cudnn_handles_[gpu]; + if (gpu_handles.size() <= stream_id) { + gpu_handles.resize(stream_id + 1, nullptr); + } + if (!gpu_handles[stream_id]) { + CUDNN_ENFORCE(cudnnCreate(&gpu_handles[stream_id])); + CUDNN_ENFORCE( + cudnnSetStream(gpu_handles[stream_id], GetStream(gpu, stream_id))); + } + return gpu_handles[stream_id]; + } + ~ThreadLocalCUDAObjects() noexcept { for (int i = 0; i < CAFFE2_COMPILE_TIME_MAX_GPUS; ++i) { for (auto& handle : cublas_handles_[i]) { @@ -104,10 +121,16 @@ class ThreadLocalCUDAObjects { CUDA_CHECK(cudaStreamDestroy(stream)); } } + for (auto& handle : cudnn_handles_[i]) { + if (handle) { + CUDNN_CHECK(cudnnDestroy(handle)); + } + } } } vector cuda_streams_[CAFFE2_COMPILE_TIME_MAX_GPUS]; vector cublas_handles_[CAFFE2_COMPILE_TIME_MAX_GPUS]; + vector cudnn_handles_[CAFFE2_COMPILE_TIME_MAX_GPUS]; }; class CUDAContext final { @@ -166,6 +189,10 @@ class CUDAContext final { return cuda_objects_.GetHandle(gpu_id_, stream_id_); } + cudnnHandle_t cudnn_handle() { + return cuda_objects_.GetCudnnHandle(gpu_id_, stream_id_); + } + curandGenerator_t& curand_generator() { if (!curand_generator_) { DeviceGuard guard(gpu_id_); diff --git a/caffe2/core/cudnn_wrappers.h b/caffe2/core/cudnn_wrappers.h new file mode 100644 index 00000000000..c2910e2e658 --- /dev/null +++ b/caffe2/core/cudnn_wrappers.h @@ -0,0 +1,161 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#ifndef CAFFE2_CORE_CUDNN_WRAPPERS_H_ +#define CAFFE2_CORE_CUDNN_WRAPPERS_H_ + +#include "caffe2/core/common_cudnn.h" +#include "caffe2/core/context_gpu.h" + +namespace caffe2 { + +class CuDNNWrapper; + +/** + * CuDNNWorkspace is a wrapper around a raw cuda pointer that holds the cudnn + * scratch space. This struct is meant to be only used in CuDNNWrapper to + * provide a program-wide scratch space for CuDNN. The reason behind it is that + * cudnn function calls are usually very efficient, hence one probably does not + * want to run multiple cudnn calls at the same time. As a result, one should + * not need more than one cudnn workspace per device. + */ +struct CuDNNWorkspace { + ~CuDNNWorkspace() noexcept {} + + void* get(size_t nbytes) { + if (nbytes_ < nbytes) { + reset(); + auto data_and_deleter = CUDAContext::New(nbytes); + data_ = {data_and_deleter.first, data_and_deleter.second}; + nbytes_ = nbytes; + } + CAFFE_ENFORCE_GE(nbytes_, nbytes); + return data_.get(); + } + + void reset() { + data_ = nullptr; + nbytes_ = 0; + } + + private: + std::unique_ptr data_{nullptr, NoDelete}; + size_t nbytes_{0}; +}; + +// CuDNNState is the owner of the CuDNNWorkspace, and serializes all +// executions of operations that use the state onto it's own stream +// (so multiple Net workers can reuse the same workspace from +// different threads and CUDA streams). +class CuDNNState { + public: + explicit CuDNNState(size_t gpu_id) : gpu_id_(gpu_id) { + DeviceGuard g(gpu_id_); + CUDNN_ENFORCE(cudnnCreate(&cudnn_handle_)); + CUDA_ENFORCE(cudaEventCreate(&before_)); + CUDA_ENFORCE(cudaEventCreate(&after_)); + CUDA_ENFORCE(cudaStreamCreate(&stream_)); + CUDNN_ENFORCE(cudnnSetStream(cudnn_handle_, stream_)); + } + + ~CuDNNState() noexcept { + DeviceGuard g(gpu_id_); + CUDNN_CHECK(cudnnDestroy(cudnn_handle_)); + CUDA_CHECK(cudaStreamDestroy(stream_)); + CUDA_CHECK(cudaEventDestroy(after_)); + CUDA_CHECK(cudaEventDestroy(before_)); + } + + cudnnHandle_t& cudnn_handle() { + return cudnn_handle_; + } + + CuDNNWorkspace& workspace() { + return workspace_; + } + + template + void execute(cudaStream_t stream, F&& f) { + CUDA_ENFORCE(cudaEventRecord(before_, stream)); + CUDA_ENFORCE(cudaStreamWaitEvent(stream_, before_, 0)); + f(this); + CUDA_ENFORCE(cudaEventRecord(after_, stream_)); + CUDA_ENFORCE(cudaStreamWaitEvent(stream, after_, 0)); + } + + private: + cudnnHandle_t cudnn_handle_{nullptr}; + cudaEvent_t before_{nullptr}; + cudaEvent_t after_{nullptr}; + cudaStream_t stream_{nullptr}; + CuDNNWorkspace workspace_; + size_t gpu_id_{0}; + DISABLE_COPY_AND_ASSIGN(CuDNNState); +}; + +/** + * CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces. + * + * The wrapper ensures that for each thread and each gpu, there is one + * identical cudnn handle, which is also associated with the thread-local + * per-device cuda stream. The wrapper also hosts the device-specific cudnn + * workspace (scratch space for some cudnn functions). + * + */ +class CuDNNWrapper { + public: + /** + * Creates a cudnn wrapper associated with a CUDAContext object. Note that + * the CUDAContext object should outlive the CuDNNWrapper. + */ + explicit CuDNNWrapper(CUDAContext* context) : context_(context) {} + + /** + * Returns the inline cudnn handle that executes on the current + * thread's cuda_stream. + */ + cudnnHandle_t inline_cudnn_handle() { + return context_->cudnn_handle(); + } + + // Executes the closure F on the CuDNNState associated with state_idx + template + void with_cudnn_state(size_t state_idx, F&& f) { + CAFFE_ENFORCE( + state_idx < CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES, "Invalid state_idx"); + auto& sync_state = cudnn_states()[context_->cuda_gpu_id()][state_idx]; + + DeviceGuard dg(context_->cuda_gpu_id()); + + // We need to serialize execution on the CuDNNState as we can't + // allow multiple threads to race through the cudaEventRecord + // calls (so a worker thread might wait on another worker thread's + // execution) + std::lock_guard g(sync_state.mutex); + if (!sync_state.state.get()) { + sync_state.state.reset(new CuDNNState(context_->cuda_gpu_id())); + } + CHECK_NOTNULL(sync_state.state.get())->execute(context_->cuda_stream(), f); + } + + protected: + // Pointer to an external cuda context that the cudnn wrapper will use. + CUDAContext* context_; + + static constexpr size_t CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES = 4; + + struct SyncedCuDNNState { + std::mutex mutex; + std::unique_ptr state; + }; + + using PerGPUCuDNNStates = std::array< + std::array, + CAFFE2_COMPILE_TIME_MAX_GPUS>; + static PerGPUCuDNNStates& cudnn_states(); + + DISABLE_COPY_AND_ASSIGN(CuDNNWrapper); +}; + +}; // namespace caffe2 + +#endif diff --git a/caffe2/operators/conv_op_cudnn.cc b/caffe2/operators/conv_op_cudnn.cc index e2e6ee7bf92..3cb6d61005a 100644 --- a/caffe2/operators/conv_op_cudnn.cc +++ b/caffe2/operators/conv_op_cudnn.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "caffe2/core/common_cudnn.h" #include "caffe2/core/context_gpu.h" +#include "caffe2/core/cudnn_wrappers.h" #include "caffe2/operators/conv_op.h" #include "caffe2/operators/conv_op_cache_cudnn.h" #include "caffe2/operators/conv_pool_op_base.h" diff --git a/caffe2/operators/conv_transpose_op_cudnn.cc b/caffe2/operators/conv_transpose_op_cudnn.cc index 0a9f5c7c4a5..e25dd4d83c2 100644 --- a/caffe2/operators/conv_transpose_op_cudnn.cc +++ b/caffe2/operators/conv_transpose_op_cudnn.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "caffe2/core/common_cudnn.h" #include "caffe2/core/context_gpu.h" +#include "caffe2/core/cudnn_wrappers.h" #include "caffe2/operators/conv_op_cache_cudnn.h" #include "caffe2/operators/conv_transpose_op.h" diff --git a/caffe2/operators/dropout_op_cudnn.cc b/caffe2/operators/dropout_op_cudnn.cc index 3d2c950a34b..5182b0f0240 100644 --- a/caffe2/operators/dropout_op_cudnn.cc +++ b/caffe2/operators/dropout_op_cudnn.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "caffe2/core/common_cudnn.h" #include "caffe2/core/context_gpu.h" +#include "caffe2/core/cudnn_wrappers.h" #include "caffe2/core/operator.h" #include "caffe2/core/types.h" diff --git a/caffe2/operators/local_response_normalization_op_cudnn.cc b/caffe2/operators/local_response_normalization_op_cudnn.cc index f3bb33e322d..fbfc457e50c 100644 --- a/caffe2/operators/local_response_normalization_op_cudnn.cc +++ b/caffe2/operators/local_response_normalization_op_cudnn.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "caffe2/core/common_cudnn.h" #include "caffe2/core/context_gpu.h" +#include "caffe2/core/cudnn_wrappers.h" #include "caffe2/core/operator.h" #include "caffe2/core/types.h" diff --git a/caffe2/operators/pool_op_cudnn.cu b/caffe2/operators/pool_op_cudnn.cu index 8fb8d7d7505..5c18c4a13e3 100644 --- a/caffe2/operators/pool_op_cudnn.cu +++ b/caffe2/operators/pool_op_cudnn.cu @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "caffe2/core/common_cudnn.h" #include "caffe2/core/context_gpu.h" +#include "caffe2/core/cudnn_wrappers.h" #include "caffe2/operators/conv_pool_op_base.h" #include diff --git a/caffe2/operators/recurrent_op_cudnn.h b/caffe2/operators/recurrent_op_cudnn.h index a1c681ecaf9..288554e67d7 100644 --- a/caffe2/operators/recurrent_op_cudnn.h +++ b/caffe2/operators/recurrent_op_cudnn.h @@ -17,9 +17,9 @@ #ifndef CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_ #define CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_ -#include "caffe2/core/common_cudnn.h" #include "caffe2/core/context.h" #include "caffe2/core/context_gpu.h" +#include "caffe2/core/cudnn_wrappers.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" diff --git a/caffe2/operators/relu_op_cudnn.cc b/caffe2/operators/relu_op_cudnn.cc index 64a0a3ab5b4..a0eb6a0fe39 100644 --- a/caffe2/operators/relu_op_cudnn.cc +++ b/caffe2/operators/relu_op_cudnn.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "caffe2/core/common_cudnn.h" #include "caffe2/core/context_gpu.h" +#include "caffe2/core/cudnn_wrappers.h" #include "caffe2/core/operator.h" #include "caffe2/core/types.h" diff --git a/caffe2/operators/softmax_op_cudnn.cc b/caffe2/operators/softmax_op_cudnn.cc index 6031844955d..0d5c77df256 100644 --- a/caffe2/operators/softmax_op_cudnn.cc +++ b/caffe2/operators/softmax_op_cudnn.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "caffe2/core/common_cudnn.h" #include "caffe2/core/context_gpu.h" +#include "caffe2/core/cudnn_wrappers.h" #include "caffe2/core/types.h" #include "caffe2/operators/softmax_op.h" diff --git a/caffe2/operators/spatial_batch_norm_op_cudnn.cc b/caffe2/operators/spatial_batch_norm_op_cudnn.cc index 86688da873f..af43109891e 100644 --- a/caffe2/operators/spatial_batch_norm_op_cudnn.cc +++ b/caffe2/operators/spatial_batch_norm_op_cudnn.cc @@ -16,8 +16,8 @@ #include -#include "caffe2/core/common_cudnn.h" #include "caffe2/core/context_gpu.h" +#include "caffe2/core/cudnn_wrappers.h" #include "caffe2/operators/spatial_batch_norm_op.h" #include "caffe2/utils/math.h" diff --git a/caffe2/operators/transpose_op_cudnn.cc b/caffe2/operators/transpose_op_cudnn.cc index 01376ded9a9..98f15305b77 100644 --- a/caffe2/operators/transpose_op_cudnn.cc +++ b/caffe2/operators/transpose_op_cudnn.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "caffe2/core/common_cudnn.h" #include "caffe2/core/context_gpu.h" +#include "caffe2/core/cudnn_wrappers.h" #include "caffe2/core/types.h" #include "caffe2/operators/transpose_op.h" #include "caffe2/operators/transpose_op_gpu.h"