From e652a236b492e573bdaafe12b6fbe810de73aae8 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Wed, 21 Aug 2019 09:59:43 -0700 Subject: [PATCH] cudnnRNNForwardInferenceEx doesn't support 0 sequence in the bathes Fix issue that cudnnRNNForwardInferenceEx doesn't support 0 sequence in the bathes Solution: Reset the 0 sequence to 1 for the bathes before call the cudnnRNNForwardInferenceEx, has a array to track the batch id which has 0 sequence. Once get the result, call a CUDA kernel to mask on the output using the batch id tracked in the array. --- .../core/providers/cuda/rnn/cudnn_rnn_base.cc | 62 +++++++++++++++++-- .../core/providers/cuda/rnn/cudnn_rnn_base.h | 6 ++ .../core/providers/cuda/rnn/rnn_impl.cu | 50 ++++++++++++++- .../core/providers/cuda/rnn/rnn_impl.h | 7 +++ .../providers/cpu/rnn/deep_cpu_gru_op_test.cc | 33 ++++++++++ .../cpu/rnn/deep_cpu_lstm_op_test.cc | 40 ++++++++++++ 6 files changed, 192 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index 26d87bcdee..2b13aa5882 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -230,6 +230,9 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { size_t workspace_bytes; CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(CudnnHandle(), rnn_desc, gsl::narrow_cast(seq_length), x_desc.data(), &workspace_bytes)); auto workspace_cuda = GetScratchBuffer(workspace_bytes); + int32_t zero_seq_count = 0; + std::vector zero_seq_index_cache(batch_size, 0); + int64_t zero_seq_index_cache_size = 0; if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) { CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(), @@ -252,10 +255,32 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { workspace_cuda.get(), workspace_bytes)); } else { + // cudnn doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1 + // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence + std::vector seq_len_array(sequence_lens_data, sequence_lens_data + batch_size); + for (int i = 0; i < batch_size; ++i) { + if (0 == seq_len_array[i]) { + seq_len_array[i] = 1; + zero_seq_index_cache[zero_seq_count] = i; + ++zero_seq_count; + } + } + + // Calculate the zero position cache for reverse direction if it's bidirectional + // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since + // we hacked the 0 sequence to 1 + if (zero_seq_count && num_directions_ > 1) { + zero_seq_index_cache_size = zero_seq_count * num_directions_; + zero_seq_index_cache.resize(zero_seq_index_cache_size); + for (int i = 0; i < zero_seq_count; ++i) { + zero_seq_index_cache[zero_seq_count + i] = static_cast(batch_size + zero_seq_index_cache[i]); + } + } + CudnnDataTensor x_desc; - x_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, input_size, sequence_lens_data); + x_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, input_size, seq_len_array.data()); CudnnDataTensor y_desc; - y_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, sequence_lens_data); + y_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, seq_len_array.data()); CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(), rnn_desc, @@ -277,8 +302,13 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { nullptr, nullptr, nullptr, nullptr, workspace_cuda.get(), workspace_bytes)); + // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data. if (nullptr == Y) { + // Mask on output for 0 sequence batches + if (zero_seq_count > 0) { + SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data); + } return Status::OK(); } } @@ -312,10 +342,14 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { } } + // Mask on output for 0 sequence batches + if (zero_seq_count > 0) { + SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data); + } + if ((CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_) && sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) { - auto count = sequence_lens->Shape().Size(); - CudaAsyncBuffer sequence_lens_buffer(this, GetDeviceId(), count); - memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, count * sizeof(int32_t)); + CudaAsyncBuffer sequence_lens_buffer(this, GetDeviceId(), batch_size); + memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, batch_size * sizeof(int32_t)); sequence_lens_buffer.CopyToGpu(); RnnMaskImpl(gsl::narrow_cast(num_directions_), gsl::narrow_cast(seq_length), @@ -330,6 +364,24 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { return Status::OK(); } +template +void CudnnRnnBase::SetZeroSequences(const int64_t zero_seq_index_cache_size, + const std::vector zero_seq_index_cache, + T* y_data, + T* y_h_data, + T* y_c_data) const { + typedef typename ToCudaType::MappedType CudaT; + CudaAsyncBuffer zero_seq_index_cache_async_buffer(this, GetDeviceId(), zero_seq_index_cache_size); + memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), zero_seq_index_cache_size * sizeof(int32_t)); + zero_seq_index_cache_async_buffer.CopyToGpu(); + MaskZeroSequences(gsl::narrow_cast(hidden_size_), + reinterpret_cast(y_data), + reinterpret_cast(y_h_data), + reinterpret_cast(y_c_data), + zero_seq_index_cache_async_buffer.GpuPtr(), + static_cast(zero_seq_index_cache_size)); +} + template class CudnnRnnBase; template class CudnnRnnBase; template class CudnnRnnBase; diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index 3a4e234c19..6b7f4e9c14 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -132,6 +132,12 @@ class CudnnRnnBase : public CudaKernel { int& offset, bool is_matrix) const; + void SetZeroSequences(const int64_t zero_seq_index_cache_size, + const std::vector zero_seq_index_cache, + T* y_data, + T* y_h_data, + T* y_c_data) const; + protected: // W_lin_layer_id_ & R_lin_layer_id_ are set in Constructor std::vector W_lin_layer_id_; diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu index ae210ae681..930c3a4ddd 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu @@ -133,6 +133,48 @@ void RnnMaskImpl(const int32_t num_directions, div_dir_block, div_batch_block, y_output_data, y_h_output_data, (CUDA_LONG)N); } +template +__global__ void _MaskZeroSequences(const int32_t hidden_size, + T* y_output_data, + T* y_h_output_data, + T* y_c_output_data, + const int32_t* zeor_seq_index_cache, + const CUDA_LONG N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + int32_t zero_seq_offset = zeor_seq_index_cache[id] * hidden_size; + + if (y_output_data != nullptr) { + for (int i = 0; i < hidden_size; ++i) { + y_output_data[zero_seq_offset + i] = 0; + } + } + + if (y_h_output_data != nullptr) { + for (int i = 0; i < hidden_size; ++i) { + y_h_output_data[zero_seq_offset + i] = 0; + } + } + + if (y_c_output_data != nullptr) { + for (int i = 0; i < hidden_size; ++i) { + y_c_output_data[zero_seq_offset + i] = 0; + } + } +} + +template +void MaskZeroSequences(const int32_t hidden_size, + T* y_output_data, + T* y_h_output_data, + T* y_c_output_data, + const int32_t* zeor_seq_index_cache, + const size_t N) { + int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); + _MaskZeroSequences<<>>( + hidden_size, y_output_data, y_h_output_data, y_c_output_data, zeor_seq_index_cache, (CUDA_LONG)N); +} + #define SPECIALIZED_RNN_IMPL(T) \ template void RnnMaskImpl(const int32_t num_directions, \ const int32_t seq_length, \ @@ -153,7 +195,13 @@ void RnnMaskImpl(const int32_t num_directions, const int32_t hidden_size,\ const T* data, \ T* reordered_data, \ - const size_t N); + const size_t N); \ +template void MaskZeroSequences(const int32_t hidden_size, \ + T* y_output_data, \ + T* y_h_output_data, \ + T* y_c_output_data, \ + const int32_t* zeor_seq_index_cache, \ + const size_t N); SPECIALIZED_RNN_IMPL(half) SPECIALIZED_RNN_IMPL(float) diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h index d25d71aed3..78ceabf23b 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h @@ -34,5 +34,12 @@ void RnnMaskImpl(const int32_t num_directions, T* y_h_output_data, const size_t N); +template +void MaskZeroSequences(const int32_t hidden_size, + T* y_output_data, + T* y_h_output_data, + T* y_c_output_data, + const int32_t* zeor_seq_index_cache_async_buffer, + const size_t N); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc b/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc index fe9cf9a389..bf9e722f0e 100644 --- a/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc @@ -823,6 +823,39 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpShorterSeqInMiddle) { ctx.RunTest(X, batch_size, seq_length, sequence_length, &initial_h, expected_Y, expected_Y_h, true); } +TEST(GRUTest, ONNXRuntime_TestGRUOpZeroSeqInMiddle) { + const std::string direction = "bidirectional"; + const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + + DeepCpuGruOpTestContext ctx(direction, activations); + + const int batch_size = 4; + const int seq_length = 2; + std::vector X = {-0.455351f, -0.276391f, + 0.855351f, 0.676391f, + -0.185934f, -0.269585f, + -0.585934f, 0.669585f, + -0.351455f, -0.391276f, + 0.670351f, 0.894676f, + 0.987653f, 1.876567f, + -1.234357f, -0.775668f}; + std::vector sequence_length = {2, 0, 2, 2}; + std::vector initial_h = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + + std::vector expected_Y = {-0.0325528607f, 0.0774837881f, 0.0f, 0.0f, -0.0456649921f, 0.0462125241f, -0.1494070887f, 0.1356348693f, + -0.0398676469f, 0.1030099019f, 0.0f, 0.0f, -0.2552363872f, 0.1258624643f, -0.1111927852f, 0.1987708956f, + + -0.0317345410f, 0.0898682102f, 0.0f, 0.0f, -0.4344840049f, 0.1124109625f, -0.0373909101f, 0.1958667039f, + -0.0190722197f, 0.0559314489f, 0.0f, 0.0f, -0.4121740460f, 0.0858790874f, 0.0524947792f, 0.1172080263f}; + + std::vector expected_Y_h = {-0.0317345410f, 0.0898682102f, 0.0f, 0.0f, -0.4344840049f, 0.1124109625f, -0.0373909101f, 0.1958667039f, + + -0.0398676469f, 0.1030099019f, 0.0f, 0.0f, -0.2552363872f, 0.1258624643f, -0.1111927852f, 0.1987708956f}; + + ctx.RunTest(X, batch_size, seq_length, sequence_length, &initial_h, expected_Y, expected_Y_h, true); +} + TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithPartialZero) { const std::string direction = "bidirectional"; const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; diff --git a/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc b/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc index 34121981a7..b3661a9dbe 100644 --- a/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc @@ -1167,6 +1167,46 @@ TEST(LSTMTest, ONNXRuntime_TestLSTMShorterSeqInMiddle) { context.RunTest(X_data, batch_size, seq_len, nullptr, nullptr, Y_data, Y_h_data, Y_c_data, &sequence_length, use_bias, use_peepholes, 0.0f, false, false); } + +TEST(LSTMTest, ONNXRuntime_TestLSTMZeroSeqInMiddle) { + const int seq_len = 2; + int batch_size = 4; + std::vector activations = {"sigmoid", "tanh", "tanh", "sigmoid", "tanh", "tanh"}; + + bool use_bias = true; + bool use_peepholes = false; + + std::vector X_data = {-0.455351f, -0.776391f, + 0.0f, 0.0f, + 0.348763f, 0.678345f, + 0.877836f, 0.543859f, + + -0.185934f, -0.169585f, + 0.0f, 0.0f, + 0.078053f, 0.163457f, + 0.846098f, 0.987531f}; + + std::vector sequence_length = {2, 0, 1, 2}; + + std::vector Y_data = {0.02907280f, 0.01765226f, 0.0f, 0.0f, -0.15355367f, 0.04701351f, -0.12951779f, -0.00989562f, + 0.01841230f, 0.04093486f, 0.0f, 0.0f, -0.15355367f, 0.04701351f, -0.17956293f, 0.01607513f, + + -0.02912546f, 0.04120104f, 0.0f, 0.0f, 0.0f, 0.0f, -0.22162350f, 0.03132058f, + -0.04350187f, 0.03531464f, 0.0f, 0.0f, 0.0f, 0.0f, -0.17885581f, 0.01959856f}; + + std::vector Y_h_data = {-0.02912546f, 0.04120104f, 0.0f, 0.0f, -0.15355367f, 0.04701351f, -0.22162350f, 0.03132058f, + + 0.01841230f, 0.04093486f, 0.0f, 0.0f, -0.15355367f, 0.04701351f, -0.17956293f, 0.01607513f}; + + std::vector Y_c_data = {-0.06609819f, 0.06838701f, 0.0f, 0.0f, -0.2894889f, 0.07438067f, -0.39655977f, 0.05050645f, + + 0.04934450f, 0.07126625f, 0.0f, 0.0f, -0.28948891f, 0.07438067f, -0.34931409f, 0.02799958f}; + + std::string direction = "bidirectional"; + LstmOpContext2x1x2x2 context(direction, activations); + context.RunTest(X_data, batch_size, seq_len, nullptr, nullptr, Y_data, Y_h_data, Y_c_data, + &sequence_length, use_bias, use_peepholes, 0.0f, false, false); +} #endif // USE_NGRAPH } // namespace test