[CUDA EP] RNN check if tf32 is allowed (#20338)

Respect the use_tf32 flag.
This commit is contained in:
Maximilian Müller 2024-04-23 09:19:09 +02:00 committed by GitHub
parent 7ebc653f04
commit 5eae33fc6b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 13 additions and 6 deletions

View file

@ -153,7 +153,8 @@ Status CudnnRnnBase<T>::CacheCudnnRnnWeights(const OpKernelInfo& info) {
cudnn_direction_mode_,
rnn_mode_,
has_bias,
CudnnTensor::GetDataType<CudaT>()));
CudnnTensor::GetDataType<CudaT>(),
UseTF32()));
if (get_B) {
ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B,
w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_,
@ -296,7 +297,8 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
cudnn_direction_mode_,
rnn_mode_,
has_bias,
CudnnTensor::GetDataType<CudaT>()));
CudnnTensor::GetDataType<CudaT>(),
UseTF32()));
// Prepare the weight data
size_t w_data_size_in_bytes = 0;

View file

@ -5,8 +5,6 @@
#include "core/common/gsl.h"
#include <cudnn.h>
#include "core/providers/cuda/cuda_kernel.h"
#include "core/providers/cuda/cudnn_common.h"
@ -40,10 +38,17 @@ class CudnnRNN {
Status Set(int64_t input_size, int64_t hidden_size, int64_t proj_size, int num_layers,
cudnnDropoutDescriptor_t cudnn_dropout_desc, cudnnDirectionMode_t cudnn_direction_model,
cudnnRNNMode_t rnn_mode, bool has_bias, cudnnDataType_t dataType) {
cudnnRNNMode_t rnn_mode, bool has_bias, cudnnDataType_t dataType, bool use_tf32) {
if (!cudnn_rnn_desc_)
CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&cudnn_rnn_desc_));
cudnnMathType_t mathType = CUDNN_DEFAULT_MATH;
if (dataType == CUDNN_DATA_HALF) {
mathType = CUDNN_TENSOR_OP_MATH;
} else if (dataType == CUDNN_DATA_FLOAT && !use_tf32) {
mathType = CUDNN_FMA_MATH; // omit TF32 tensor cores
}
CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v8(cudnn_rnn_desc_,
CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC
rnn_mode,
@ -52,7 +57,7 @@ class CudnnRNN {
CUDNN_LINEAR_INPUT,
dataType,
dataType,
dataType == CUDNN_DATA_HALF ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH,
mathType,
gsl::narrow_cast<int>(input_size),
gsl::narrow_cast<int>(hidden_size),
gsl::narrow_cast<int>(proj_size), // projected size