diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index 4db30f934e..f6585775d9 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -207,19 +207,20 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { } IAllocatorUniquePtr x_reversed_data; - T* x_data = const_cast(X->template Data()); + const T* x_data = X->template Data(); if (reverse_) { // reverse input data x_reversed_data = GetScratchBuffer(seq_length * batch_size * input_size); ReverseBySequence(gsl::narrow_cast(seq_length), gsl::narrow_cast(batch_size), gsl::narrow_cast(input_size), - reinterpret_cast(x_data), + reinterpret_cast(x_data), reinterpret_cast(x_reversed_data.get()), seq_length * batch_size * input_size); - x_data = x_reversed_data.get(); } + const T* x_data_input = reverse_ ? x_reversed_data.get() : x_data; + const T* hx_data = (initial_h == nullptr) ? nullptr : initial_h->template Data(); const T* cx_data = (initial_c == nullptr) ? nullptr : initial_c->template Data(); T* y_h_data = (Y_h == nullptr) ? nullptr : Y_h->template MutableData(); @@ -248,7 +249,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { rnn_desc_, gsl::narrow_cast(seq_length), x_desc.data(), - x_data, + x_data_input, hx_desc, hx_data, cx_desc, @@ -272,7 +273,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(), rnn_desc_, x_desc, - x_data, + x_data_input, hx_desc, hx_data, cx_desc,