mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
cudnnRNNForwardInferenceEx doesn't support 0 sequence in the bathes
Fix issue that cudnnRNNForwardInferenceEx doesn't support 0 sequence in the bathes Solution: Reset the 0 sequence to 1 for the bathes before call the cudnnRNNForwardInferenceEx, has a array to track the batch id which has 0 sequence. Once get the result, call a CUDA kernel to mask on the output using the batch id tracked in the array.
This commit is contained in:
parent
d0d82432f3
commit
e652a236b4
6 changed files with 192 additions and 6 deletions
|
|
@ -230,6 +230,9 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
size_t workspace_bytes;
|
||||
CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(CudnnHandle(), rnn_desc, gsl::narrow_cast<int>(seq_length), x_desc.data(), &workspace_bytes));
|
||||
auto workspace_cuda = GetScratchBuffer<void>(workspace_bytes);
|
||||
int32_t zero_seq_count = 0;
|
||||
std::vector<int32_t> zero_seq_index_cache(batch_size, 0);
|
||||
int64_t zero_seq_index_cache_size = 0;
|
||||
|
||||
if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(),
|
||||
|
|
@ -252,10 +255,32 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
workspace_cuda.get(),
|
||||
workspace_bytes));
|
||||
} else {
|
||||
// cudnn doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1
|
||||
// there's a ZeroMask kernel to reset the result to 0 for the 0 sequence
|
||||
std::vector<int32_t> seq_len_array(sequence_lens_data, sequence_lens_data + batch_size);
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
if (0 == seq_len_array[i]) {
|
||||
seq_len_array[i] = 1;
|
||||
zero_seq_index_cache[zero_seq_count] = i;
|
||||
++zero_seq_count;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate the zero position cache for reverse direction if it's bidirectional
|
||||
// The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since
|
||||
// we hacked the 0 sequence to 1
|
||||
if (zero_seq_count && num_directions_ > 1) {
|
||||
zero_seq_index_cache_size = zero_seq_count * num_directions_;
|
||||
zero_seq_index_cache.resize(zero_seq_index_cache_size);
|
||||
for (int i = 0; i < zero_seq_count; ++i) {
|
||||
zero_seq_index_cache[zero_seq_count + i] = static_cast<int32_t>(batch_size + zero_seq_index_cache[i]);
|
||||
}
|
||||
}
|
||||
|
||||
CudnnDataTensor x_desc;
|
||||
x_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, input_size, sequence_lens_data);
|
||||
x_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, input_size, seq_len_array.data());
|
||||
CudnnDataTensor y_desc;
|
||||
y_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, hidden_size_ * num_directions_, sequence_lens_data);
|
||||
y_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, hidden_size_ * num_directions_, seq_len_array.data());
|
||||
|
||||
CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(),
|
||||
rnn_desc,
|
||||
|
|
@ -277,8 +302,13 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
nullptr, nullptr, nullptr, nullptr,
|
||||
workspace_cuda.get(),
|
||||
workspace_bytes));
|
||||
|
||||
// Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data.
|
||||
if (nullptr == Y) {
|
||||
// Mask on output for 0 sequence batches
|
||||
if (zero_seq_count > 0) {
|
||||
SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
|
@ -312,10 +342,14 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
}
|
||||
}
|
||||
|
||||
// Mask on output for 0 sequence batches
|
||||
if (zero_seq_count > 0) {
|
||||
SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data);
|
||||
}
|
||||
|
||||
if ((CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_) && sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) {
|
||||
auto count = sequence_lens->Shape().Size();
|
||||
CudaAsyncBuffer<int32_t> sequence_lens_buffer(this, GetDeviceId(), count);
|
||||
memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, count * sizeof(int32_t));
|
||||
CudaAsyncBuffer<int32_t> sequence_lens_buffer(this, GetDeviceId(), batch_size);
|
||||
memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, batch_size * sizeof(int32_t));
|
||||
sequence_lens_buffer.CopyToGpu();
|
||||
RnnMaskImpl(gsl::narrow_cast<int32_t>(num_directions_),
|
||||
gsl::narrow_cast<int32_t>(seq_length),
|
||||
|
|
@ -330,6 +364,24 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CudnnRnnBase<T>::SetZeroSequences(const int64_t zero_seq_index_cache_size,
|
||||
const std::vector<int32_t> zero_seq_index_cache,
|
||||
T* y_data,
|
||||
T* y_h_data,
|
||||
T* y_c_data) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
CudaAsyncBuffer<int32_t> zero_seq_index_cache_async_buffer(this, GetDeviceId(), zero_seq_index_cache_size);
|
||||
memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), zero_seq_index_cache_size * sizeof(int32_t));
|
||||
zero_seq_index_cache_async_buffer.CopyToGpu();
|
||||
MaskZeroSequences(gsl::narrow_cast<int32_t>(hidden_size_),
|
||||
reinterpret_cast<CudaT*>(y_data),
|
||||
reinterpret_cast<CudaT*>(y_h_data),
|
||||
reinterpret_cast<CudaT*>(y_c_data),
|
||||
zero_seq_index_cache_async_buffer.GpuPtr(),
|
||||
static_cast<int64_t>(zero_seq_index_cache_size));
|
||||
}
|
||||
|
||||
template class CudnnRnnBase<float>;
|
||||
template class CudnnRnnBase<double>;
|
||||
template class CudnnRnnBase<MLFloat16>;
|
||||
|
|
|
|||
|
|
@ -132,6 +132,12 @@ class CudnnRnnBase : public CudaKernel {
|
|||
int& offset,
|
||||
bool is_matrix) const;
|
||||
|
||||
void SetZeroSequences(const int64_t zero_seq_index_cache_size,
|
||||
const std::vector<int32_t> zero_seq_index_cache,
|
||||
T* y_data,
|
||||
T* y_h_data,
|
||||
T* y_c_data) const;
|
||||
|
||||
protected:
|
||||
// W_lin_layer_id_ & R_lin_layer_id_ are set in Constructor
|
||||
std::vector<int> W_lin_layer_id_;
|
||||
|
|
|
|||
|
|
@ -133,6 +133,48 @@ void RnnMaskImpl(const int32_t num_directions,
|
|||
div_dir_block, div_batch_block, y_output_data, y_h_output_data, (CUDA_LONG)N);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void _MaskZeroSequences(const int32_t hidden_size,
|
||||
T* y_output_data,
|
||||
T* y_h_output_data,
|
||||
T* y_c_output_data,
|
||||
const int32_t* zeor_seq_index_cache,
|
||||
const CUDA_LONG N) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
|
||||
int32_t zero_seq_offset = zeor_seq_index_cache[id] * hidden_size;
|
||||
|
||||
if (y_output_data != nullptr) {
|
||||
for (int i = 0; i < hidden_size; ++i) {
|
||||
y_output_data[zero_seq_offset + i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (y_h_output_data != nullptr) {
|
||||
for (int i = 0; i < hidden_size; ++i) {
|
||||
y_h_output_data[zero_seq_offset + i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (y_c_output_data != nullptr) {
|
||||
for (int i = 0; i < hidden_size; ++i) {
|
||||
y_c_output_data[zero_seq_offset + i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MaskZeroSequences(const int32_t hidden_size,
|
||||
T* y_output_data,
|
||||
T* y_h_output_data,
|
||||
T* y_c_output_data,
|
||||
const int32_t* zeor_seq_index_cache,
|
||||
const size_t N) {
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
|
||||
_MaskZeroSequences<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
hidden_size, y_output_data, y_h_output_data, y_c_output_data, zeor_seq_index_cache, (CUDA_LONG)N);
|
||||
}
|
||||
|
||||
#define SPECIALIZED_RNN_IMPL(T) \
|
||||
template void RnnMaskImpl<T>(const int32_t num_directions, \
|
||||
const int32_t seq_length, \
|
||||
|
|
@ -153,7 +195,13 @@ void RnnMaskImpl(const int32_t num_directions,
|
|||
const int32_t hidden_size,\
|
||||
const T* data, \
|
||||
T* reordered_data, \
|
||||
const size_t N);
|
||||
const size_t N); \
|
||||
template void MaskZeroSequences<T>(const int32_t hidden_size, \
|
||||
T* y_output_data, \
|
||||
T* y_h_output_data, \
|
||||
T* y_c_output_data, \
|
||||
const int32_t* zeor_seq_index_cache, \
|
||||
const size_t N);
|
||||
|
||||
SPECIALIZED_RNN_IMPL(half)
|
||||
SPECIALIZED_RNN_IMPL(float)
|
||||
|
|
|
|||
|
|
@ -34,5 +34,12 @@ void RnnMaskImpl(const int32_t num_directions,
|
|||
T* y_h_output_data,
|
||||
const size_t N);
|
||||
|
||||
template <typename T>
|
||||
void MaskZeroSequences(const int32_t hidden_size,
|
||||
T* y_output_data,
|
||||
T* y_h_output_data,
|
||||
T* y_c_output_data,
|
||||
const int32_t* zeor_seq_index_cache_async_buffer,
|
||||
const size_t N);
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -823,6 +823,39 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpShorterSeqInMiddle) {
|
|||
ctx.RunTest(X, batch_size, seq_length, sequence_length, &initial_h, expected_Y, expected_Y_h, true);
|
||||
}
|
||||
|
||||
TEST(GRUTest, ONNXRuntime_TestGRUOpZeroSeqInMiddle) {
|
||||
const std::string direction = "bidirectional";
|
||||
const std::vector<std::string> activations = {"sigmoid", "tanh", "sigmoid", "tanh"};
|
||||
|
||||
DeepCpuGruOpTestContext ctx(direction, activations);
|
||||
|
||||
const int batch_size = 4;
|
||||
const int seq_length = 2;
|
||||
std::vector<float> X = {-0.455351f, -0.276391f,
|
||||
0.855351f, 0.676391f,
|
||||
-0.185934f, -0.269585f,
|
||||
-0.585934f, 0.669585f,
|
||||
-0.351455f, -0.391276f,
|
||||
0.670351f, 0.894676f,
|
||||
0.987653f, 1.876567f,
|
||||
-1.234357f, -0.775668f};
|
||||
std::vector<int> sequence_length = {2, 0, 2, 2};
|
||||
std::vector<float> initial_h = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||
|
||||
std::vector<float> expected_Y = {-0.0325528607f, 0.0774837881f, 0.0f, 0.0f, -0.0456649921f, 0.0462125241f, -0.1494070887f, 0.1356348693f,
|
||||
-0.0398676469f, 0.1030099019f, 0.0f, 0.0f, -0.2552363872f, 0.1258624643f, -0.1111927852f, 0.1987708956f,
|
||||
|
||||
-0.0317345410f, 0.0898682102f, 0.0f, 0.0f, -0.4344840049f, 0.1124109625f, -0.0373909101f, 0.1958667039f,
|
||||
-0.0190722197f, 0.0559314489f, 0.0f, 0.0f, -0.4121740460f, 0.0858790874f, 0.0524947792f, 0.1172080263f};
|
||||
|
||||
std::vector<float> expected_Y_h = {-0.0317345410f, 0.0898682102f, 0.0f, 0.0f, -0.4344840049f, 0.1124109625f, -0.0373909101f, 0.1958667039f,
|
||||
|
||||
-0.0398676469f, 0.1030099019f, 0.0f, 0.0f, -0.2552363872f, 0.1258624643f, -0.1111927852f, 0.1987708956f};
|
||||
|
||||
ctx.RunTest(X, batch_size, seq_length, sequence_length, &initial_h, expected_Y, expected_Y_h, true);
|
||||
}
|
||||
|
||||
TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithPartialZero) {
|
||||
const std::string direction = "bidirectional";
|
||||
const std::vector<std::string> activations = {"sigmoid", "tanh", "sigmoid", "tanh"};
|
||||
|
|
|
|||
|
|
@ -1167,6 +1167,46 @@ TEST(LSTMTest, ONNXRuntime_TestLSTMShorterSeqInMiddle) {
|
|||
context.RunTest(X_data, batch_size, seq_len, nullptr, nullptr, Y_data, Y_h_data, Y_c_data,
|
||||
&sequence_length, use_bias, use_peepholes, 0.0f, false, false);
|
||||
}
|
||||
|
||||
TEST(LSTMTest, ONNXRuntime_TestLSTMZeroSeqInMiddle) {
|
||||
const int seq_len = 2;
|
||||
int batch_size = 4;
|
||||
std::vector<std::string> activations = {"sigmoid", "tanh", "tanh", "sigmoid", "tanh", "tanh"};
|
||||
|
||||
bool use_bias = true;
|
||||
bool use_peepholes = false;
|
||||
|
||||
std::vector<float> X_data = {-0.455351f, -0.776391f,
|
||||
0.0f, 0.0f,
|
||||
0.348763f, 0.678345f,
|
||||
0.877836f, 0.543859f,
|
||||
|
||||
-0.185934f, -0.169585f,
|
||||
0.0f, 0.0f,
|
||||
0.078053f, 0.163457f,
|
||||
0.846098f, 0.987531f};
|
||||
|
||||
std::vector<int> sequence_length = {2, 0, 1, 2};
|
||||
|
||||
std::vector<float> Y_data = {0.02907280f, 0.01765226f, 0.0f, 0.0f, -0.15355367f, 0.04701351f, -0.12951779f, -0.00989562f,
|
||||
0.01841230f, 0.04093486f, 0.0f, 0.0f, -0.15355367f, 0.04701351f, -0.17956293f, 0.01607513f,
|
||||
|
||||
-0.02912546f, 0.04120104f, 0.0f, 0.0f, 0.0f, 0.0f, -0.22162350f, 0.03132058f,
|
||||
-0.04350187f, 0.03531464f, 0.0f, 0.0f, 0.0f, 0.0f, -0.17885581f, 0.01959856f};
|
||||
|
||||
std::vector<float> Y_h_data = {-0.02912546f, 0.04120104f, 0.0f, 0.0f, -0.15355367f, 0.04701351f, -0.22162350f, 0.03132058f,
|
||||
|
||||
0.01841230f, 0.04093486f, 0.0f, 0.0f, -0.15355367f, 0.04701351f, -0.17956293f, 0.01607513f};
|
||||
|
||||
std::vector<float> Y_c_data = {-0.06609819f, 0.06838701f, 0.0f, 0.0f, -0.2894889f, 0.07438067f, -0.39655977f, 0.05050645f,
|
||||
|
||||
0.04934450f, 0.07126625f, 0.0f, 0.0f, -0.28948891f, 0.07438067f, -0.34931409f, 0.02799958f};
|
||||
|
||||
std::string direction = "bidirectional";
|
||||
LstmOpContext2x1x2x2 context(direction, activations);
|
||||
context.RunTest(X_data, batch_size, seq_len, nullptr, nullptr, Y_data, Y_h_data, Y_c_data,
|
||||
&sequence_length, use_bias, use_peepholes, 0.0f, false, false);
|
||||
}
|
||||
#endif // USE_NGRAPH
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
Loading…
Reference in a new issue