cudnnRNNForwardInferenceEx doesn't support 0 sequence in the bathes

Fix issue that cudnnRNNForwardInferenceEx doesn't support 0 sequence in the bathes

Solution:
Reset the 0 sequence to 1 for the bathes before call the cudnnRNNForwardInferenceEx, has a array to track the batch id which has 0 sequence. Once get the result, call a CUDA kernel to mask on the output using the batch id tracked in the array.
This commit is contained in:
Hector Li 2019-08-21 09:59:43 -07:00 committed by GitHub
parent d0d82432f3
commit e652a236b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 192 additions and 6 deletions

View file

@ -230,6 +230,9 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
size_t workspace_bytes;
CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(CudnnHandle(), rnn_desc, gsl::narrow_cast<int>(seq_length), x_desc.data(), &workspace_bytes));
auto workspace_cuda = GetScratchBuffer<void>(workspace_bytes);
int32_t zero_seq_count = 0;
std::vector<int32_t> zero_seq_index_cache(batch_size, 0);
int64_t zero_seq_index_cache_size = 0;
if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) {
CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(),
@ -252,10 +255,32 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
workspace_cuda.get(),
workspace_bytes));
} else {
// cudnn doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1
// there's a ZeroMask kernel to reset the result to 0 for the 0 sequence
std::vector<int32_t> seq_len_array(sequence_lens_data, sequence_lens_data + batch_size);
for (int i = 0; i < batch_size; ++i) {
if (0 == seq_len_array[i]) {
seq_len_array[i] = 1;
zero_seq_index_cache[zero_seq_count] = i;
++zero_seq_count;
}
}
// Calculate the zero position cache for reverse direction if it's bidirectional
// The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since
// we hacked the 0 sequence to 1
if (zero_seq_count && num_directions_ > 1) {
zero_seq_index_cache_size = zero_seq_count * num_directions_;
zero_seq_index_cache.resize(zero_seq_index_cache_size);
for (int i = 0; i < zero_seq_count; ++i) {
zero_seq_index_cache[zero_seq_count + i] = static_cast<int32_t>(batch_size + zero_seq_index_cache[i]);
}
}
CudnnDataTensor x_desc;
x_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, input_size, sequence_lens_data);
x_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, input_size, seq_len_array.data());
CudnnDataTensor y_desc;
y_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, hidden_size_ * num_directions_, sequence_lens_data);
y_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, hidden_size_ * num_directions_, seq_len_array.data());
CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(),
rnn_desc,
@ -277,8 +302,13 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
nullptr, nullptr, nullptr, nullptr,
workspace_cuda.get(),
workspace_bytes));
// Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data.
if (nullptr == Y) {
// Mask on output for 0 sequence batches
if (zero_seq_count > 0) {
SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data);
}
return Status::OK();
}
}
@ -312,10 +342,14 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
}
}
// Mask on output for 0 sequence batches
if (zero_seq_count > 0) {
SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data);
}
if ((CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_) && sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) {
auto count = sequence_lens->Shape().Size();
CudaAsyncBuffer<int32_t> sequence_lens_buffer(this, GetDeviceId(), count);
memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, count * sizeof(int32_t));
CudaAsyncBuffer<int32_t> sequence_lens_buffer(this, GetDeviceId(), batch_size);
memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, batch_size * sizeof(int32_t));
sequence_lens_buffer.CopyToGpu();
RnnMaskImpl(gsl::narrow_cast<int32_t>(num_directions_),
gsl::narrow_cast<int32_t>(seq_length),
@ -330,6 +364,24 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
return Status::OK();
}
template <typename T>
void CudnnRnnBase<T>::SetZeroSequences(const int64_t zero_seq_index_cache_size,
const std::vector<int32_t> zero_seq_index_cache,
T* y_data,
T* y_h_data,
T* y_c_data) const {
typedef typename ToCudaType<T>::MappedType CudaT;
CudaAsyncBuffer<int32_t> zero_seq_index_cache_async_buffer(this, GetDeviceId(), zero_seq_index_cache_size);
memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), zero_seq_index_cache_size * sizeof(int32_t));
zero_seq_index_cache_async_buffer.CopyToGpu();
MaskZeroSequences(gsl::narrow_cast<int32_t>(hidden_size_),
reinterpret_cast<CudaT*>(y_data),
reinterpret_cast<CudaT*>(y_h_data),
reinterpret_cast<CudaT*>(y_c_data),
zero_seq_index_cache_async_buffer.GpuPtr(),
static_cast<int64_t>(zero_seq_index_cache_size));
}
template class CudnnRnnBase<float>;
template class CudnnRnnBase<double>;
template class CudnnRnnBase<MLFloat16>;

