From bbdd1d658b7e83a5c56dcd80980312125f37e731 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Wed, 12 Jun 2019 09:51:15 -0700 Subject: [PATCH] use cudnnRNNForwardInferenceEx for unpacked (padded) layout case (#899) 1. Use cudnnRNNForwardInferenceEx for unpacked (padded) layout case, place the sequence_lens data on CPU 2. Fix hard code device ID issue. In cuda kernel, it should get the device id from provider. --- onnxruntime/core/providers/cuda/cuda_common.h | 2 + .../providers/cuda/cuda_execution_provider.h | 2 + .../core/providers/cuda/cudnn_common.cc | 35 +++++++ .../core/providers/cuda/cudnn_common.h | 19 ++++ .../core/providers/cuda/math/matmul.cc | 2 +- .../core/providers/cuda/rnn/cudnn_rnn_base.cc | 95 +++++++++++++------ .../core/providers/cuda/rnn/cudnn_rnn_base.h | 19 ++-- onnxruntime/core/providers/cuda/rnn/gru.cc | 21 ++-- onnxruntime/core/providers/cuda/rnn/lstm.cc | 21 ++-- onnxruntime/core/providers/cuda/rnn/rnn.cc | 21 ++-- onnxruntime/core/providers/cuda/tensor/pad.cc | 2 +- .../core/providers/cuda/tensor/slice.cc | 2 +- .../core/providers/cuda/tensor/split.cc | 2 +- .../core/providers/cuda/tensor/tile.cc | 2 +- .../core/providers/cuda/tensor/transpose.cc | 2 +- .../core/providers/cuda/tensor/upsample.cc | 2 +- .../providers/cpu/rnn/deep_cpu_gru_op_test.cc | 81 ++++++++-------- 17 files changed, 212 insertions(+), 118 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index d28c81ea9a..9ce0ffbfb9 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -140,6 +140,8 @@ class CudaKernel : public OpKernel { return provider_->CopyTensor(src, dst); } + inline int GetDeviceId() const { return provider_->GetDeviceId(); } + private: CUDAExecutionProvider* provider_; }; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 847df0cc2f..b42ffc8854 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -76,6 +76,8 @@ class CUDAExecutionProvider : public IExecutionProvider { GetCapability(const onnxruntime::GraphViewer& graph, const std::vector& kernel_registries) const override; + int GetDeviceId() const { return device_id_; } + private: cudaStream_t streams_[kTotalCudaStreams]; int device_id_; diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index ff3a09e73a..910f828c01 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -59,6 +59,41 @@ cudnnDataType_t CudnnTensor::GetDataType() { ORT_THROW("cuDNN engine currently supports only single/double/half precision data types."); } +CudnnDataTensor::CudnnDataTensor() + : tensor_(nullptr) { +} + +CudnnDataTensor::~CudnnDataTensor() { + if (tensor_ != nullptr) { + cudnnDestroyRNNDataDescriptor(tensor_); + tensor_ = nullptr; + } +} + +Status CudnnDataTensor::CreateTensorIfNeeded() { + if (!tensor_) + CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDataDescriptor(&tensor_)); + return Status::OK(); +} + +Status CudnnDataTensor::Set(cudnnDataType_t dataType, + int64_t max_seq_length, + int64_t batch_size, + int64_t data_size, + const int32_t* seq_lengths) { + ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); + + cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED; + float padding_fill = 0.0f; + CUDNN_RETURN_IF_ERROR(cudnnSetRNNDataDescriptor(tensor_, dataType, layout, + static_cast(max_seq_length), + static_cast(batch_size), + static_cast(data_size), + seq_lengths, + static_cast(&padding_fill))); + return Status::OK(); +} + CudnnFilterDescriptor::CudnnFilterDescriptor() : desc_(nullptr) { cudnnCreateFilterDescriptor(&desc_); } diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index 3b56b9c68a..158847e8d8 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -28,6 +28,25 @@ class CudnnTensor final { cudnnTensorDescriptor_t tensor_; }; +class CudnnDataTensor final { + public: + CudnnDataTensor(); + ~CudnnDataTensor(); + + Status Set(cudnnDataType_t dataType, + int64_t max_seq_length, + int64_t batch_size, + int64_t data_size, + const int32_t* seq_lengths); + + operator cudnnRNNDataDescriptor_t() const { return tensor_; } + + private: + Status CreateTensorIfNeeded(); + + cudnnRNNDataDescriptor_t tensor_; +}; + class CudnnFilterDescriptor final { public: CudnnFilterDescriptor(); diff --git a/onnxruntime/core/providers/cuda/math/matmul.cc b/onnxruntime/core/providers/cuda/math/matmul.cc index ad49d793be..c88d0f1613 100644 --- a/onnxruntime/core/providers/cuda/math/matmul.cc +++ b/onnxruntime/core/providers/cuda/math/matmul.cc @@ -67,7 +67,7 @@ Status MatMul::ComputeInternal(OpKernelContext* ctx) const { static_cast(helper.N()))); return Status::OK(); } - int device_id = 0; + int device_id = GetDeviceId(); CudaAsyncBuffer left_arrays(this, device_id, helper.LeftOffsets().size()); CudaAsyncBuffer right_arrays(this, device_id, helper.RightOffsets().size()); CudaAsyncBuffer output_arrays(this, device_id, helper.OutputOffsets().size()); diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index e56f59df16..7645fb763f 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -133,9 +133,9 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { const Tensor* W; const Tensor* R; const Tensor* B; - bool get_W = info.TryGetConstantInput(Input_Index::W, &W); - bool get_R = info.TryGetConstantInput(Input_Index::R, &R); - bool get_B = info.TryGetConstantInput(Input_Index::B, &B); + bool get_W = info.TryGetConstantInput(RNN_Input_Index::W, &W); + bool get_R = info.TryGetConstantInput(RNN_Input_Index::R, &R); + bool get_B = info.TryGetConstantInput(RNN_Input_Index::B, &B); if (get_W && get_R) { if (get_B) { @@ -154,15 +154,15 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { typedef typename ToCudaType::MappedType CudaT; // inputs - const Tensor* X = ctx->Input(Input_Index::X); // inputs. [seq_length, batch_size, input_size] + const Tensor* X = ctx->Input(RNN_Input_Index::X); // inputs. [seq_length, batch_size, input_size] ORT_ENFORCE(nullptr != X); // optional inputs - const Tensor* sequence_lens = ctx->Input(Input_Index::sequence_lens); // [batch_size] - const Tensor* initial_h = ctx->Input(Input_Index::initial_h); // initial hidden. [num_directions_, batch_size, hidden_size_] + const Tensor* sequence_lens = ctx->Input(RNN_Input_Index::sequence_lens); // [batch_size] + const Tensor* initial_h = ctx->Input(RNN_Input_Index::initial_h); // initial hidden. [num_directions_, batch_size, hidden_size_] const Tensor* initial_c(nullptr); if (rnn_mode_ == CUDNN_LSTM) { - initial_c = ctx->Input(Input_Index::initial_c); // initial cell. [num_directions_, batch_size, hidden_size_] + initial_c = ctx->Input(RNN_Input_Index::initial_c); // initial cell. [num_directions_, batch_size, hidden_size_] } int64_t seq_length = X->Shape()[0]; @@ -200,9 +200,9 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { IAllocatorUniquePtr w_data; CudnnFilterDescriptor w_desc; if (!weight_cached_) { - const Tensor& W = *ctx->Input(Input_Index::W); - const Tensor& R = *ctx->Input(Input_Index::R); - const Tensor* B = ctx->Input(Input_Index::B); + const Tensor& W = *ctx->Input(RNN_Input_Index::W); + const Tensor& R = *ctx->Input(RNN_Input_Index::R); + const Tensor* B = ctx->Input(RNN_Input_Index::B); ReorganizeWeights(&W, &R, B, w_data, w_desc); } @@ -240,25 +240,54 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { workspace_bytes *= num_directions_; auto workspace_cuda = GetScratchBuffer(workspace_bytes); - CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(), - rnn_desc_, - gsl::narrow_cast(seq_length), - x_desc.data(), - x_data, - hx_desc, - hx_data, - cx_desc, - cx_data, - weight_cached_ ? w_desc_cache_ : w_desc, - weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc.data(), - y_data, - y_h_desc, - y_h_data, - y_c_desc, - y_c_data, - workspace_cuda.get(), - workspace_bytes)); + if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) { + CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(), + rnn_desc_, + gsl::narrow_cast(seq_length), + x_desc.data(), + x_data, + hx_desc, + hx_data, + cx_desc, + cx_data, + weight_cached_ ? w_desc_cache_ : w_desc, + weight_cached_ ? w_data_cache_.get() : w_data.get(), + y_desc.data(), + y_data, + y_h_desc, + y_h_data, + y_c_desc, + y_c_data, + workspace_cuda.get(), + workspace_bytes)); + } else { + CudnnDataTensor x_desc; + x_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, input_size, sequence_lens_data); + CudnnDataTensor y_desc; + y_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, sequence_lens_data); + + CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(), + rnn_desc_, + x_desc, + x_data, + hx_desc, + hx_data, + cx_desc, + cx_data, + weight_cached_ ? w_desc_cache_ : w_desc, + weight_cached_ ? w_data_cache_.get() : w_data.get(), + y_desc, + y_data, + y_h_desc, + y_h_data, + y_c_desc, + y_c_data, + nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, + workspace_cuda.get(), + workspace_bytes)); + } + IAllocatorUniquePtr y_reorganized_data; if (reverse_ || num_directions_ == 2) { //reverse output @@ -288,12 +317,16 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { } } - if (sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) { + if ((CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_) && sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) { + auto count = sequence_lens->Shape().Size(); + CudaAsyncBuffer sequence_lens_buffer(this, GetDeviceId(), count); + memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, count * sizeof(int32_t)); + sequence_lens_buffer.CopyToGpu(); RnnMaskImpl(gsl::narrow_cast(num_directions_), gsl::narrow_cast(seq_length), gsl::narrow_cast(batch_size), gsl::narrow_cast(hidden_size_), - sequence_lens_data, + sequence_lens_buffer.GpuPtr(), reinterpret_cast(y_data), reinterpret_cast(y_h_data), output_size); diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index 4d44c96880..1c08d5af2b 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -11,6 +11,16 @@ namespace onnxruntime { namespace cuda { +enum RNN_Input_Index { + X = 0, + W = 1, + R = 2, + B = 3, + sequence_lens = 4, + initial_h = 5, + initial_c = 6 +}; + class CudnnDropout { public: CudnnDropout() : dropout_desc_(nullptr) { @@ -158,15 +168,6 @@ class CudnnRnnBase : public CudaKernel { IAllocatorUniquePtr w_data_cache_; bool weight_cached_; - enum Input_Index { - X = 0, - W = 1, - R = 2, - B = 3, - sequence_lens = 4, - initial_h = 5, - initial_c = 6 - }; enum Output_Index { Y = 0, Y_h = 1, diff --git a/onnxruntime/core/providers/cuda/rnn/gru.cc b/onnxruntime/core/providers/cuda/rnn/gru.cc index 38f311008c..c2ad3956a1 100644 --- a/onnxruntime/core/providers/cuda/rnn/gru.cc +++ b/onnxruntime/core/providers/cuda/rnn/gru.cc @@ -10,16 +10,17 @@ namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GRU, \ - kOnnxDomain, \ - 7, \ - T, \ - kCudaExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GRU, \ + kOnnxDomain, \ + 7, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(RNN_Input_Index::sequence_lens), \ GRU); REGISTER_KERNEL_TYPED(float); diff --git a/onnxruntime/core/providers/cuda/rnn/lstm.cc b/onnxruntime/core/providers/cuda/rnn/lstm.cc index e78ddf534f..e8100dfc0b 100644 --- a/onnxruntime/core/providers/cuda/rnn/lstm.cc +++ b/onnxruntime/core/providers/cuda/rnn/lstm.cc @@ -8,16 +8,17 @@ namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - LSTM, \ - kOnnxDomain, \ - 7, \ - T, \ - kCudaExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + LSTM, \ + kOnnxDomain, \ + 7, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(RNN_Input_Index::sequence_lens), \ LSTM); REGISTER_KERNEL_TYPED(float); diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.cc b/onnxruntime/core/providers/cuda/rnn/rnn.cc index 277b5a88b6..3d0b3e49cd 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.cc +++ b/onnxruntime/core/providers/cuda/rnn/rnn.cc @@ -10,16 +10,17 @@ namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - RNN, \ - kOnnxDomain, \ - 7, \ - T, \ - kCudaExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RNN, \ + kOnnxDomain, \ + 7, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(RNN_Input_Index::sequence_lens), \ RNN); REGISTER_KERNEL_TYPED(float); diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc index f57779706c..e5f90dab62 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad.cc +++ b/onnxruntime/core/providers/cuda/tensor/pad.cc @@ -24,7 +24,7 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { const auto& input_tensor = *ctx->Input(0); auto const& input_shape = input_tensor.Shape(); auto dimension_count = input_shape.NumDimensions(); - int device_id = 0; + int device_id = GetDeviceId(); CudaAsyncBuffer input_dims(this, device_id, input_shape.GetDims()); CudaAsyncBuffer input_strides(this, device_id, dimension_count); CudaAsyncBuffer lower_pads(this, device_id, dimension_count); diff --git a/onnxruntime/core/providers/cuda/tensor/slice.cc b/onnxruntime/core/providers/cuda/tensor/slice.cc index 53f9293ccd..83379f89af 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.cc +++ b/onnxruntime/core/providers/cuda/tensor/slice.cc @@ -85,7 +85,7 @@ Status Slice::ComputeInternal(OpKernelContext* ctx) const { if (output_size == 0) { return Status::OK(); } - int device_id = 0; + int device_id = GetDeviceId(); CudaAsyncBuffer starts_buffer(this, device_id, dimension_count); gsl::span starts_buffer_span = starts_buffer.CpuSpan(); for (int i = 0; i < dimension_count; ++i) { diff --git a/onnxruntime/core/providers/cuda/tensor/split.cc b/onnxruntime/core/providers/cuda/tensor/split.cc index 6b222d6e62..f6dd317c5b 100644 --- a/onnxruntime/core/providers/cuda/tensor/split.cc +++ b/onnxruntime/core/providers/cuda/tensor/split.cc @@ -41,7 +41,7 @@ Status Split::ComputeInternal(OpKernelContext* ctx) const { auto& input_dims = input_shape.GetDims(); std::vector output_dimensions{input_dims}; - int device_id = 0; + int device_id = GetDeviceId(); CudaAsyncBuffer output_ptr(this, device_id, num_outputs); gsl::span output_ptr_span = output_ptr.CpuSpan(); for (int i = 0; i < num_outputs; ++i) { diff --git a/onnxruntime/core/providers/cuda/tensor/tile.cc b/onnxruntime/core/providers/cuda/tensor/tile.cc index 263a789bce..390d9139de 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile.cc +++ b/onnxruntime/core/providers/cuda/tensor/tile.cc @@ -42,7 +42,7 @@ Status Tile::ComputeInternal(OpKernelContext* ctx) const { T* output_data = output_tensor.template MutableData(); const T* input_data = input_tensor.template Data(); - int device_id = 0; + int device_id = GetDeviceId(); CudaAsyncBuffer input_strides(this, device_id, rank); CudaAsyncBuffer fdm_input_shape(this, device_id, rank); CudaAsyncBuffer fdm_output_strides(this, device_id, rank); diff --git a/onnxruntime/core/providers/cuda/tensor/transpose.cc b/onnxruntime/core/providers/cuda/tensor/transpose.cc index 04e69bbc8c..53f6d32715 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose.cc +++ b/onnxruntime/core/providers/cuda/tensor/transpose.cc @@ -92,7 +92,7 @@ Status Transpose::ComputeInternal(OpKernelContext* ctx) const { return Status::OK(); } - int device_id = 0; + int device_id = GetDeviceId(); CudaAsyncBuffer input_strides(this, device_id, rank); CudaAsyncBuffer perm(this, device_id, *p_perm); CudaAsyncBuffer fdm_output_strides(this, device_id, rank); diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index 1f5797dec8..88248983d7 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -51,7 +51,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, const std::vector::MappedType CudaT; // kernel - int device_id = 0; + int device_id = GetDeviceId(); TensorPitches input_pitches(X_dims); CudaAsyncBuffer input_strides(this, device_id, rank); gsl::span input_stride_span = input_strides.CpuSpan(); diff --git a/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc b/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc index e0010423b1..3314a8e627 100644 --- a/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc @@ -474,19 +474,19 @@ void DeepCpuGruOpTestContext::RunTest(const std::vector& X, alphas_, betas_); - //::onnxruntime::test::RunGruTest(X, gru_input_weights_, gru_recurrent_weights_, - // expected_Y, expected_Y_h, - // input_size_, batch_size, hidden_dim_, seq_length, - // use_bias_ ? &gru_bias_ : nullptr, - // initial_h, - // &sequence_lens, - // direction_, - // 9999999999.f, - // /*output_sequence*/ false, - // linear_before_reset, - // activation_func_names_, - // alphas_, - // betas_); + ::onnxruntime::test::RunGruTest(X, gru_input_weights_, gru_recurrent_weights_, + expected_Y, expected_Y_h, + input_size_, batch_size, hidden_dim_, seq_length, + use_bias_ ? &gru_bias_ : nullptr, + initial_h, + &sequence_lens, + direction_, + 9999999999.f, + /*output_sequence*/ false, + linear_before_reset, + activation_func_names_, + alphas_, + betas_); } TEST(GRUTest, ONNXRuntime_TestGRUOpForwardBasic) { @@ -762,34 +762,33 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeRe ctx.RunTest(X, batch_size, seq_length, sequence_length, &initial_h, expected_Y, expected_Y_h, true); } -// Need CPU fix -//TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeReset) { -// const std::string direction = "bidirectional"; -// const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; -// -// DeepCpuGruOpTestContext ctx(direction, activations); -// -// const int batch_size = 2; -// const int seq_length = 2; -// std::vector X = {-0.455351f, -0.276391f, -// 0.855351f, 0.676391f, -// -0.185934f, -0.269585f, -// 0.585934f, 0.669585f}; -// std::vector sequence_length = {2, 1}; -// std::vector initial_h = {0.0f, 0.0f, 0.0f, 0.0f, -// 0.0f, 0.0f, 0.0f, 0.0f}; -// std::vector expected_Y = {-0.0325528607f, 0.0774837881f, -0.275918573f, -0.00228558504f, -// -0.0559310019f, 0.101836264f, -0.385578573f, 0.0370728001f, -// -// -0.0577347837f, 0.0796165839f, 0.0f, 0.0f, -// -0.0456649922f, 0.0462125242f, 0.0f, 0.0f}; -// std::vector expected_Y_h = {-0.0577347837f, 0.0796165839f, -// -0.275918573f, -0.00228558504f, -// -0.0559310019f, 0.101836264f, -// -0.385578573f, 0.0370728001f}; -// -// ctx.RunTest(X, batch_size, seq_length, sequence_length, &initial_h, expected_Y, expected_Y_h, true); -//} +TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeReset) { + const std::string direction = "bidirectional"; + const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + + DeepCpuGruOpTestContext ctx(direction, activations); + + const int batch_size = 2; + const int seq_length = 2; + std::vector X = {-0.455351f, -0.276391f, + 0.855351f, 0.676391f, + -0.185934f, -0.269585f, + 0.585934f, 0.669585f}; + std::vector sequence_length = {2, 1}; + std::vector initial_h = {0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f}; + std::vector expected_Y = {-0.0325528607f, 0.0774837881f, -0.275918573f, -0.00228558504f, + -0.0559310019f, 0.101836264f, -0.275918573f, -0.00228558504f, + + -0.0577347837f, 0.0796165839f, 0.0f, 0.0f, + -0.0456649922f, 0.0462125242f, 0.0f, 0.0f}; + std::vector expected_Y_h = {-0.0577347837f, 0.0796165839f, + -0.275918573f, -0.00228558504f, + -0.0559310019f, 0.101836264f, + -0.275918573f, -0.00228558504f}; + + ctx.RunTest(X, batch_size, seq_length, sequence_length, &initial_h, expected_Y, expected_Y_h, true); +} TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithPartialZero) { const std::string direction = "bidirectional";