mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
remove const_cast which makes it's not thread safe. (#1463)
This commit is contained in:
parent
6be93f11e5
commit
31838fc9ee
1 changed files with 6 additions and 5 deletions
|
|
@ -207,19 +207,20 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
}
|
||||
|
||||
IAllocatorUniquePtr<T> x_reversed_data;
|
||||
T* x_data = const_cast<T*>(X->template Data<T>());
|
||||
const T* x_data = X->template Data<T>();
|
||||
if (reverse_) {
|
||||
// reverse input data
|
||||
x_reversed_data = GetScratchBuffer<T>(seq_length * batch_size * input_size);
|
||||
ReverseBySequence(gsl::narrow_cast<int32_t>(seq_length),
|
||||
gsl::narrow_cast<int32_t>(batch_size),
|
||||
gsl::narrow_cast<int32_t>(input_size),
|
||||
reinterpret_cast<CudaT*>(x_data),
|
||||
reinterpret_cast<const CudaT*>(x_data),
|
||||
reinterpret_cast<CudaT*>(x_reversed_data.get()),
|
||||
seq_length * batch_size * input_size);
|
||||
x_data = x_reversed_data.get();
|
||||
}
|
||||
|
||||
const T* x_data_input = reverse_ ? x_reversed_data.get() : x_data;
|
||||
|
||||
const T* hx_data = (initial_h == nullptr) ? nullptr : initial_h->template Data<T>();
|
||||
const T* cx_data = (initial_c == nullptr) ? nullptr : initial_c->template Data<T>();
|
||||
T* y_h_data = (Y_h == nullptr) ? nullptr : Y_h->template MutableData<T>();
|
||||
|
|
@ -248,7 +249,7 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
rnn_desc_,
|
||||
gsl::narrow_cast<int>(seq_length),
|
||||
x_desc.data(),
|
||||
x_data,
|
||||
x_data_input,
|
||||
hx_desc,
|
||||
hx_data,
|
||||
cx_desc,
|
||||
|
|
@ -272,7 +273,7 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(),
|
||||
rnn_desc_,
|
||||
x_desc,
|
||||
x_data,
|
||||
x_data_input,
|
||||
hx_desc,
|
||||
hx_data,
|
||||
cx_desc,
|
||||
|
|
|
|||
Loading…
Reference in a new issue