View file

@ -132,6 +132,12 @@ class CudnnRnnBase : public CudaKernel {
int& offset,
bool is_matrix) const;
void SetZeroSequences(const int64_t zero_seq_index_cache_size,
const std::vector<int32_t> zero_seq_index_cache,
T* y_data,
T* y_h_data,
T* y_c_data) const;
protected:
// W_lin_layer_id_ & R_lin_layer_id_ are set in Constructor
std::vector<int> W_lin_layer_id_;

View file

@ -133,6 +133,48 @@ void RnnMaskImpl(const int32_t num_directions,
div_dir_block, div_batch_block, y_output_data, y_h_output_data, (CUDA_LONG)N);
}
template <typename T>
__global__ void _MaskZeroSequences(const int32_t hidden_size,
T* y_output_data,
T* y_h_output_data,
T* y_c_output_data,
const int32_t* zeor_seq_index_cache,
const CUDA_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
int32_t zero_seq_offset = zeor_seq_index_cache[id] * hidden_size;
if (y_output_data != nullptr) {
for (int i = 0; i < hidden_size; ++i) {
y_output_data[zero_seq_offset + i] = 0;
}
}
if (y_h_output_data != nullptr) {
for (int i = 0; i < hidden_size; ++i) {
y_h_output_data[zero_seq_offset + i] = 0;
}
}
if (y_c_output_data != nullptr) {
for (int i = 0; i < hidden_size; ++i) {
y_c_output_data[zero_seq_offset + i] = 0;
}
}
}
template <typename T>
void MaskZeroSequences(const int32_t hidden_size,
T* y_output_data,
T* y_h_output_data,
T* y_c_output_data,
const int32_t* zeor_seq_index_cache,
const size_t N) {
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
_MaskZeroSequences<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
hidden_size, y_output_data, y_h_output_data, y_c_output_data, zeor_seq_index_cache, (CUDA_LONG)N);
}
#define SPECIALIZED_RNN_IMPL(T) \
template void RnnMaskImpl<T>(const int32_t num_directions, \
const int32_t seq_length, \
@ -153,7 +195,13 @@ void RnnMaskImpl(const int32_t num_directions,
const int32_t hidden_size,\
const T* data, \
T* reordered_data, \
const size_t N);
const size_t N); \
template void MaskZeroSequences<T>(const int32_t hidden_size, \
T* y_output_data, \
T* y_h_output_data, \
T* y_c_output_data, \
const int32_t* zeor_seq_index_cache, \
const size_t N);
SPECIALIZED_RNN_IMPL(half)
SPECIALIZED_RNN_IMPL(float)

View file

@ -34,5 +34,12 @@ void RnnMaskImpl(const int32_t num_directions,
T* y_h_output_data,
const size_t N);
template <typename T>
void MaskZeroSequences(const int32_t hidden_size,
T* y_output_data,
T* y_h_output_data,
T* y_c_output_data,
const int32_t* zeor_seq_index_cache_async_buffer,
const size_t N);
} // namespace cuda
} // namespace onnxruntime

View file

