remove const_cast which makes it's not thread safe. (#1463)

This commit is contained in:
Hector Li 2019-07-22 17:55:29 -07:00 committed by GitHub
parent 6be93f11e5
commit 31838fc9ee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -207,19 +207,20 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
}
IAllocatorUniquePtr<T> x_reversed_data;
T* x_data = const_cast<T*>(X->template Data<T>());
const T* x_data = X->template Data<T>();
if (reverse_) {
// reverse input data
x_reversed_data = GetScratchBuffer<T>(seq_length * batch_size * input_size);
ReverseBySequence(gsl::narrow_cast<int32_t>(seq_length),
gsl::narrow_cast<int32_t>(batch_size),
gsl::narrow_cast<int32_t>(input_size),
reinterpret_cast<CudaT*>(x_data),
reinterpret_cast<const CudaT*>(x_data),
reinterpret_cast<CudaT*>(x_reversed_data.get()),
seq_length * batch_size * input_size);
x_data = x_reversed_data.get();
}
const T* x_data_input = reverse_ ? x_reversed_data.get() : x_data;
const T* hx_data = (initial_h == nullptr) ? nullptr : initial_h->template Data<T>();
const T* cx_data = (initial_c == nullptr) ? nullptr : initial_c->template Data<T>();
T* y_h_data = (Y_h == nullptr) ? nullptr : Y_h->template MutableData<T>();
@ -248,7 +249,7 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
rnn_desc_,
gsl::narrow_cast<int>(seq_length),
x_desc.data(),
x_data,
x_data_input,
hx_desc,
hx_data,
cx_desc,
@ -272,7 +273,7 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(),
rnn_desc_,
x_desc,
x_data,
x_data_input,
hx_desc,
hx_data,
cx_desc,