mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
[CUDA EP] RNN check if tf32 is allowed (#20338)
Respect the use_tf32 flag.
This commit is contained in:
parent
7ebc653f04
commit
5eae33fc6b
2 changed files with 13 additions and 6 deletions
|
|
@ -153,7 +153,8 @@ Status CudnnRnnBase<T>::CacheCudnnRnnWeights(const OpKernelInfo& info) {
|
|||
cudnn_direction_mode_,
|
||||
rnn_mode_,
|
||||
has_bias,
|
||||
CudnnTensor::GetDataType<CudaT>()));
|
||||
CudnnTensor::GetDataType<CudaT>(),
|
||||
UseTF32()));
|
||||
if (get_B) {
|
||||
ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B,
|
||||
w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_,
|
||||
|
|
@ -296,7 +297,8 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
cudnn_direction_mode_,
|
||||
rnn_mode_,
|
||||
has_bias,
|
||||
CudnnTensor::GetDataType<CudaT>()));
|
||||
CudnnTensor::GetDataType<CudaT>(),
|
||||
UseTF32()));
|
||||
|
||||
// Prepare the weight data
|
||||
size_t w_data_size_in_bytes = 0;
|
||||
|
|
|
|||
|
|
@ -5,8 +5,6 @@
|
|||
|
||||
#include "core/common/gsl.h"
|
||||
|
||||
#include <cudnn.h>
|
||||
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
|
||||
|
|
@ -40,10 +38,17 @@ class CudnnRNN {
|
|||
|
||||
Status Set(int64_t input_size, int64_t hidden_size, int64_t proj_size, int num_layers,
|
||||
cudnnDropoutDescriptor_t cudnn_dropout_desc, cudnnDirectionMode_t cudnn_direction_model,
|
||||
cudnnRNNMode_t rnn_mode, bool has_bias, cudnnDataType_t dataType) {
|
||||
cudnnRNNMode_t rnn_mode, bool has_bias, cudnnDataType_t dataType, bool use_tf32) {
|
||||
if (!cudnn_rnn_desc_)
|
||||
CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&cudnn_rnn_desc_));
|
||||
|
||||
cudnnMathType_t mathType = CUDNN_DEFAULT_MATH;
|
||||
if (dataType == CUDNN_DATA_HALF) {
|
||||
mathType = CUDNN_TENSOR_OP_MATH;
|
||||
} else if (dataType == CUDNN_DATA_FLOAT && !use_tf32) {
|
||||
mathType = CUDNN_FMA_MATH; // omit TF32 tensor cores
|
||||
}
|
||||
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v8(cudnn_rnn_desc_,
|
||||
CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC
|
||||
rnn_mode,
|
||||
|
|
@ -52,7 +57,7 @@ class CudnnRNN {
|
|||
CUDNN_LINEAR_INPUT,
|
||||
dataType,
|
||||
dataType,
|
||||
dataType == CUDNN_DATA_HALF ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH,
|
||||
mathType,
|
||||
gsl::narrow_cast<int>(input_size),
|
||||
gsl::narrow_cast<int>(hidden_size),
|
||||
gsl::narrow_cast<int>(proj_size), // projected size
|
||||
|
|
|
|||
Loading…
Reference in a new issue