From d35f365ad578110e10afce56574e2d00bf59e727 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 17 Aug 2018 11:07:31 -0700 Subject: [PATCH] Remove all cuDNN specific inputs to RNN functions (#10581) Summary: This is still not the final PR, but it removes all blockers for actually using the RNN functions directly in the JIT. Next patch should be final, and will actually remove the symbolic_override code, and change it to proper symbolics for those ATen functions. Turns out the symbolic code can be also cleaned up a bit, and I'll do that too. zdevito ezyang colesbury (for minor DispatchStub.h) changes There was no way to handle those in the JIT for now, and they turned out to be completely unnecessary. It should make the Python and C++ module code much simpler too, since all the logic is now centralized in the native functions. The downside is that RNN modules no longer own their dropout buffers, which are shared per-device instead (with appropriate locking and synchronization). This might appear as a perf regression at first, but in reality it's highly unlikely that anyone will want to run cuDNN RNNs on the same GPU in parallel. Pull Request resolved: https://github.com/pytorch/pytorch/pull/10581 Reviewed By: colesbury Differential Revision: D9365541 Pulled By: apaszke fbshipit-source-id: 3ef8677ee5481bae60c74a9117a2508665b476b5 --- aten/src/ATen/cuda/CUDAEvent.cpp | 66 +++++ aten/src/ATen/cuda/CUDAEvent.h | 84 +++++++ aten/src/ATen/cuda/CUDAStream.cpp | 9 + aten/src/ATen/cuda/CUDAStream.h | 5 +- aten/src/ATen/native/DispatchStub.h | 35 ++- aten/src/ATen/native/RNN.cpp | 147 +++-------- aten/src/ATen/native/RNN.h | 23 ++ aten/src/ATen/native/cudnn/RNN.cpp | 272 ++++++++++++++++++++- aten/src/ATen/native/native_functions.yaml | 16 +- torch/nn/_functions/rnn.py | 18 +- torch/nn/modules/rnn.py | 16 -- torch/onnx/symbolic.py | 1 + 12 files changed, 531 insertions(+), 161 deletions(-) create mode 100644 aten/src/ATen/cuda/CUDAEvent.cpp create mode 100644 aten/src/ATen/cuda/CUDAEvent.h create mode 100644 aten/src/ATen/native/RNN.h diff --git a/aten/src/ATen/cuda/CUDAEvent.cpp b/aten/src/ATen/cuda/CUDAEvent.cpp new file mode 100644 index 00000000000..ab6c8421816 --- /dev/null +++ b/aten/src/ATen/cuda/CUDAEvent.cpp @@ -0,0 +1,66 @@ +#include "ATen/cuda/CUDAEvent.h" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/CUDAStream.h" +#include "ATen/cuda/Exceptions.h" +#include "ATen/core/Error.h" + +#include +#include + +// Internal implementation is entirely hidden +struct CUDAEventInternals { + std::atomic refcount; + int64_t device; // Note: cudaGetDevice works with int32_t, not int64_t + cudaEvent_t event; +}; + +namespace at { +namespace cuda { + +namespace detail { + +/* +* Pointer-based event API +*/ +CUDAEventInternals* CUDAEvent_create(unsigned int flags) { + std::unique_ptr internals { new CUDAEventInternals() }; + internals->refcount = 1; + internals->device = current_device(); + AT_CUDA_CHECK(cudaEventCreateWithFlags(&internals->event, flags)); + return internals.release(); +} + +void CUDAEvent_retain(CUDAEventInternals* internals) { + internals->refcount++; +} + +void CUDAEvent_uncheckedFree(CUDAEventInternals* internals) { + if (--internals->refcount) { + cudaEventDestroy(internals->event); + } +} +cudaEvent_t CUDAEvent_event(CUDAEventInternals* internals) { + return internals->event; +} + +int64_t CUDAEvent_device(CUDAEventInternals* internals) { + return internals->device; +} + +void CUDAEvent_record(CUDAEventInternals* internals, const CUDAStream& stream) { + AT_CUDA_CHECK(cudaEventRecord(internals->event, stream)); +} + +} // namespace detail + +void CUDAEvent::record() const { + record(getCurrentCUDAStream()); +} + +void CUDAEvent::record(const CUDAStream& stream) const { + detail::CUDAEvent_record(internals_, stream); +} + + +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/CUDAEvent.h b/aten/src/ATen/cuda/CUDAEvent.h new file mode 100644 index 00000000000..7fd47e6a562 --- /dev/null +++ b/aten/src/ATen/cuda/CUDAEvent.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include + +#include "cuda_runtime_api.h" + +#include +#include + +/* +* A CUDA event interface with no CUDA build dependency. +* +* Includes the CUDAEvent RAII class and a pointer-based event API. +*/ + +struct CUDAEventInternals; + +namespace at { +namespace cuda { + +struct CUDAStream; + +namespace detail { + +// Pointer-based API (for internal use) +// Note: ATen/Context is preferred to work with streams safely +AT_API CUDAEventInternals* CUDAEvent_create(unsigned int flags); +AT_API void CUDAEvent_retain(CUDAEventInternals* internals); +AT_API void CUDAEvent_uncheckedFree(CUDAEventInternals* internals); +AT_API cudaEvent_t CUDAEvent_event(CUDAEventInternals* internals); +AT_API int64_t CUDAEvent_device(CUDAEventInternals* internals); + +} // namespace detail + +struct CUDAEvent { + // Constants + static constexpr unsigned int DEFAULT_FLAGS = cudaEventDisableTiming; + + // Constructors + CUDAEvent(unsigned int flags = DEFAULT_FLAGS) + : internals_(detail::CUDAEvent_create(flags)) {} + + ~CUDAEvent() { detail::CUDAEvent_uncheckedFree(internals_); } + + CUDAEvent(const CUDAEvent& other) { + detail::CUDAEvent_retain(other.internals_); + internals_ = other.internals_; + } + + CUDAEvent(CUDAEvent&& other) { + std::swap(internals_, other.internals_); + } + + CUDAEvent& operator=(CUDAEvent other) noexcept { + std::swap(internals_, other.internals_); + return *this; + } + + explicit operator bool() const noexcept { + return internals_ != nullptr; + } + + operator cudaEvent_t() const { return detail::CUDAEvent_event(internals_); } + + // Less than operator (to allow use in sets) + friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) { + return left.internals_ < right.internals_; + } + + int64_t device() const { return detail::CUDAEvent_device(internals_); } + cudaEvent_t event() const { return detail::CUDAEvent_event(internals_); } + CUDAEventInternals* internals() const { return internals_; } + + void record() const; // Record on the current stream + void record(const CUDAStream& stream) const; + +private: + CUDAEventInternals* internals_; +}; + +} // namespace cuda +} // namespace at + diff --git a/aten/src/ATen/cuda/CUDAStream.cpp b/aten/src/ATen/cuda/CUDAStream.cpp index 2dab634bc71..12d571da7f4 100644 --- a/aten/src/ATen/cuda/CUDAStream.cpp +++ b/aten/src/ATen/cuda/CUDAStream.cpp @@ -1,5 +1,6 @@ #include "ATen/cuda/CUDAStream.h" #include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/CUDAEvent.h" #include "ATen/cuda/Exceptions.h" #include "ATen/core/Error.h" @@ -173,6 +174,10 @@ namespace detail { } } + void CUDAStream_synchronize_with(CUDAStreamInternals* ptr, const CUDAEvent& event) { + AT_CUDA_CHECK(cudaStreamWaitEvent(ptr->stream, event, 0)); + } + } // namespace detail /* @@ -194,5 +199,9 @@ namespace detail { std::swap(internals_, other.internals_); } + void CUDAStream::synchronize_with(const CUDAEvent& event) const { + detail::CUDAStream_synchronize_with(internals_, event); + } + } // namespace cuda } // namespace at diff --git a/aten/src/ATen/cuda/CUDAStream.h b/aten/src/ATen/cuda/CUDAStream.h index 545bccfdfbc..7a3e1e0595c 100644 --- a/aten/src/ATen/cuda/CUDAStream.h +++ b/aten/src/ATen/cuda/CUDAStream.h @@ -15,12 +15,13 @@ * The ATen Context interface should be preferred when working with streams. */ -// Forward-declares internals struct CUDAStreamInternals; namespace at { namespace cuda { +struct CUDAEvent; + namespace detail { // Pointer-based API (for internal use) @@ -102,6 +103,8 @@ struct CUDAStream { cudaStream_t stream() const { return detail::CUDAStream_stream(internals_); } CUDAStreamInternals* internals() const { return internals_; } + void synchronize_with(const CUDAEvent& event) const; + private: CUDAStreamInternals* internals_ = nullptr; }; diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index 4d4d3df1bd3..cb9dd9c05a3 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -50,15 +50,15 @@ struct AT_API DispatchStub { static_assert(std::is_pointer::value, "FnPtr should be a pointer type"); template - void operator()(Backend backend, ArgTypes... args) { + void operator()(Backend backend, ArgTypes&&... args) { if (backend == Backend::CPU) { if (!cpu_dispatch_ptr) { cpu_dispatch_ptr = choose_cpu_impl(); } - (*cpu_dispatch_ptr)(args...); + (*cpu_dispatch_ptr)(std::forward(args)...); } else if (backend == Backend::CUDA) { AT_ASSERTM(cuda_dispatch_ptr, "DispatchStub: missing CUDA kernel"); - (*cuda_dispatch_ptr)(args...); + (*cuda_dispatch_ptr)(std::forward(args)...); } else { AT_ERROR("DispatchStub: unsupported backend", backend); } @@ -109,12 +109,33 @@ struct RegisterDispatch { #define DEFINE_DISPATCH(name) struct name name -#if defined(__CUDACC__) -#define REGISTER_DISPATCH(name, fn) \ +#define REGISTER_ARCH_DISPATCH(name, arch, fn) \ + template <> decltype(fn) DispatchStub::arch = fn; + +#ifdef HAVE_AVX_CPU_DEFINITION +#define REGISTER_AVX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX, fn) +#else +#define REGISTER_AVX_DISPATCH(name, fn) +#endif + +#ifdef HAVE_AVX2_CPU_DEFINITION +#define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn) +#else +#define REGISTER_AVX2_DISPATCH(name, fn) +#endif + +#define REGISTER_NO_CPU_DISPATCH(name, fn_type) \ + REGISTER_ARCH_DISPATCH(name, DEFAULT, static_cast(nullptr)) \ + REGISTER_AVX_DISPATCH(name, static_cast(nullptr)) \ + REGISTER_AVX2_DISPATCH(name, static_cast(nullptr)) + +#define REGISTER_CUDA_DISPATCH(name, fn) \ static RegisterDispatch name ## __register(name, fn); + +#if defined(__CUDACC__) +#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn) #elif defined(CPU_CAPABILITY) -#define REGISTER_DISPATCH(name, fn) \ - template <> decltype(fn) DispatchStub::CPU_CAPABILITY = fn; +#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) #endif diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp index 4e7a23fd1ac..27f5452811e 100644 --- a/aten/src/ATen/native/RNN.cpp +++ b/aten/src/ATen/native/RNN.cpp @@ -1,3 +1,5 @@ +#include "ATen/native/RNN.h" + #include "ATen/ATen.h" #include "ATen/NativeFunctions.h" @@ -499,100 +501,6 @@ std::tuple _lstm_impl( return std::make_tuple(result.outputs, at::stack(hy, 0), at::stack(cy, 0)); } -//////////////////////////////////////////////////////////////////////////////// -// CUDNN BINDINGS -//////////////////////////////////////////////////////////////////////////////// - -// These must line up with the CUDNN mode codes: -// https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t -enum class CuDNNMode { rnn_relu = 0, rnn_tanh = 1, lstm = 2, gru = 3 }; - -std::tuple unpack_hidden(const Tensor& hidden) { - return std::make_tuple(hidden, at::Tensor{}); -} - -std::tuple unpack_hidden(const tpair_of& hidden) { - return hidden; -} - -template -hidden_type pack_hidden(const Tensor& hx, const Tensor& cx) { - static_assert(std::is_same::value, "pack_hidden not implemented for this type"); - AT_ERROR("NOT IMPLEMENTED"); -} - -template<> -Tensor pack_hidden(const Tensor& hx, const Tensor& cx) { - AT_ASSERT(cx.numel() == 0); - return hx; -} - -template<> -tpair_of pack_hidden>(const Tensor& hx, const Tensor& cx) { - return std::make_tuple(hx, cx); -} - -const char * WEIGHT_FORMAT_WARN = "RNN module weights are not part of single contiguous " - "chunk of memory. This means they need to be compacted " - "at every call, possibly greatly increasing memory usage. " - "To compact weights again call flatten_parameters()."; - -template -LayerOutput _cudnn_impl( - const Tensor& input, const Tensor& _batch_sizes, - const hidden_type& hidden, - TensorList params, bool has_biases, - CuDNNMode cudnn_mode, const Tensor& weight_buf, const Tensor& dropout_state, - int64_t num_layers, double dropout_p, bool train, bool bidirectional) { - if (!weight_buf.defined()) { - AT_WARN(WEIGHT_FORMAT_WARN); - } - - Tensor hx, cx; - std::tie(hx, cx) = unpack_hidden(hidden); - - int64_t hidden_size = hx.size(2); - - AT_CHECK(_batch_sizes.dim() == 1, "batch_sizes tensor should be 1D"); - IntList batch_sizes { _batch_sizes.data(), static_cast(_batch_sizes.size(0)) }; - // cudnn_output = std::tuple - auto cudnn_output = at::_cudnn_rnn( - input, params, has_biases ? 4 : 2, weight_buf, - hx, cx, static_cast(cudnn_mode), hidden_size, - num_layers, /*batch_first=*/false, dropout_p, train, bidirectional, - batch_sizes, dropout_state); - - return {std::get<0>(cudnn_output), - pack_hidden(std::get<1>(cudnn_output), std::get<2>(cudnn_output))}; -} - -template -LayerOutput _cudnn_impl( - const Tensor& input, - const hidden_type& hidden, - TensorList params, bool has_biases, - CuDNNMode cudnn_mode, const Tensor& weight_buf, const Tensor& dropout_state, - int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { - if (!weight_buf.defined()) { - AT_WARN(WEIGHT_FORMAT_WARN); - } - - Tensor hx, cx; - std::tie(hx, cx) = unpack_hidden(hidden); - - int64_t hidden_size = hx.size(2); - - // cudnn_output = std::tuple - auto cudnn_output = at::_cudnn_rnn( - input, params, has_biases ? 4 : 2, weight_buf, - hx, cx, static_cast(cudnn_mode), hidden_size, - num_layers, batch_first, dropout_p, train, bidirectional, - /*batch_sizes=*/{}, dropout_state); - - return {std::get<0>(cudnn_output), - pack_hidden(std::get<1>(cudnn_output), std::get<2>(cudnn_output))}; -} - } // anonymous namespace //////////////////////////////////////////////////////////////////////////////// @@ -600,16 +508,20 @@ LayerOutput _cudnn_impl( //////////////////////////////////////////////////////////////////////////////// #define ONE_HIDDEN_RNN(NAME, CELL) \ +DEFINE_DISPATCH(NAME##_cudnn_stub); \ +DEFINE_DISPATCH(NAME##_packed_cudnn_stub); \ +REGISTER_NO_CPU_DISPATCH(NAME##_cudnn_stub, rnn_fn); \ +REGISTER_NO_CPU_DISPATCH(NAME##_packed_cudnn_stub, rnn_packed_fn); \ + \ std::tuple NAME( \ const Tensor& _input, const Tensor& hx, \ TensorList _params, bool has_biases, \ - int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first, \ - const Tensor& cudnn_weight_buf, const Tensor& cudnn_dropout_state) { \ + int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { \ if (at::cudnn_is_acceptable(_input)) { \ - auto result = _cudnn_impl(_input, hx, _params, has_biases, \ - CuDNNMode::NAME, cudnn_weight_buf, cudnn_dropout_state, \ - num_layers, dropout_p, train, bidirectional, batch_first); \ - return std::make_tuple(result.outputs, result.final_hidden); \ + Tensor output, hy; \ + NAME##_cudnn_stub(_input.type().backend(), output, hy, _input, hx, _params, has_biases, \ + num_layers, dropout_p, train, bidirectional, batch_first); \ + return std::make_tuple(output, hy); \ } \ auto input = batch_first ? _input.transpose(0, 1) : _input; \ auto params = gather_params(_params, has_biases); \ @@ -624,12 +536,12 @@ std::tuple NAME( \ std::tuple NAME( \ const Tensor& data, const Tensor& batch_sizes, const Tensor& hx, \ TensorList _params, bool has_biases, \ - int64_t num_layers, double dropout_p, bool train, bool bidirectional, \ - const Tensor& cudnn_weight_buf, const Tensor& cudnn_dropout_state) { \ + int64_t num_layers, double dropout_p, bool train, bool bidirectional) { \ if (at::cudnn_is_acceptable(data)) { \ - auto result = _cudnn_impl(data, batch_sizes, hx, _params, has_biases, \ - CuDNNMode::NAME, cudnn_weight_buf, cudnn_dropout_state, num_layers, dropout_p, train, bidirectional); \ - return std::make_tuple(result.outputs, result.final_hidden); \ + Tensor output, hy; \ + NAME##_packed_cudnn_stub(data.type().backend(), output, hy, data, batch_sizes, hx, \ + _params, has_biases, num_layers, dropout_p, train, bidirectional); \ + return std::make_tuple(output, hy); \ } \ PackedSequence input { data, batch_sizes }; \ auto params = gather_params(_params, has_biases); \ @@ -643,16 +555,21 @@ ONE_HIDDEN_RNN(gru, GRUCell) ONE_HIDDEN_RNN(rnn_tanh, SimpleCell) ONE_HIDDEN_RNN(rnn_relu, SimpleCell) +DEFINE_DISPATCH(lstm_cudnn_stub); +DEFINE_DISPATCH(lstm_packed_cudnn_stub); +REGISTER_NO_CPU_DISPATCH(lstm_cudnn_stub, lstm_fn); +REGISTER_NO_CPU_DISPATCH(lstm_packed_cudnn_stub, lstm_packed_fn); + std::tuple lstm( const Tensor& _input, TensorList hx, TensorList _params, bool has_biases, - int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first, - const Tensor& cudnn_weight_buf, const Tensor& cudnn_dropout_state) { + int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { AT_CHECK(hx.size() == 2, "lstm expects two hidden states"); if (at::cudnn_is_acceptable(_input)) { - auto result = _cudnn_impl(_input, std::make_tuple(hx[0], hx[1]), _params, has_biases, - CuDNNMode::lstm, cudnn_weight_buf, cudnn_dropout_state, num_layers, dropout_p, train, bidirectional, batch_first); - return std::make_tuple(result.outputs, std::get<0>(result.final_hidden), std::get<1>(result.final_hidden)); + Tensor output, hy, cy; + lstm_cudnn_stub(_input.type().backend(), output, hy, cy, _input, hx, _params, has_biases, + num_layers, dropout_p, train, bidirectional, batch_first); + return std::make_tuple(output, hy, cy); } auto input = batch_first ? _input.transpose(0, 1) : _input; auto params = gather_params(_params, has_biases); @@ -667,13 +584,13 @@ std::tuple lstm( std::tuple lstm( const Tensor& data, const Tensor& batch_sizes, TensorList hx, TensorList _params, bool has_biases, - int64_t num_layers, double dropout_p, bool train, bool bidirectional, - const Tensor& cudnn_weight_buf, const Tensor& cudnn_dropout_state) { + int64_t num_layers, double dropout_p, bool train, bool bidirectional) { AT_CHECK(hx.size() == 2, "lstm expects two hidden states"); if (at::cudnn_is_acceptable(data)) { - auto result = _cudnn_impl(data, batch_sizes, std::make_tuple(hx[0], hx[1]), _params, has_biases, - CuDNNMode::lstm, cudnn_weight_buf, cudnn_dropout_state, num_layers, dropout_p, train, bidirectional); - return std::make_tuple(result.outputs, std::get<0>(result.final_hidden), std::get<1>(result.final_hidden)); + Tensor output, hy, cy; + lstm_packed_cudnn_stub(data.type().backend(), output, hy, cy, data, batch_sizes, hx, + _params, has_biases, num_layers, dropout_p, train, bidirectional); + return std::make_tuple(output, hy, cy); } PackedSequence input { data, batch_sizes }; auto params = gather_params(_params, has_biases); diff --git a/aten/src/ATen/native/RNN.h b/aten/src/ATen/native/RNN.h new file mode 100644 index 00000000000..3fc89993404 --- /dev/null +++ b/aten/src/ATen/native/RNN.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +namespace at { namespace native { + +using lstm_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool, bool); +using rnn_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool, bool); +using lstm_packed_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool); +using rnn_packed_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool); + +DECLARE_DISPATCH(lstm_fn, lstm_cudnn_stub); +DECLARE_DISPATCH(rnn_fn, gru_cudnn_stub); +DECLARE_DISPATCH(rnn_fn, rnn_tanh_cudnn_stub); +DECLARE_DISPATCH(rnn_fn, rnn_relu_cudnn_stub); +DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_cudnn_stub); +DECLARE_DISPATCH(rnn_packed_fn, gru_packed_cudnn_stub); +DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_cudnn_stub); +DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_cudnn_stub); + +}} // namespace at::native + diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index 994a652dbaa..7a4e0ec7a3c 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -6,6 +6,8 @@ #include #include #include +#include +#include #if !AT_CUDNN_ENABLED() @@ -451,7 +453,7 @@ namespace { // (same for the hh weights, and the ih and hh biases). // Since we're storing all the weights in a single tensor anyway, // might as well merge the CUDNN ones into a single tensor as well - int mat_numel = *filter_dim_a.prod(at::ScalarType::Int).data(); + int mat_numel = *filter_dim_a.prod(at::ScalarType::Int).data(); if (linear_id == 0 || linear_id == num_linear_layers / 2) { std::initializer_list size = { mat_numel * num_linear_layers / 2, 1}; @@ -477,6 +479,46 @@ namespace { return std::make_pair(params, global_layer_params_count); } + // This is a lightweight version of the method above used to quickly get the expected + // parameter offsets. + std::vector get_expected_data_ptrs( + const Tensor& weight_buf, cudnnHandle_t handle, const RNNDescriptorParams& rnn, + const RNNDescriptor& rnn_desc, const TensorDescriptor& x_desc, cudnnDataType_t datatype) { + FilterDescriptor w_desc; + w_desc.set(weight_buf, 3); + + int64_t num_linear_layers = _num_linear_layers(rnn.mode); + int64_t num_dir_layers = rnn.num_directions() * rnn.num_layers; + const auto cudnn_methods = { cudnnGetRNNLinLayerMatrixParams, cudnnGetRNNLinLayerBiasParams }; + std::vector data_ptrs; + data_ptrs.reserve(num_dir_layers * 2 * 2); + for (int64_t layer = 0; layer < num_dir_layers; layer++) { + for (auto cudnn_method : cudnn_methods) { + // This API returns a separate pointer for weight of every gate, + // but we represent them as a single tensor, so we're only interested + // in a very limited subset of possible values. + const std::array linear_offsets = { 0, num_linear_layers / 2 }; + for (int64_t linear_id : linear_offsets) { + FilterDescriptor lin_layer_mat_desc; + void* matrix_pointer; + AT_CUDNN_CHECK(cudnn_method( + handle, + rnn_desc.desc(), + layer, + x_desc.desc(), + w_desc.desc(), + weight_buf.data_ptr(), + linear_id, + lin_layer_mat_desc.mut_desc(), + &matrix_pointer + )); + data_ptrs.push_back(matrix_pointer); + } + } + } + return data_ptrs; + } + void _copyParams(MatrixRef params_from, MatrixRef params_to) { AT_ASSERTM(params_from.size(0) == params_to.size(0), "number of layers mismatch"); for (size_t i = 0; i < params_from.size(0); i++) { @@ -1007,6 +1049,234 @@ Tensor _cudnn_init_dropout_state(const Type& ty, double dropout, bool train, int return dropout_desc.state; } +//////////////////////////////////////////////////////////////////////////////// +// CUDA dispatch for the generic RNN ops (at::lstm, at::gru, ...) +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +// Helpers for working with different hidden types. +std::tuple unpack_hidden(const Tensor& hidden) { + return std::make_tuple(hidden, at::Tensor{}); +} + +std::tuple unpack_hidden(const std::tuple& hidden) { + return hidden; +} + +template +hidden_type pack_hidden(const Tensor& hx, const Tensor& cx) { + static_assert(std::is_same::value, "pack_hidden not implemented for this type"); + AT_ERROR("NOT IMPLEMENTED"); +} + +template<> +Tensor pack_hidden(const Tensor& hx, const Tensor& cx) { + AT_ASSERT(cx.numel() == 0); + return hx; +} + +template<> +std::tuple pack_hidden>(const Tensor& hx, const Tensor& cx) { + return std::make_tuple(hx, cx); +} + +struct DropoutState { + at::Tensor buffer; + cuda::CUDAEvent event; + std::mutex mutex; + + void lock() { + mutex.lock(); + cuda::getCurrentCUDAStream().synchronize_with(event); + } + + void unlock() { + event.record(); + mutex.unlock(); + } +}; + +DropoutState& get_dropout_state(const Type& tp, double dropout_p, bool train) { + // Each state is slightly over 2MB and initialized lazily, so it's fine to cache them. + static std::vector ten_dropout_state_cache { static_cast(cuda::getNumGPUs()) }; + static std::vector var_dropout_state_cache { static_cast(cuda::getNumGPUs()) }; + static std::mutex state_cache_mut; + + int device = cuda::current_device(); + std::unique_lock lock {state_cache_mut}; + auto& state = tp.is_variable() ? var_dropout_state_cache.at(device) + : ten_dropout_state_cache.at(device); + if (!state.buffer.defined() && dropout_p > 0 && train) { + int64_t seed = at::empty({}, at::kLong).random_().toCLong(); + state.buffer = at::_cudnn_init_dropout_state( + tp.toScalarType(at::kByte), dropout_p, train, seed); + // NB: This event will be already constructed by now, but CUDA actually binds the event + // to a device at creation time, and all events in the cache initially are assigned to + // the one that was active when this function is called for the first time. + state.event = cuda::CUDAEvent{}; + } + return state; +} + +Tensor try_get_weight_buf( + const Tensor& input, TensorList parameters, bool has_biases, + cudnnRNNMode_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional) { + // Prepare all relevant descriptors + auto handle = getCudnnHandle(); + auto datatype = getCudnnDataType(input); + + RNNDescriptorParams rnn; + rnn.set(mode, hidden_size, num_layers, bidirectional, datatype); + RNNDescriptor rnn_desc = rnn.descriptor(handle); + + TensorGeometry x_geom ({1, input.size(-1)}); + TensorDescriptor x_desc; + x_desc.set(datatype, x_geom.sizes(), x_geom.strides(), 5); + + auto num_params = get_num_weights(handle, rnn_desc, x_desc, datatype); + + // Try to get parameter storage + auto & any_param = parameters.at(0); + auto param_storage = any_param.storage(); + auto weight_buf = any_param.type().tensor().set_(*param_storage); + if (weight_buf.size(0) < num_params) { + return {}; + } else if (weight_buf.size(0) > num_params) { + weight_buf = weight_buf.narrow(0, 0, num_params); + } + + // Get and check data pointers + auto expected_data_ptrs = get_expected_data_ptrs( + weight_buf, handle, rnn, rnn_desc, x_desc, datatype); + + int64_t num_parameters = parameters.size(); + int64_t num_ptrs = expected_data_ptrs.size(); + AT_ASSERT(num_ptrs == (num_parameters * (has_biases ? 1 : 2))); + AT_ASSERT(num_ptrs % (has_biases ? 4 : 2) == 0); + for (int64_t param_i = 0, ptr_i = 0; + ptr_i < num_ptrs; + ptr_i += (has_biases ? 2 : 4), param_i += 2) { + if (expected_data_ptrs[ptr_i] != parameters[param_i].data_ptr()) return {}; + if (expected_data_ptrs[ptr_i + 1] != parameters[param_i + 1].data_ptr()) return {}; + } + if (!parameters[num_parameters - 1].is_contiguous()) return {}; + return weight_buf; +} + +const char * WEIGHT_FORMAT_WARN = "RNN module weights are not part of single contiguous " + "chunk of memory. This means they need to be compacted " + "at every call, possibly greatly increasing memory usage. " + "To compact weights again call flatten_parameters()."; + +template +std::pair _cudnn_impl( + const Tensor& input, const Tensor& _batch_sizes, const hidden_type& hidden, + TensorList params, bool has_biases, cudnnRNNMode_t mode, + int64_t num_layers, double dropout_p, bool train, bool bidirectional) { + Tensor hx, cx; + std::tie(hx, cx) = unpack_hidden(hidden); + int64_t hidden_size = hx.size(2); + + auto weight_buf = try_get_weight_buf( + input, params, has_biases, mode, hidden_size, num_layers, bidirectional); + if (!weight_buf.defined()) { + AT_WARN(WEIGHT_FORMAT_WARN); + } + + AT_CHECK(_batch_sizes.dim() == 1, "batch_sizes tensor should be 1D"); + IntList batch_sizes { _batch_sizes.data(), static_cast(_batch_sizes.size(0)) }; + + auto & dropout_state = get_dropout_state(input.type(), dropout_p, train); + std::unique_lock lock { dropout_state }; + // cudnn_output = std::tuple + auto cudnn_output = at::_cudnn_rnn( + input, params, has_biases ? 4 : 2, weight_buf, + hx, cx, static_cast(mode), hidden_size, num_layers, /*batch_first=*/false, + dropout_p, train, bidirectional, batch_sizes, dropout_state.buffer); + + return {std::get<0>(cudnn_output), + pack_hidden(std::get<1>(cudnn_output), std::get<2>(cudnn_output))}; +} + +template +std::pair _cudnn_impl( + const Tensor& input, const hidden_type& hidden, + TensorList params, bool has_biases, cudnnRNNMode_t mode, + int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { + Tensor hx, cx; + std::tie(hx, cx) = unpack_hidden(hidden); + int64_t hidden_size = hx.size(2); + + auto weight_buf = try_get_weight_buf( + input, params, has_biases, mode, hidden_size, num_layers, bidirectional); + if (!weight_buf.defined()) { + AT_WARN(WEIGHT_FORMAT_WARN); + } + + auto & dropout_state = get_dropout_state(input.type(), dropout_p, train); + std::unique_lock lock { dropout_state }; + // cudnn_output = std::tuple + auto cudnn_output = at::_cudnn_rnn( + input, params, has_biases ? 4 : 2, weight_buf, + hx, cx, static_cast(mode), hidden_size, num_layers, batch_first, dropout_p, + train, bidirectional, /*batch_sizes=*/{}, dropout_state.buffer); + + return {std::get<0>(cudnn_output), + pack_hidden(std::get<1>(cudnn_output), std::get<2>(cudnn_output))}; +} + +#define ONE_HIDDEN_RNN(NAME, MODE) \ +void NAME##_cudnn(Tensor& output, Tensor& hy, \ + const Tensor& input, const Tensor& hx, \ + TensorList params, bool has_biases, \ + int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { \ + std::tie(output, hy) = _cudnn_impl(input, hx, params, has_biases, \ + MODE, num_layers, dropout_p, train, bidirectional, batch_first); \ +} \ + \ +void NAME##_packed_cudnn(Tensor& output, Tensor& hy, \ + const Tensor& data, const Tensor& batch_sizes, const Tensor& hx, \ + TensorList params, bool has_biases, \ + int64_t num_layers, double dropout_p, bool train, bool bidirectional) { \ + std::tie(output, hy) = _cudnn_impl(data, batch_sizes, hx, params, \ + has_biases, MODE, num_layers, dropout_p, train, bidirectional); \ +} \ + \ +REGISTER_CUDA_DISPATCH(NAME##_cudnn_stub, &NAME##_cudnn); \ +REGISTER_CUDA_DISPATCH(NAME##_packed_cudnn_stub, &NAME##_packed_cudnn); + +ONE_HIDDEN_RNN(gru, CUDNN_GRU) +ONE_HIDDEN_RNN(rnn_tanh, CUDNN_RNN_TANH) +ONE_HIDDEN_RNN(rnn_relu, CUDNN_RNN_RELU) + +void lstm_cudnn(Tensor& output, Tensor& hy, Tensor& cy, + const Tensor& input, TensorList hx, + TensorList params, bool has_biases, + int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { + auto result = _cudnn_impl(input, std::make_tuple(hx[0], hx[1]), params, has_biases, + CUDNN_LSTM, num_layers, dropout_p, train, bidirectional, batch_first); + output = result.first; + hy = std::get<0>(result.second); + cy = std::get<1>(result.second); +} + +void lstm_packed_cudnn(Tensor& output, Tensor& hy, Tensor& cy, + const Tensor& data, const Tensor& batch_sizes, TensorList hx, + TensorList params, bool has_biases, + int64_t num_layers, double dropout_p, bool train, bool bidirectional) { + auto result = _cudnn_impl(data, batch_sizes, std::make_tuple(hx[0], hx[1]), + params, has_biases, CUDNN_LSTM, num_layers, dropout_p, train, bidirectional); + output = result.first; + hy = std::get<0>(result.second); + cy = std::get<1>(result.second); +} + +REGISTER_CUDA_DISPATCH(lstm_cudnn_stub, &lstm_cudnn); +REGISTER_CUDA_DISPATCH(lstm_packed_cudnn_stub, &lstm_packed_cudnn); + +} // anonymous namepsace + }} // namespace at::native #endif // AT_CUDNN_ENABLED() diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 13a01d9161a..d7e0ef45d43 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2053,28 +2053,28 @@ variants: function # RNN cells and layers -- func: lstm(Tensor input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, Tensor? cudnn_weight_buf={}, Tensor? cudnn_dropout_state={}) -> (Tensor, Tensor, Tensor) +- func: lstm(Tensor input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor) variants: function -- func: lstm(Tensor data, Tensor batch_sizes, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, Tensor? cudnn_weight_buf={}, Tensor? cudnn_dropout_state={}) -> (Tensor, Tensor, Tensor) +- func: lstm(Tensor data, Tensor batch_sizes, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor) variants: function -- func: gru(Tensor input, Tensor hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, Tensor? cudnn_weight_buf={}, Tensor? cudnn_dropout_state={}) -> (Tensor, Tensor) +- func: gru(Tensor input, Tensor hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) variants: function -- func: gru(Tensor data, Tensor batch_sizes, Tensor hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, Tensor? cudnn_weight_buf={}, Tensor? cudnn_dropout_state={}) -> (Tensor, Tensor) +- func: gru(Tensor data, Tensor batch_sizes, Tensor hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) -> (Tensor, Tensor) variants: function -- func: rnn_tanh(Tensor input, Tensor hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, Tensor? cudnn_weight_buf={}, Tensor? cudnn_dropout_state={}) -> (Tensor, Tensor) +- func: rnn_tanh(Tensor input, Tensor hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) variants: function -- func: rnn_tanh(Tensor data, Tensor batch_sizes, Tensor hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, Tensor? cudnn_weight_buf={}, Tensor? cudnn_dropout_state={}) -> (Tensor, Tensor) +- func: rnn_tanh(Tensor data, Tensor batch_sizes, Tensor hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) -> (Tensor, Tensor) variants: function -- func: rnn_relu(Tensor input, Tensor hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, Tensor? cudnn_weight_buf={}, Tensor? cudnn_dropout_state={}) -> (Tensor, Tensor) +- func: rnn_relu(Tensor input, Tensor hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) variants: function -- func: rnn_relu(Tensor data, Tensor batch_sizes, Tensor hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, Tensor? cudnn_weight_buf={}, Tensor? cudnn_dropout_state={}) -> (Tensor, Tensor) +- func: rnn_relu(Tensor data, Tensor batch_sizes, Tensor hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) -> (Tensor, Tensor) variants: function - func: lstm_cell(Tensor input, TensorList hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih={}, Tensor? b_hh={}) -> (Tensor, Tensor) diff --git a/torch/nn/_functions/rnn.py b/torch/nn/_functions/rnn.py index ba7db5aa2d2..702cdb156b5 100644 --- a/torch/nn/_functions/rnn.py +++ b/torch/nn/_functions/rnn.py @@ -13,8 +13,7 @@ except ImportError: def _select_rnn_impl(mode, input_size, hidden_size, num_layers=1, batch_first=False, - dropout=0, train=True, bidirectional=False, variable_length=False, - dropout_state=None, flat_weight=None): + dropout=0, train=True, bidirectional=False, dropout_state=None): hidden_is_tensor = True if mode == 'RNN_RELU': impl = torch._C._VariableFunctions.rnn_relu @@ -31,18 +30,11 @@ def _select_rnn_impl(mode, input_size, hidden_size, num_layers=1, batch_first=Fa def forward(input, weight, hidden, batch_sizes): has_biases = len(weight[0]) == 4 weight = sum(weight, type(weight[0])()) - if cudnn.is_acceptable(input): - dropout_seed = int(torch.IntTensor(1).random_()) - with torch.cuda.device(input.get_device()): - dropout_ts = cudnn.rnn.init_dropout_state(dropout, train, dropout_seed, dropout_state) + + if batch_sizes is None: + result = impl(input, hidden, weight, has_biases, num_layers, dropout, train, bidirectional, batch_first) else: - dropout_ts = None - if not variable_length: - result = impl(input, hidden, weight, has_biases, num_layers, dropout, train, bidirectional, - batch_first, flat_weight, dropout_ts) - else: - result = impl(input, batch_sizes, hidden, weight, has_biases, num_layers, dropout, train, - bidirectional, flat_weight, dropout_ts) + result = impl(input, batch_sizes, hidden, weight, has_biases, num_layers, dropout, train, bidirectional) return result[0], (result[1] if hidden_is_tensor else result[1:]) return forward diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index 68dc70a6d4b..20ff911ecd2 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -80,7 +80,6 @@ class RNNBase(Module): """ any_param = next(self.parameters()).data if not any_param.is_cuda or not torch.backends.cudnn.is_acceptable(any_param): - self._data_ptrs = [] return # If any parameters alias, we fall back to the slower, copying code path. This is @@ -89,7 +88,6 @@ class RNNBase(Module): # Module.named_parameters(). unique_data_ptrs = set(p.data_ptr() for l in self.all_weights for p in l) if len(unique_data_ptrs) != sum(len(l) for l in self.all_weights): - self._data_ptrs = [] return with torch.cuda.device_of(any_param): @@ -108,9 +106,6 @@ class RNNBase(Module): self.input_size, rnn.get_cudnn_mode(self.mode), self.hidden_size, self.num_layers, self.batch_first, bool(self.bidirectional)) - self._param_buf_size = weight_buf.size(0) - self._data_ptrs = list(p.data.data_ptr() for p in self.parameters()) - def _apply(self, fn): ret = super(RNNBase, self)._apply(fn) self.flatten_parameters() @@ -171,14 +166,6 @@ class RNNBase(Module): if self.mode == 'LSTM': hx = (hx, hx) - has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs - if has_flat_weights: - first_data = next(self.parameters()).data - assert first_data.storage().size() == self._param_buf_size - flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size])) - else: - flat_weight = None - self.check_forward_args(input, hx, batch_sizes) func = _get_rnn_impl( self.mode, @@ -190,8 +177,6 @@ class RNNBase(Module): train=self.training, bidirectional=self.bidirectional, dropout_state=self.dropout_state, - variable_length=is_packed, - flat_weight=flat_weight ) output, hidden = func(input, self.all_weights, hx, batch_sizes) if is_packed: @@ -214,7 +199,6 @@ class RNNBase(Module): def __setstate__(self, d): super(RNNBase, self).__setstate__(d) - self.__dict__.setdefault('_data_ptrs', []) if 'all_weights' in d: self._all_weights = d['all_weights'] if isinstance(self._all_weights[0][0], str): diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index ca5702837bd..9695e6fc0cd 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -1066,6 +1066,7 @@ def rnn_trace_override_symbolic(cell_type, func, sym, g, input, weights, hiddens inputs = list(itertools.chain.from_iterable( [[input], flattened_weights, hiddens, [batch_sizes] if batch_sizes else []])) + outputs = g.wrapPyFuncWithSymbolic( forward_flattened_wrapper, inputs,