From fbf274f5a7c55f58ee1f7eb9b515f23f29bff443 Mon Sep 17 00:00:00 2001 From: Michael Carilli Date: Tue, 18 Aug 2020 13:36:02 -0700 Subject: [PATCH] Autocast support for cudnn RNNs (#42385) Summary: Should close https://github.com/pytorch/pytorch/issues/36428. The cudnn RNN API expects weights to occupy a flat buffer in memory with a particular layout. This PR implements a "speed of light" fix: [`_cudnn_rnn_cast_reflatten`](https://github.com/pytorch/pytorch/pull/42385/files#diff-9ef93b6a4fb5a06a37c562b83737ac6aR327) (the autocast wrapper assigned to `_cudnn_rnn`) copies weights to the right slices of a flat FP16 buffer with a single read/write per weight (as opposed to casting them to FP16 individually then reflattening the individual FP16 weights, which would require 2 read/writes per weight). It isn't pretty but IMO it doesn't make rnn bindings much more tortuous than they already are. The [test](https://github.com/pytorch/pytorch/pull/42385/files#diff-e68a7bc6ba14f212e5e7eb3727394b40R2683) tries a forward under autocast and a backward for the full cross product of RNN options and input/weight/hidden dtypes. As for all FP16list autocast tests, forward output and backward grads are checked against a control where inputs (including RNN module weights in this case) are precasted to FP16 on the python side. Not sure who to ask for review, tagging ezyang and ngimel because Ed wrote this file (almost 2 years ago) and Natalia did the most recent major [surgery](https://github.com/pytorch/pytorch/pull/12600). Side quests discovered: - Should we update [persistent RNN heuristics](https://github.com/pytorch/pytorch/blob/dbdd28207c5cf6c4a35ceb1de0811c4812e8882c/aten/src/ATen/native/cudnn/RNN.cpp#L584) to include compute capability 8.0? Could be another PR but seems easy enough to include. - Many (maybe all?!) the raw cudnn API calls in [RNN.cpp](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp) are deprecated in cudnn 8. I don't mind taking the AI to update them since my mental cache is full of rnn stuff, but that would be a substantial separate PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/42385 Reviewed By: zhangguanheng66 Differential Revision: D23077782 Pulled By: ezyang fbshipit-source-id: a2afb1bdab33ba0442879a703df13dc87f03ec2e --- BUILD.bazel | 1 + aten/src/ATen/autocast_mode.cpp | 174 +++--------------- aten/src/ATen/autocast_mode.h | 123 +++++++++++++ aten/src/ATen/cudnn/AutocastRNN.cpp | 117 ++++++++++++ aten/src/ATen/cudnn/Types.cpp | 16 +- aten/src/ATen/cudnn/Types.h | 1 + aten/src/ATen/native/RNN.cpp | 4 +- aten/src/ATen/native/RNN.h | 9 +- aten/src/ATen/native/cudnn/RNN.cpp | 202 +++++++++++++-------- aten/src/ATen/native/cudnn/RNNUtils.h | 28 +++ aten/src/ATen/native/miopen/RNN_miopen.cpp | 2 +- test/test_cuda.py | 83 ++++++++- 12 files changed, 531 insertions(+), 229 deletions(-) create mode 100644 aten/src/ATen/cudnn/AutocastRNN.cpp create mode 100644 aten/src/ATen/native/cudnn/RNNUtils.h diff --git a/BUILD.bazel b/BUILD.bazel index 3c6d5c24820..da50ea2d4b8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -340,6 +340,7 @@ filegroup( "aten/src/ATen/cuda/CublasHandlePool.cpp", "aten/src/ATen/cuda/PinnedMemoryAllocator.cpp", "aten/src/ATen/cuda/detail/CUDAHooks.cpp", + "aten/src/ATen/cudnn/AutocastRNN.cpp", "aten/src/ATen/cudnn/Descriptors.cpp", "aten/src/ATen/cudnn/Handle.cpp", "aten/src/ATen/cudnn/Types.cpp", diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 199613122ad..74c4e0b69cf 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -59,90 +59,10 @@ int decrement_nesting() { return --nesting; } -// Policies correspond to op categories that need code-divergent handling. -// Wrapper templates below are specialized based on a policy template parameter. -enum class CastPolicy : uint8_t { - fp16 = 0, // Cast all inputs to at::kHalf before running the op. - fp32, // Cast all inputs to at::kFloat before running the op. - fp32_set_opt_dtype, // Treats functions (like softmax) that - // 1. we'd like to run in fp32 and - // 2. have a c10::optional arg that controls the output type. - // fp32_set_opt_dtype wrappers' policy is: if the output type is already set, - // don't touch it, otherwise, set it to at::kFloat. - fp32_append_dtype, // Treats functions (like norm) that - // 1. we'd like to run in fp32 and - // 2. have some overloads that accept an output type and other overloads that don't. - // fp32_append_dtype wrappers wrap the overloads that don't have an output dtype. - // The wrapper policy is: append at::kFloat to the args, and redispatch to the - // type-aware overload. - promote, // Run in the widest dtype among several args. -}; - -/******************************************************************** -Logic to extract the promote type from any Tensor or TensorList args. -********************************************************************/ - -// Overload to catch Tensor args. -// If nextArg is floating-point, compare its scalar_type with our -// current best guess for the promote type, and update if necessary. -inline at::ScalarType prioritize(at::ScalarType current, const Tensor& nextArg) { - if (current == at::kDouble) { - AT_ERROR("promote type is double in at::autocast::prioritize"); - return current; - } - if (nextArg.is_cuda() && nextArg.is_floating_point()) { - auto next = nextArg.scalar_type(); - if (next == at::kDouble) { - return current; // ignores double tensors - } else if (current == at::kFloat || next == at::kFloat) { - return at::kFloat; // prioritizes float over half - } else if (current == at::kHalf && next == at::kHalf) { - return at::kHalf; - } else { - AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize"); - return current; - } - } else { - return current; - } -} - -// Overload to catch TensorList args (for e.g. cat, stack). -// Reuses the overload above to process each Tensor in the list. -inline at::ScalarType prioritize(at::ScalarType current, const TensorList& list) { - for (const auto& tensor : list) { - current = prioritize(current, tensor); - } - return current; -} - -// Template to catch non-Tensor args (no-op that returns current best guess) -template -inline at::ScalarType prioritize(at::ScalarType current, T nextArg) { - return current; -} - -// Overload for the tail case. -inline at::ScalarType promote_type(at::ScalarType current) { - return current; -} - -// Unpack args and determine if incoming float16 tensors need to be promoted to float32. -// Non-Tensor arguments are ignored. -template -inline at::ScalarType promote_type(at::ScalarType current, Arg0 arg0, Args... args) { - auto new_current = prioritize(current, arg0); - return promote_type(new_current, args...); -} - -/**************************************************** -Logic to apply cached casting to any Tensor argument. -****************************************************/ -inline bool is_eligible(const Tensor& arg) { - return (arg.is_cuda() && arg.is_floating_point() && (arg.scalar_type() != at::kDouble)); -} - // Overload to catch Tensor args +// TODO (possible optimization): Move cast_cache to an inline function in a header +// (+ refactor the can_try_cache branch to call a small non-inline helper function. +// can_try_cache branch is the only part that's hard to inline in other files). Tensor cached_cast(at::ScalarType to_type, const Tensor& arg) { if (is_eligible(arg) && (arg.scalar_type() != to_type)) { // Heuristic: Do what Apex does, and cache fp16 casts of fp32 model weights (leaves). @@ -165,61 +85,24 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg) { } } -// Overload to process optional -c10::optional cached_cast(at::ScalarType to_type, const c10::optional& arg) { - if (arg.has_value()) { - return cached_cast(to_type, *arg); - } else { - return c10::nullopt; - } -} - -// Overload to process TensorLists -std::vector cached_cast(at::ScalarType to_type, const TensorList& arg) { - std::vector vec; - vec.reserve(arg.size()); - for (const auto& t : arg) { - vec.push_back(cached_cast(to_type, t)); - } - return vec; -} - -// Template to catch non-Tensor args. -template -T cached_cast(at::ScalarType to_type, T arg) { - return arg; -} - -/******************************************************* -Logic to flip an output dtype flag. -Keep it simple for now by assuming only one such flag is -present in the argument list. If I ever need a function -with more than flag I'll figure out something else. -The policy is: -If the user has explicity specified a dtype, respect it. -Otherwise, set it to the autocast type. -********************************************************/ - -// Overload to catch dtype flags -c10::optional set_opt_dtype(at::ScalarType to_type, const c10::optional& dtype) { - return dtype.has_value() ? dtype : to_type; -} - -// Template to catch other args -template -inline T set_opt_dtype(at::ScalarType to_type, T arg) { - return arg; -} - -template -inline bool firstarg_is_eligible(const Tensor& arg, Args... args) { - return is_eligible(arg); -} - -template -inline at::ScalarType type_from_firstarg(at::ScalarType to_type, const Tensor& arg, Args... args) { - return (is_eligible(arg) ? to_type : arg.scalar_type()); -} +// Policies correspond to op categories that need code-divergent handling. +// Wrapper templates below are specialized based on a policy template parameter. +enum class CastPolicy : uint8_t { + fp16 = 0, // Cast all inputs to at::kHalf before running the op. + fp32, // Cast all inputs to at::kFloat before running the op. + fp32_set_opt_dtype, // Treats functions (like softmax) that + // 1. we'd like to run in fp32 and + // 2. have a c10::optional arg that controls the output type. + // fp32_set_opt_dtype wrappers' policy is: if the output type is already set, + // don't touch it, otherwise, set it to at::kFloat. + fp32_append_dtype, // Treats functions (like norm) that + // 1. we'd like to run in fp32 and + // 2. have some overloads that accept an output type and other overloads that don't. + // fp32_append_dtype wrappers wrap the overloads that don't have an output dtype. + // The wrapper policy is: append at::kFloat to the args, and redispatch to the + // type-aware overload. + promote, // Run in the widest dtype among several args. +}; /******************************************************************************************************** Templates to provide wrapper functions @@ -239,7 +122,7 @@ template struct WrapFunction_> { static Ret call(Args... args) { - c10::impl::ExcludeDispatchKeyGuard no_autocasting(DispatchKey::Autocast); + c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast); return (*F)(cached_cast(at::kHalf, args)...); } }; @@ -248,7 +131,7 @@ struct WrapFunction_ struct WrapFunction_> { static Ret call(Args... args) { - c10::impl::ExcludeDispatchKeyGuard no_autocasting(DispatchKey::Autocast); + c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast); return (*F)(cached_cast(at::kFloat, args)...); } }; @@ -257,7 +140,7 @@ struct WrapFunction_ struct WrapFunction_> { static Ret call(Args... args) { - c10::impl::ExcludeDispatchKeyGuard no_autocasting(DispatchKey::Autocast); + c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast); if (firstarg_is_eligible(args...)) { return (*F)(set_opt_dtype(at::kFloat, args)...); } else { @@ -272,7 +155,7 @@ struct WrapFunction_ struct WrapFunction_> { static Ret call(Args... args) { - c10::impl::ExcludeDispatchKeyGuard no_autocasting(DispatchKey::Autocast); + c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast); at::ScalarType out_type = type_from_firstarg(at::kFloat, args...); return (*F)(args..., out_type); } @@ -282,7 +165,7 @@ struct WrapFunction_ struct WrapFunction_> { static Ret call(Args... args) { - c10::impl::ExcludeDispatchKeyGuard no_autocasting(DispatchKey::Autocast); + c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast); auto to_type = promote_type(at::kHalf, args...); return (*F)(cached_cast(to_type, args)...); } @@ -319,6 +202,7 @@ Tensor binary_cross_entropy_banned(const Tensor &, const Tensor &, const c10::op "safe to autocast."); } + #ifndef USE_STATIC_DISPATCH namespace { /***************************************************************************************************************** @@ -422,7 +306,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { KERNEL(ADD_NS(layer_norm), "layer_norm", Tensor (const Tensor &, IntArrayRef, const c10::optional&, const c10::optional&, double, bool), fp32) // The macro doesn't like this one so I had to write it out manually. m.impl("native_layer_norm", - TORCH_FN((&WrapFunction (const Tensor &, const c10::optional&, const c10::optional&, int64_t, int64_t, double), std::tuple (const Tensor &, const c10::optional&, const c10::optional&, int64_t, int64_t, double), &ADD_NS(native_layer_norm)>::type::call))); + TORCH_FN((&WrapFunction (const Tensor &, const c10::optional&, const c10::optional&, int64_t, int64_t, double), std::tuple (const Tensor &, const c10::optional&, const c10::optional&, int64_t, int64_t, double), &ADD_NS(native_layer_norm)>::type::call))); KERNEL(ADD_NS(group_norm), "group_norm", Tensor (const Tensor &, int64_t, const c10::optional&, const c10::optional&, double, bool), fp32) KERNEL(ADD_NS(frobenius_norm), "frobenius_norm", Tensor (const Tensor &), fp32) KERNEL(ADD_NS(frobenius_norm), "frobenius_norm.dim", Tensor (const Tensor &, IntArrayRef, bool), fp32) @@ -490,7 +374,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { KERNEL(ADD_NS(tensordot), "tensordot", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote) m.impl("binary_cross_entropy", - TORCH_FN((&at::autocast::binary_cross_entropy_banned))); + TORCH_FN((&at::autocast::binary_cross_entropy_banned))); } } diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index 739b5324396..ca643f401bf 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -9,5 +9,128 @@ TORCH_API void clear_cache(); TORCH_API int increment_nesting(); TORCH_API int decrement_nesting(); +/******************************************************************** +Logic to extract the promote type from any Tensor or TensorList args. +********************************************************************/ + +// Overload to catch Tensor args. +// If nextArg is floating-point, compare its scalar_type with our +// current best guess for the promote type, and update if necessary. +inline at::ScalarType prioritize(at::ScalarType current, const Tensor& nextArg) { + if (current == at::kDouble) { + AT_ERROR("promote type is double in at::autocast::prioritize"); + return current; + } + if (nextArg.is_cuda() && nextArg.is_floating_point()) { + auto next = nextArg.scalar_type(); + if (next == at::kDouble) { + return current; // ignores double tensors + } else if (current == at::kFloat || next == at::kFloat) { + return at::kFloat; // prioritizes float over half + } else if (current == at::kHalf && next == at::kHalf) { + return at::kHalf; + } else { + AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize"); + return current; + } + } else { + return current; + } +} + +// Overload to catch TensorList args (for e.g. cat, stack). +// Reuses the overload above to process each Tensor in the list. +inline at::ScalarType prioritize(at::ScalarType current, const TensorList& list) { + for (const auto& tensor : list) { + current = prioritize(current, tensor); + } + return current; +} + +// Template to catch non-Tensor args (no-op that returns current best guess) +template +inline at::ScalarType prioritize(at::ScalarType current, T nextArg) { + return current; +} + +// Overload for the tail case. +inline at::ScalarType promote_type(at::ScalarType current) { + return current; +} + +// Unpack args and determine if incoming float16 tensors need to be promoted to float32. +// Non-Tensor arguments are ignored. +template +inline at::ScalarType promote_type(at::ScalarType current, Arg0 arg0, Args... args) { + auto new_current = prioritize(current, arg0); + return promote_type(new_current, args...); +} + +/**************************************************** +Logic to apply cached casting to any Tensor argument. +****************************************************/ +inline bool is_eligible(const Tensor& arg) { + return (arg.defined() && arg.is_cuda() && arg.is_floating_point() && (arg.scalar_type() != at::kDouble)); +} + +// Overload to catch Tensor args +TORCH_API Tensor cached_cast(at::ScalarType to_type, const Tensor& arg); + +// Overload to process optional +inline c10::optional cached_cast(at::ScalarType to_type, const c10::optional& arg) { + if (arg.has_value()) { + return cached_cast(to_type, *arg); + } else { + return c10::nullopt; + } +} + +// Overload to process TensorLists +inline std::vector cached_cast(at::ScalarType to_type, const TensorList& arg) { + std::vector vec; + vec.reserve(arg.size()); + for (const auto& t : arg) { + vec.push_back(cached_cast(to_type, t)); + } + return vec; +} + +// Template to catch non-Tensor args. +template +inline T cached_cast(at::ScalarType to_type, T arg) { + return arg; +} + +/******************************************************* +Logic to flip an output dtype flag. +Keep it simple for now by assuming only one such flag is +present in the argument list. If I ever need a function +with more than flag I'll figure out something else. +The policy is: +If the user has explicity specified a dtype, respect it. +Otherwise, set it to the autocast type. +********************************************************/ + +// Overload to catch dtype flags +c10::optional inline set_opt_dtype(at::ScalarType to_type, const c10::optional& dtype) { + return dtype.has_value() ? dtype : to_type; +} + +// Template to catch other args +template +inline T set_opt_dtype(at::ScalarType to_type, T arg) { + return arg; +} + +template +inline bool firstarg_is_eligible(const Tensor& arg, Args... args) { + return is_eligible(arg); +} + +template +inline at::ScalarType type_from_firstarg(at::ScalarType to_type, const Tensor& arg, Args... args) { + return (is_eligible(arg) ? to_type : arg.scalar_type()); +} + } // namespace autocast } // namespace at diff --git a/aten/src/ATen/cudnn/AutocastRNN.cpp b/aten/src/ATen/cudnn/AutocastRNN.cpp new file mode 100644 index 00000000000..4a900f1309b --- /dev/null +++ b/aten/src/ATen/cudnn/AutocastRNN.cpp @@ -0,0 +1,117 @@ +#include +#include +#include + +// pulls in AT_CUDNN_ENABLED() as defined by cmake +#include + +#if AT_CUDNN_ENABLED() +#include +#endif + +namespace at { +namespace autocast { + +/******************************************************************************** +Autocast wrapper for CuDNN RNNs (the weight reflattening needs special attention) +********************************************************************************/ + +// To be registered for the "_cudnn_rnn(...)" schema. +// _cudnn_rnn is autograd-exposed (test_autocast_cudnn_rnn in test_cuda.py includes a test to confirm) +std::tuple +_cudnn_rnn_cast_reflatten(const Tensor & input, + TensorList weight, + int64_t weight_stride0, + const c10::optional& weight_buf_opt, + const Tensor& hx, + const c10::optional& cx, + int64_t mode, + int64_t hidden_size, + int64_t num_layers, + bool batch_first, + double dropout, + bool train, + bool bidirectional, + IntArrayRef batch_sizes, + const c10::optional& dropout_state) { +#if AT_CUDNN_ENABLED() + c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast); + + for (const auto& t : weight) { + TORCH_CHECK(weight[0].scalar_type() == t.scalar_type(), "Weight scalar types do not match."); + } + // weight_stride0 is the number of weight tensors per layer and direction, as seen by model.parameters(). + // If bias is enabled, there are 4 such tensors (ih and hh weights, ih and hh biases). + // If bias is not enabled, there are 2 (ih and hh weights). + // This organization holds for all rnn types (RNN, GRU, and LSTM). + TORCH_INTERNAL_ASSERT((weight_stride0 == 2) || (weight_stride0 == 4), + "weight_stride0 must be 2 (if no bias) or 4 (if bias). Received ", + weight_stride0); + + Tensor weight_buf, redispatch_weight_buf; + std::vector redispatch_weight; + // There's an implicit contract here with native/cudnn/RNN.cpp:_cudnn_impl, which calls at:_cudnn_rnn. + // Code here assumes if _cudnn_impl passes weight_buf_opt containing a defined tensor, that tensor + // is valid flat storage of the weights in their incoming dtype. + if (weight_buf_opt.has_value()) { + weight_buf = *weight_buf_opt; + } + bool needs_cast_and_flatten = (weight_buf.defined() ? + // weight_buf is valid. Only change it if it's eligible and not already FP16. + is_eligible(weight_buf) && (weight_buf.scalar_type() != at::kHalf) : + // weight_buf is not valid. Only create it if other weights are eligible and not already FP16. + is_eligible(weight[0]) && (weight[0].scalar_type() != at::kHalf)); + if (needs_cast_and_flatten) { + // Casts weight tensors to FP16 and ensures all weights for all layers are views into a large flat buffer, + // with the right locations and layouts expected by cudnn. + // This is (and should be) autograd-exposed. + std::tie(redispatch_weight_buf, redispatch_weight) = + at::native::cudnn_rnn::copy_weights_to_flat_buf_views( + weight, + weight_stride0, + input.size(-1), + mode, + hidden_size, + num_layers, + batch_first, + bidirectional, + /*flat_buf_datatype=*/at::native::getCudnnDataTypeFromScalarType(at::kHalf), // could just hardcode CUDNN_DATA_HALF + /*flat_buf_options=*/weight[0].options().dtype(at::kHalf), + /*set_orig_weights_to_flat_buf=*/false, + /*allow_type_change=*/true, + /*include_bias=*/weight_stride0 == 4); + } + + return at::_cudnn_rnn( + cached_cast(at::kHalf, input), + needs_cast_and_flatten ? TensorList(redispatch_weight) : weight, + weight_stride0, + needs_cast_and_flatten ? redispatch_weight_buf : weight_buf, + cached_cast(at::kHalf, hx), + cached_cast(at::kHalf, cx), + mode, + hidden_size, + num_layers, + batch_first, + dropout, + train, + bidirectional, + batch_sizes, + dropout_state); +#else // AT_CUDNN_ENABLED() + AT_ERROR("autocast::_cudnn_rnn_cast_reflatten: ATen not compiled with cuDNN support"); + return {Tensor{}, Tensor{}, Tensor{}, Tensor{}, Tensor{}}; // never reached, placates the compiler +#endif // AT_CUDNN_ENABLED() +} + +#ifndef USE_STATIC_DISPATCH +namespace { +TORCH_LIBRARY_IMPL(aten, Autocast, m) { + m.impl("_cudnn_rnn", + TORCH_FN((&at::autocast::_cudnn_rnn_cast_reflatten))); +} +} // anonymous namespace +#endif + +} // namespace autocast +} // namespace at diff --git a/aten/src/ATen/cudnn/Types.cpp b/aten/src/ATen/cudnn/Types.cpp index 58e41640b5e..e450aec661f 100644 --- a/aten/src/ATen/cudnn/Types.cpp +++ b/aten/src/ATen/cudnn/Types.cpp @@ -4,19 +4,23 @@ namespace at { namespace native { -cudnnDataType_t getCudnnDataType(const at::Tensor& tensor) { - if (tensor.scalar_type() == at::kFloat) { +cudnnDataType_t getCudnnDataTypeFromScalarType(const at::ScalarType dtype) { + if (dtype == at::kFloat) { return CUDNN_DATA_FLOAT; - } else if (tensor.scalar_type() == at::kDouble) { + } else if (dtype == at::kDouble) { return CUDNN_DATA_DOUBLE; - } else if (tensor.scalar_type() == at::kHalf) { + } else if (dtype == at::kHalf) { return CUDNN_DATA_HALF; } - std::string msg("getCudnnDataType() not supported for "); - msg += toString(tensor.scalar_type()); + std::string msg("getCudnnDataTypeFromScalarType() not supported for "); + msg += toString(dtype); throw std::runtime_error(msg); } +cudnnDataType_t getCudnnDataType(const at::Tensor& tensor) { + return getCudnnDataTypeFromScalarType(tensor.scalar_type()); +} + int64_t cudnn_version() { return CUDNN_VERSION; } diff --git a/aten/src/ATen/cudnn/Types.h b/aten/src/ATen/cudnn/Types.h index 70e5dd4ab7e..90c4d56ed8a 100644 --- a/aten/src/ATen/cudnn/Types.h +++ b/aten/src/ATen/cudnn/Types.h @@ -5,6 +5,7 @@ namespace at { namespace native { +TORCH_CUDA_API cudnnDataType_t getCudnnDataTypeFromScalarType(const at::ScalarType dtype); cudnnDataType_t getCudnnDataType(const at::Tensor& tensor); int64_t cudnn_version(); diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp index f298b84626e..80d97a7e8a0 100644 --- a/aten/src/ATen/native/RNN.cpp +++ b/aten/src/ATen/native/RNN.cpp @@ -1182,7 +1182,7 @@ bool _use_cudnn_rnn_flatten_weight() { batch_first); \ return std::make_tuple(std::move(output), std::move(hy)); \ } \ - check_device(_input, _params, hx); \ + check_attributes(_input, _params, hx); \ auto input = batch_first ? _input.transpose(0, 1) : _input; \ auto params = gather_params(_params, has_biases); \ auto results = \ @@ -1410,7 +1410,7 @@ std::tuple lstm( num_layers, dropout_p, train, bidirectional, batch_first); return std::make_tuple(std::move(output), std::move(hy), std::move(cy)); } - check_device(_input, _params, hx); + check_attributes(_input, _params, hx); auto input = batch_first ? _input.transpose(0, 1) : _input; auto params = gather_params(_params, has_biases); auto results = _lstm_impl( diff --git a/aten/src/ATen/native/RNN.h b/aten/src/ATen/native/RNN.h index 2f0fa69d692..2bdb9becf4f 100644 --- a/aten/src/ATen/native/RNN.h +++ b/aten/src/ATen/native/RNN.h @@ -27,8 +27,9 @@ DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_miopen_stub); DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_cudnn_stub); DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_miopen_stub); -inline void check_device(const Tensor& input, const TensorList& params, const TensorList& hiddens) { +inline void check_attributes(const Tensor& input, const TensorList& params, const TensorList& hiddens, bool check_dtype=false) { auto input_device = input.device(); + auto input_dtype = input.scalar_type(); auto check_tensors = [&](const std::string& name, const Tensor& t) { if (!t.defined()) return; @@ -36,6 +37,12 @@ inline void check_device(const Tensor& input, const TensorList& params, const Te TORCH_CHECK(input_device == t_device, "Input and ", name, " tensors are not at the same device, found input tensor at ", input_device, " and ", name, " tensor at ", t_device); + if (check_dtype) { + auto t_dtype = t.scalar_type(); + TORCH_CHECK(input_dtype == t_dtype, + "Input and ", name, " tensors are not the same dtype, found input tensor with ", + input_dtype, " and ", name, " tensor with ", t_dtype); + } }; for (auto h : hiddens) check_tensors("hidden", h); diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index 24c498dbd21..71667482dd7 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -59,10 +59,7 @@ Tensor _cudnn_init_dropout_state(double dropout, bool train, int64_t dropout_see #else // AT_CUDNN_ENABLED() -#include -#include -#include -#include +#include namespace at { namespace native { @@ -406,7 +403,8 @@ namespace { const RNNDescriptor& rnn_desc, const TensorDescriptor& x_desc, const FilterDescriptor& w_desc, - const Tensor& weight_buf + const Tensor& weight_buf, + bool include_bias=true ) { auto cudnn_methods = { cudnnGetRNNLinLayerMatrixParams, cudnnGetRNNLinLayerBiasParams }; std::vector params; @@ -462,13 +460,19 @@ namespace { // might as well merge the CUDNN ones into a single tensor as well int mat_numel = *filter_dim_a.prod(at::ScalarType::Int).data_ptr(); if (linear_id == 0 || linear_id == num_linear_layers / 2) { - std::initializer_list size = { - mat_numel * num_linear_layers / 2, 1}; - // Generate a new parameter tensor which is a view into the - // weight_buf. - Tensor param = at::empty({0}, weight_buf.options()).set_(weight_buf.storage(), offset, size); - params.emplace_back(std::move(param)); - layer_params_count++; + // We could also exclude bias params by restricting cudnn_methods to just { cudnnGetRNNLinLayerMatrixParams } + // at the very top. However, to do so would throw off the cur_offset account, which is currently a strict + // and informative check that all params are laid out the way we think they are. If include_bias is false, + // I'd rather keep full cur_offset checks rather than save some CPU overhead by skipping the cudnn_method = + // cudnnGetRNNLinLayerBiasParams iteration. + if (include_bias || cudnn_method != cudnnGetRNNLinLayerBiasParams) { + // Generate a new parameter tensor which is a view into the weight_buf. + std::initializer_list size = { + mat_numel * num_linear_layers / 2, 1}; + Tensor param = at::empty({0}, weight_buf.options()).set_(weight_buf.storage(), offset, size); + params.emplace_back(std::move(param)); + layer_params_count++; + } } else { AT_ASSERTM(cur_offset == offset, "cur_offset = ", cur_offset, "; offset = ", offset); } @@ -526,7 +530,8 @@ namespace { return data_ptrs; } - void _viewOrCopyParams(MatrixRef params_from, MatrixRef params_to, bool copy) { + void _viewOrCopyParams(MatrixRef params_from, MatrixRef params_to, + bool copy, bool allow_type_change=false) { AT_ASSERTM(params_from.size(0) == params_to.size(0), "number of layers mismatch"); for (size_t i = 0; i < params_from.size(0); i++) { auto layer_params_from = params_from[i]; @@ -538,7 +543,12 @@ namespace { a != layer_params_from.end() && b != layer_params_to.end(); ++a, ++b) { auto param_from = *a, param_to = *b; - AT_ASSERTM(param_from.type() == param_to.type(), "parameter types mismatch"); + // if copying, allow_type_change may be true or false. + // if viewing, allow_type_change must be false. + TORCH_INTERNAL_ASSERT(copy || !allow_type_change, + "if viewing, type change is not allowed."); + TORCH_INTERNAL_ASSERT(allow_type_change || (param_from.scalar_type() == param_to.scalar_type()), + "parameter types mismatch"); if (copy) { param_to.copy_(param_from.view_as(param_to)); } else { @@ -607,6 +617,80 @@ namespace { } // anonymous namespace +// Utilities exposed in RNNUtils.h +namespace cudnn_rnn { + + TORCH_CUDA_API std::tuple> copy_weights_to_flat_buf_views( + TensorList weight_arr, + int64_t weight_stride0, + int64_t input_size, + int64_t mode, + int64_t hidden_size, + int64_t num_layers, + bool batch_first, + bool bidirectional, + const cudnnDataType_t flat_buf_datatype, + const TensorOptions& flat_buf_options, + bool set_orig_weights_to_flat_buf, + bool allow_type_change/*=false*/, + bool include_bias/*=true*/) { + // flat_buf_datatype is accepted as a separate argument (rather than extracted from flat_buf_options) + // because to extract flat_buf_datatype from flat_buf_options, we'd need to say + // auto flat_buf_datatype = getCudnnDataTypeFromScalarType(typeMetaToScalarType(options.dtype())); + // typeMetaToScalarType is a surprisingly nontrivial function. We should avoid it if we can. + + TORCH_CHECK(weight_arr.size() > 0, + "copy_weights_to_flat_buf_views: cannot flatten empty weight list"); + + RNNDescriptorParams rnn; + rnn.set(mode, hidden_size, num_layers, bidirectional, promote_rnn_math_type(flat_buf_datatype), flat_buf_datatype); + + auto handle = getCudnnHandle(); + RNNDescriptor rnn_desc = rnn.descriptor(handle); + + TensorGeometry x_geom({1, input_size}); + TensorDescriptor x_desc; + // Why do we pad to 5 dims here (and elsewhere)? + // https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnRNNForwardTraining + // expects descriptors padded to 3 dimensions. + x_desc.set(flat_buf_datatype, x_geom.sizes(), x_geom.strides(), 5); + + auto num_weights = get_num_weights(handle, rnn_desc, x_desc, flat_buf_datatype); + auto weight_buf = at::zeros(num_weights, flat_buf_options); + + FilterDescriptor w_desc; + w_desc.set(weight_buf, 3); + + // Slice off views into weight_buf + std::vector params_arr; + size_t params_stride0; + std::tie(params_arr, params_stride0) = get_parameters(handle, rnn, rnn_desc, x_desc, w_desc, weight_buf, include_bias); + + MatrixRef weight{weight_arr, static_cast(weight_stride0)}, + params{params_arr, params_stride0}; + + // Copy weights + _viewOrCopyParams(weight, params, /*copy=*/true, allow_type_change); + + if (set_orig_weights_to_flat_buf) { + // Update the storage + for (size_t i = 0; i < weight.size(0); i++) { + for (auto orig_param_it = weight[i].begin(), new_param_it = params[i].begin(); + orig_param_it != weight[i].end() && new_param_it != params[i].end(); + orig_param_it++, new_param_it++) { + auto orig_param = *orig_param_it, new_param = *new_param_it; + orig_param.set_(new_param.view_as(orig_param)); + } + } + } + + return {weight_buf, params_arr}; + } + +} // namespace cudnn_rnn + +using namespace cudnn_rnn; + // NB: does inplace update into TensorList // It would be a relatively simple matter to refactor this into multiple // functions, only one of which does an inplace update, but we leave this @@ -618,53 +702,26 @@ Tensor _cudnn_rnn_flatten_weight( int64_t fn_num_layers, bool batch_first, bool fn_bidirectional ) { - - TORCH_CHECK(weight_arr.size() > 0, - "_cudnn_rnn_flatten_weight_: cannot flatten empty weight list"); - - auto any_param = weight_arr[0]; - auto datatype = getCudnnDataType(any_param); - - RNNDescriptorParams rnn; - rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype); - - auto handle = getCudnnHandle(); - RNNDescriptor rnn_desc = rnn.descriptor(handle); - - TensorGeometry x_geom({1, input_size}); - TensorDescriptor x_desc; - x_desc.set(getCudnnDataType(any_param), x_geom.sizes(), x_geom.strides(), 5); - - auto num_weights = get_num_weights(handle, rnn_desc, x_desc, datatype); - auto weight_buf = at::zeros(num_weights, any_param.options()); - - FilterDescriptor w_desc; - w_desc.set(weight_buf, 3); - - // Slice off views into weight_buf - std::vector params_arr; - size_t params_stride0; - std::tie(params_arr, params_stride0) = get_parameters(handle, rnn, rnn_desc, x_desc, w_desc, weight_buf); - - MatrixRef weight{weight_arr, static_cast(weight_stride0)}, - params{params_arr, params_stride0}; - - // Copy weights - _copyParams(weight, params); - - // Update the storage - for (size_t i = 0; i < weight.size(0); i++) { - for (auto orig_param_it = weight[i].begin(), new_param_it = params[i].begin(); - orig_param_it != weight[i].end() && new_param_it != params[i].end(); - orig_param_it++, new_param_it++) { - auto orig_param = *orig_param_it, new_param = *new_param_it; - orig_param.set_(new_param.view_as(orig_param)); - } - } - - return weight_buf; + // returns flat weight_buf + return std::get<0>(copy_weights_to_flat_buf_views( + weight_arr, + weight_stride0, + input_size, + fn_mode, + fn_hidden_size, + fn_num_layers, + batch_first, + fn_bidirectional, + /*flat_buf_datatype=*/getCudnnDataType(weight_arr[0]), + /*flat_buf_options=*/weight_arr[0].options(), + /*set_orig_weights_to_flat_buf=*/true)); } +const char * WEIGHT_FORMAT_WARN = "RNN module weights are not part of single contiguous " + "chunk of memory. This means they need to be compacted " + "at every call, possibly greatly increasing memory usage. " + "To compact weights again call flatten_parameters()."; + // NB: when fn_batch_sizes is empty, that means no batch sizes was specified std::tuple _cudnn_rnn( const Tensor& input_r, @@ -676,9 +733,12 @@ std::tuple _cudnn_rnn( const Tensor& fn_dropout_state ) { - check_device(input_r, weight, {hx, cx}); + check_attributes(input_r, weight, {hx, cx}, /*check_dtype=*/true); auto input = input_r; auto weight_buf = weight_buf_r; + if (!weight_buf.defined()) { + TORCH_WARN(WEIGHT_FORMAT_WARN); + } if (fn_dropout_state.defined()) { auto input_arg = TensorArg(input, "input", 1); auto dropout_state_arg = TensorArg(fn_dropout_state, "dropout_states", 15); @@ -1172,9 +1232,11 @@ DropoutState& get_dropout_state(double dropout_p, bool train, TensorOptions opti Tensor try_get_weight_buf( const Tensor& input, TensorList parameters, bool has_biases, cudnnRNNMode_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional) { + // Prepare all relevant descriptors auto handle = getCudnnHandle(); - auto datatype = getCudnnDataType(input); + auto & any_param = parameters.at(0); + auto datatype = getCudnnDataType(any_param); RNNDescriptorParams rnn; rnn.set(mode, hidden_size, num_layers, bidirectional, promote_rnn_math_type(datatype), datatype); @@ -1182,12 +1244,14 @@ Tensor try_get_weight_buf( TensorGeometry x_geom ({1, input.size(-1)}); TensorDescriptor x_desc; + // datatype for x_desc comes from any_param, not input. + // try_get_weight_buf's job is to check "is the weight buffer correctly laid out + // for us to run it with input of the same datatype?" x_desc.set(datatype, x_geom.sizes(), x_geom.strides(), 5); auto num_params = get_num_weights(handle, rnn_desc, x_desc, datatype); // Try to get parameter storage - auto & any_param = parameters.at(0); auto param_storage = any_param.storage(); auto weight_buf = at::empty({0}, any_param.options()).set_(param_storage); if (weight_buf.size(0) < num_params) { @@ -1214,11 +1278,6 @@ Tensor try_get_weight_buf( return weight_buf; } -const char * WEIGHT_FORMAT_WARN = "RNN module weights are not part of single contiguous " - "chunk of memory. This means they need to be compacted " - "at every call, possibly greatly increasing memory usage. " - "To compact weights again call flatten_parameters()."; - template std::pair _cudnn_impl( const Tensor& input, const Tensor& _batch_sizes, const hidden_type& hidden, @@ -1228,11 +1287,11 @@ std::pair _cudnn_impl( std::tie(hx, cx) = unpack_hidden(hidden); int64_t hidden_size = hx.size(2); + // TODO: try_get_weight_buf returns a Tensor, but _cudnn_rnn below takes a c10::optional + // in weight_buf's slot. Do we want try_get_weight_buf to return a c10::optional + // instead of a defined or undefined Tensor? auto weight_buf = try_get_weight_buf( input, params, has_biases, mode, hidden_size, num_layers, bidirectional); - if (!weight_buf.defined()) { - TORCH_WARN(WEIGHT_FORMAT_WARN); - } TORCH_CHECK(_batch_sizes.dim() == 1, "batch_sizes tensor should be 1D"); IntArrayRef batch_sizes { _batch_sizes.data_ptr(), static_cast(_batch_sizes.size(0)) }; @@ -1260,9 +1319,6 @@ std::pair _cudnn_impl( auto weight_buf = try_get_weight_buf( input, params, has_biases, mode, hidden_size, num_layers, bidirectional); - if (!weight_buf.defined()) { - TORCH_WARN(WEIGHT_FORMAT_WARN); - } auto & dropout_state = get_dropout_state(dropout_p, train, input.options()); std::unique_lock lock { dropout_state }; diff --git a/aten/src/ATen/native/cudnn/RNNUtils.h b/aten/src/ATen/native/cudnn/RNNUtils.h new file mode 100644 index 00000000000..89b58ebef1d --- /dev/null +++ b/aten/src/ATen/native/cudnn/RNNUtils.h @@ -0,0 +1,28 @@ +#include +#include +#include +#include + +// Declares utilities used by RNN.cpp and also needed by external consumers +namespace at { +namespace native { +namespace cudnn_rnn { + +TORCH_CUDA_API std::tuple> copy_weights_to_flat_buf_views( + TensorList weight_arr, + int64_t weight_stride0, + int64_t input_size, + int64_t mode, + int64_t hidden_size, + int64_t num_layers, + bool batch_first, + bool bidirectional, + const cudnnDataType_t flat_buf_datatype, + const TensorOptions& flat_buf_options, + bool set_orig_weights_to_flat_buf, + bool allow_type_change=false, + bool include_bias=true); + +} // namespace cudnn_rnn +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/miopen/RNN_miopen.cpp b/aten/src/ATen/native/miopen/RNN_miopen.cpp index a570fada86b..1493cece321 100644 --- a/aten/src/ATen/native/miopen/RNN_miopen.cpp +++ b/aten/src/ATen/native/miopen/RNN_miopen.cpp @@ -437,7 +437,7 @@ std::tuple miopen_rnn( IntArrayRef fn_batch_sizes, const Tensor& fn_dropout_state ) { - check_device(input_r, weight, {hx, cx}); + check_attributes(input_r, weight, {hx, cx}); auto input = input_r; RNNParams fn; diff --git a/test/test_cuda.py b/test/test_cuda.py index 595c65aa744..cecb8529c25 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -3,7 +3,7 @@ import io import tempfile import unittest import sys -from itertools import repeat, chain +from itertools import repeat, chain, product import os import gc import threading @@ -2600,6 +2600,87 @@ t2.start() model() model_jit_script() + # cudnn RNNs require special backend handling (weights are cast to FP16 and reflattened) + # so they get a dedicated test. + # Despite the large number of RNN cases it tries, the test takes < 15 seconds on a Titan V (similar to V100). + @skipIfRocm + @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') + def test_autocast_rnn(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + # seq, batch, features, hidden size + clses = ("RNN", "GRU", "LSTM") + T, B, F, H = 3, 4, 5, 6 + dtypes = (torch.float16, torch.float32) + input_layouts = ("seq_first", "batch_first", "packed") + + for (cls, num_layers, bias, input_layout, bidirectional, try_nonpreflattened_weights, + input_dtype, hidden_dtype, weight_dtype) in \ + product(clses, (1, 2), (True, False), input_layouts, (True, False), (True, False), + dtypes, dtypes, dtypes): + if input_layout == "seq_first": + batch_first = False + x = torch.randn((T, B, F), device="cuda", dtype=input_dtype) + elif input_layout == "batch_first": + batch_first = True + x = torch.randn((B, T, F), device="cuda", dtype=input_dtype) + elif input_layout == "packed": + batch_first = False + x = torch.randn((T, B, F), device="cuda", dtype=input_dtype) + x = torch.nn.utils.rnn.pack_padded_sequence(torch.randn((T, B, F), + device="cuda", dtype=input_dtype), + lengths=(3, 2, 1, 3), + enforce_sorted=False) + + rnn = getattr(torch.nn, cls)(F, H, num_layers=num_layers, bidirectional=bidirectional, + bias=bias, batch_first=batch_first).cuda().to(dtype=weight_dtype) + + if try_nonpreflattened_weights: + for p in rnn.parameters(): + with torch.no_grad(): + p.set_(p.clone()) + + h = torch.randn((num_layers * (2 if bidirectional else 1), B, H), + device="cuda", dtype=hidden_dtype) + if cls == "LSTM": + c = torch.randn((num_layers * (2 if bidirectional else 1), B, H), + device="cuda", dtype=hidden_dtype) + h = (h, c) + + with torch.cuda.amp.autocast(): + out, h_out = rnn(x, h) + out = out.data if input_layout == "packed" else out + self.assertEqual(out.dtype, torch.float16) + # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed. This check can't guarantee + # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has + # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed. + self.assertEqual(out.grad_fn.name(), "CudnnRnnBackward") + out.sum().backward() + grads = [p.grad.clone() for p in rnn.parameters()] + + rnn.zero_grad() + + if cls == "LSTM": + out_control, h_out_control = rnn.to(dtype=torch.float16)(x.half(), (h[0].half(), h[1].half())) + else: + out_control, h_out_control = rnn.to(dtype=torch.float16)(x.half(), h.half()) + out_control = out_control.data if input_layout == "packed" else out_control + out_control.sum().backward() + grads_control = [p.grad.clone() for p in rnn.parameters()] + + # Compares with default tolerances, even for FP16 execution. Barring nondeterminism, + # autocast and control results should be bitwise identical. + self.assertEqual(out, out_control) + + if cls == "LSTM": + self.assertTrue(h_out[0].dtype is torch.float16 and h_out[1].dtype is torch.float16) + self.assertEqual(h_out[0], h_out_control[0]) + self.assertEqual(h_out[1], h_out_control[1]) + else: + self.assertEqual(h_out.dtype, torch.float16) + self.assertEqual(h_out, h_out_control) + for grad, grad_control in zip(grads, grads_control): + self.assertEqual(grad.half(), grad_control) + @slowTest @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") def test_max_large_axis(self):