@ -823,6 +823,39 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpShorterSeqInMiddle) {
ctx.RunTest(X, batch_size, seq_length, sequence_length, &initial_h, expected_Y, expected_Y_h, true);
}
TEST(GRUTest, ONNXRuntime_TestGRUOpZeroSeqInMiddle) {
const std::string direction = "bidirectional";
const std::vector<std::string> activations = {"sigmoid", "tanh", "sigmoid", "tanh"};
DeepCpuGruOpTestContext ctx(direction, activations);
const int batch_size = 4;
const int seq_length = 2;
std::vector<float> X = {-0.455351f, -0.276391f,
0.855351f, 0.676391f,
-0.185934f, -0.269585f,
-0.585934f, 0.669585f,
-0.351455f, -0.391276f,
0.670351f, 0.894676f,
0.987653f, 1.876567f,
-1.234357f, -0.775668f};
std::vector<int> sequence_length = {2, 0, 2, 2};
std::vector<float> initial_h = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
std::vector<float> expected_Y = {-0.0325528607f, 0.0774837881f, 0.0f, 0.0f, -0.0456649921f, 0.0462125241f, -0.1494070887f, 0.1356348693f,
-0.0398676469f, 0.1030099019f, 0.0f, 0.0f, -0.2552363872f, 0.1258624643f, -0.1111927852f, 0.1987708956f,
-0.0317345410f, 0.0898682102f, 0.0f, 0.0f, -0.4344840049f, 0.1124109625f, -0.0373909101f, 0.1958667039f,
-0.0190722197f, 0.0559314489f, 0.0f, 0.0f, -0.4121740460f, 0.0858790874f, 0.0524947792f, 0.1172080263f};
std::vector<float> expected_Y_h = {-0.0317345410f, 0.0898682102f, 0.0f, 0.0f, -0.4344840049f, 0.1124109625f, -0.0373909101f, 0.1958667039f,
-0.0398676469f, 0.1030099019f, 0.0f, 0.0f, -0.2552363872f, 0.1258624643f, -0.1111927852f, 0.1987708956f};
ctx.RunTest(X, batch_size, seq_length, sequence_length, &initial_h, expected_Y, expected_Y_h, true);
}
TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithPartialZero) {
const std::string direction = "bidirectional";
const std::vector<std::string> activations = {"sigmoid", "tanh", "sigmoid", "tanh"};

View file

@ -1167,6 +1167,46 @@ TEST(LSTMTest, ONNXRuntime_TestLSTMShorterSeqInMiddle) {
context.RunTest(X_data, batch_size, seq_len, nullptr, nullptr, Y_data, Y_h_data, Y_c_data,
&sequence_length, use_bias, use_peepholes, 0.0f, false, false);
}
TEST(LSTMTest, ONNXRuntime_TestLSTMZeroSeqInMiddle) {
const int seq_len = 2;
int batch_size = 4;
std::vector<std::string> activations = {"sigmoid", "tanh", "tanh", "sigmoid", "tanh", "tanh"};
bool use_bias = true;
bool use_peepholes = false;
std::vector<float> X_data = {-0.455351f, -0.776391f,
0.0f, 0.0f,
0.348763f, 0.678345f,
0.877836f, 0.543859f,
-0.185934f, -0.169585f,
0.0f, 0.0f,
0.078053f, 0.163457f,
0.846098f, 0.987531f};
std::vector<int> sequence_length = {2, 0, 1, 2};
std::vector<float> Y_data = {0.02907280f, 0.01765226f, 0.0f, 0.0f, -0.15355367f, 0.04701351f, -0.12951779f, -0.00989562f,
0.01841230f, 0.04093486f, 0.0f, 0.0f, -0.15355367f, 0.04701351f, -0.17956293f, 0.01607513f,
-0.02912546f, 0.04120104f, 0.0f, 0.0f, 0.0f, 0.0f, -0.22162350f, 0.03132058f,
-0.04350187f, 0.03531464f, 0.0f, 0.0f, 0.0f, 0.0f, -0.17885581f, 0.01959856f};
std::vector<float> Y_h_data = {-0.02912546f, 0.04120104f, 0.0f, 0.0f, -0.15355367f, 0.04701351f, -0.22162350f, 0.03132058f,
0.01841230f, 0.04093486f, 0.0f, 0.0f, -0.15355367f, 0.04701351f, -0.17956293f, 0.01607513f};
std::vector<float> Y_c_data = {-0.06609819f, 0.06838701f, 0.0f, 0.0f, -0.2894889f, 0.07438067f, -0.39655977f, 0.05050645f,
0.04934450f, 0.07126625f, 0.0f, 0.0f, -0.28948891f, 0.07438067f, -0.34931409f, 0.02799958f};
std::string direction = "bidirectional";
LstmOpContext2x1x2x2 context(direction, activations);
context.RunTest(X_data, batch_size, seq_len, nullptr, nullptr, Y_data, Y_h_data, Y_c_data,
&sequence_length, use_bias, use_peepholes, 0.0f, false, false);
}
#endif // USE_NGRAPH
} // namespace test