diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index 6476364a21..79f9bea719 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -153,7 +153,8 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { cudnn_direction_mode_, rnn_mode_, has_bias, - CudnnTensor::GetDataType())); + CudnnTensor::GetDataType(), + UseTF32())); if (get_B) { ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, @@ -296,7 +297,8 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { cudnn_direction_mode_, rnn_mode_, has_bias, - CudnnTensor::GetDataType())); + CudnnTensor::GetDataType(), + UseTF32())); // Prepare the weight data size_t w_data_size_in_bytes = 0; diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index 0fa01d3486..2bc937526f 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -5,8 +5,6 @@ #include "core/common/gsl.h" -#include - #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/cudnn_common.h" @@ -40,10 +38,17 @@ class CudnnRNN { Status Set(int64_t input_size, int64_t hidden_size, int64_t proj_size, int num_layers, cudnnDropoutDescriptor_t cudnn_dropout_desc, cudnnDirectionMode_t cudnn_direction_model, - cudnnRNNMode_t rnn_mode, bool has_bias, cudnnDataType_t dataType) { + cudnnRNNMode_t rnn_mode, bool has_bias, cudnnDataType_t dataType, bool use_tf32) { if (!cudnn_rnn_desc_) CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&cudnn_rnn_desc_)); + cudnnMathType_t mathType = CUDNN_DEFAULT_MATH; + if (dataType == CUDNN_DATA_HALF) { + mathType = CUDNN_TENSOR_OP_MATH; + } else if (dataType == CUDNN_DATA_FLOAT && !use_tf32) { + mathType = CUDNN_FMA_MATH; // omit TF32 tensor cores + } + CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v8(cudnn_rnn_desc_, CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC rnn_mode, @@ -52,7 +57,7 @@ class CudnnRNN { CUDNN_LINEAR_INPUT, dataType, dataType, - dataType == CUDNN_DATA_HALF ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH, + mathType, gsl::narrow_cast(input_size), gsl::narrow_cast(hidden_size), gsl::narrow_cast(proj_size), // projected size