mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
use cudnnRNNForwardInferenceEx for unpacked (padded) layout case (#899)
1. Use cudnnRNNForwardInferenceEx for unpacked (padded) layout case, place the sequence_lens data on CPU 2. Fix hard code device ID issue. In cuda kernel, it should get the device id from provider.
This commit is contained in:
parent
38963d81eb
commit
bbdd1d658b
17 changed files with 212 additions and 118 deletions
|
|
@ -140,6 +140,8 @@ class CudaKernel : public OpKernel {
|
|||
return provider_->CopyTensor(src, dst);
|
||||
}
|
||||
|
||||
inline int GetDeviceId() const { return provider_->GetDeviceId(); }
|
||||
|
||||
private:
|
||||
CUDAExecutionProvider* provider_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -76,6 +76,8 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
GetCapability(const onnxruntime::GraphViewer& graph,
|
||||
const std::vector<const KernelRegistry*>& kernel_registries) const override;
|
||||
|
||||
int GetDeviceId() const { return device_id_; }
|
||||
|
||||
private:
|
||||
cudaStream_t streams_[kTotalCudaStreams];
|
||||
int device_id_;
|
||||
|
|
|
|||
|
|
@ -59,6 +59,41 @@ cudnnDataType_t CudnnTensor::GetDataType() {
|
|||
ORT_THROW("cuDNN engine currently supports only single/double/half precision data types.");
|
||||
}
|
||||
|
||||
CudnnDataTensor::CudnnDataTensor()
|
||||
: tensor_(nullptr) {
|
||||
}
|
||||
|
||||
CudnnDataTensor::~CudnnDataTensor() {
|
||||
if (tensor_ != nullptr) {
|
||||
cudnnDestroyRNNDataDescriptor(tensor_);
|
||||
tensor_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Status CudnnDataTensor::CreateTensorIfNeeded() {
|
||||
if (!tensor_)
|
||||
CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDataDescriptor(&tensor_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CudnnDataTensor::Set(cudnnDataType_t dataType,
|
||||
int64_t max_seq_length,
|
||||
int64_t batch_size,
|
||||
int64_t data_size,
|
||||
const int32_t* seq_lengths) {
|
||||
ORT_RETURN_IF_ERROR(CreateTensorIfNeeded());
|
||||
|
||||
cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED;
|
||||
float padding_fill = 0.0f;
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetRNNDataDescriptor(tensor_, dataType, layout,
|
||||
static_cast<int>(max_seq_length),
|
||||
static_cast<int>(batch_size),
|
||||
static_cast<int>(data_size),
|
||||
seq_lengths,
|
||||
static_cast<void*>(&padding_fill)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
CudnnFilterDescriptor::CudnnFilterDescriptor() : desc_(nullptr) {
|
||||
cudnnCreateFilterDescriptor(&desc_);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,25 @@ class CudnnTensor final {
|
|||
cudnnTensorDescriptor_t tensor_;
|
||||
};
|
||||
|
||||
class CudnnDataTensor final {
|
||||
public:
|
||||
CudnnDataTensor();
|
||||
~CudnnDataTensor();
|
||||
|
||||
Status Set(cudnnDataType_t dataType,
|
||||
int64_t max_seq_length,
|
||||
int64_t batch_size,
|
||||
int64_t data_size,
|
||||
const int32_t* seq_lengths);
|
||||
|
||||
operator cudnnRNNDataDescriptor_t() const { return tensor_; }
|
||||
|
||||
private:
|
||||
Status CreateTensorIfNeeded();
|
||||
|
||||
cudnnRNNDataDescriptor_t tensor_;
|
||||
};
|
||||
|
||||
class CudnnFilterDescriptor final {
|
||||
public:
|
||||
CudnnFilterDescriptor();
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ Status MatMul<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
static_cast<int>(helper.N())));
|
||||
return Status::OK();
|
||||
}
|
||||
int device_id = 0;
|
||||
int device_id = GetDeviceId();
|
||||
CudaAsyncBuffer<const CudaT*> left_arrays(this, device_id, helper.LeftOffsets().size());
|
||||
CudaAsyncBuffer<const CudaT*> right_arrays(this, device_id, helper.RightOffsets().size());
|
||||
CudaAsyncBuffer<CudaT*> output_arrays(this, device_id, helper.OutputOffsets().size());
|
||||
|
|
|
|||
|
|
@ -133,9 +133,9 @@ Status CudnnRnnBase<T>::CacheCudnnRnnWeights(const OpKernelInfo& info) {
|
|||
const Tensor* W;
|
||||
const Tensor* R;
|
||||
const Tensor* B;
|
||||
bool get_W = info.TryGetConstantInput(Input_Index::W, &W);
|
||||
bool get_R = info.TryGetConstantInput(Input_Index::R, &R);
|
||||
bool get_B = info.TryGetConstantInput(Input_Index::B, &B);
|
||||
bool get_W = info.TryGetConstantInput(RNN_Input_Index::W, &W);
|
||||
bool get_R = info.TryGetConstantInput(RNN_Input_Index::R, &R);
|
||||
bool get_B = info.TryGetConstantInput(RNN_Input_Index::B, &B);
|
||||
|
||||
if (get_W && get_R) {
|
||||
if (get_B) {
|
||||
|
|
@ -154,15 +154,15 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
||||
// inputs
|
||||
const Tensor* X = ctx->Input<Tensor>(Input_Index::X); // inputs. [seq_length, batch_size, input_size]
|
||||
const Tensor* X = ctx->Input<Tensor>(RNN_Input_Index::X); // inputs. [seq_length, batch_size, input_size]
|
||||
ORT_ENFORCE(nullptr != X);
|
||||
|
||||
// optional inputs
|
||||
const Tensor* sequence_lens = ctx->Input<Tensor>(Input_Index::sequence_lens); // [batch_size]
|
||||
const Tensor* initial_h = ctx->Input<Tensor>(Input_Index::initial_h); // initial hidden. [num_directions_, batch_size, hidden_size_]
|
||||
const Tensor* sequence_lens = ctx->Input<Tensor>(RNN_Input_Index::sequence_lens); // [batch_size]
|
||||
const Tensor* initial_h = ctx->Input<Tensor>(RNN_Input_Index::initial_h); // initial hidden. [num_directions_, batch_size, hidden_size_]
|
||||
const Tensor* initial_c(nullptr);
|
||||
if (rnn_mode_ == CUDNN_LSTM) {
|
||||
initial_c = ctx->Input<Tensor>(Input_Index::initial_c); // initial cell. [num_directions_, batch_size, hidden_size_]
|
||||
initial_c = ctx->Input<Tensor>(RNN_Input_Index::initial_c); // initial cell. [num_directions_, batch_size, hidden_size_]
|
||||
}
|
||||
|
||||
int64_t seq_length = X->Shape()[0];
|
||||
|
|
@ -200,9 +200,9 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
IAllocatorUniquePtr<void> w_data;
|
||||
CudnnFilterDescriptor w_desc;
|
||||
if (!weight_cached_) {
|
||||
const Tensor& W = *ctx->Input<Tensor>(Input_Index::W);
|
||||
const Tensor& R = *ctx->Input<Tensor>(Input_Index::R);
|
||||
const Tensor* B = ctx->Input<Tensor>(Input_Index::B);
|
||||
const Tensor& W = *ctx->Input<Tensor>(RNN_Input_Index::W);
|
||||
const Tensor& R = *ctx->Input<Tensor>(RNN_Input_Index::R);
|
||||
const Tensor* B = ctx->Input<Tensor>(RNN_Input_Index::B);
|
||||
ReorganizeWeights(&W, &R, B, w_data, w_desc);
|
||||
}
|
||||
|
||||
|
|
@ -240,25 +240,54 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
workspace_bytes *= num_directions_;
|
||||
auto workspace_cuda = GetScratchBuffer<void>(workspace_bytes);
|
||||
|
||||
CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(),
|
||||
rnn_desc_,
|
||||
gsl::narrow_cast<int>(seq_length),
|
||||
x_desc.data(),
|
||||
x_data,
|
||||
hx_desc,
|
||||
hx_data,
|
||||
cx_desc,
|
||||
cx_data,
|
||||
weight_cached_ ? w_desc_cache_ : w_desc,
|
||||
weight_cached_ ? w_data_cache_.get() : w_data.get(),
|
||||
y_desc.data(),
|
||||
y_data,
|
||||
y_h_desc,
|
||||
y_h_data,
|
||||
y_c_desc,
|
||||
y_c_data,
|
||||
workspace_cuda.get(),
|
||||
workspace_bytes));
|
||||
if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(),
|
||||
rnn_desc_,
|
||||
gsl::narrow_cast<int>(seq_length),
|
||||
x_desc.data(),
|
||||
x_data,
|
||||
hx_desc,
|
||||
hx_data,
|
||||
cx_desc,
|
||||
cx_data,
|
||||
weight_cached_ ? w_desc_cache_ : w_desc,
|
||||
weight_cached_ ? w_data_cache_.get() : w_data.get(),
|
||||
y_desc.data(),
|
||||
y_data,
|
||||
y_h_desc,
|
||||
y_h_data,
|
||||
y_c_desc,
|
||||
y_c_data,
|
||||
workspace_cuda.get(),
|
||||
workspace_bytes));
|
||||
} else {
|
||||
CudnnDataTensor x_desc;
|
||||
x_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, input_size, sequence_lens_data);
|
||||
CudnnDataTensor y_desc;
|
||||
y_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, hidden_size_ * num_directions_, sequence_lens_data);
|
||||
|
||||
CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(),
|
||||
rnn_desc_,
|
||||
x_desc,
|
||||
x_data,
|
||||
hx_desc,
|
||||
hx_data,
|
||||
cx_desc,
|
||||
cx_data,
|
||||
weight_cached_ ? w_desc_cache_ : w_desc,
|
||||
weight_cached_ ? w_data_cache_.get() : w_data.get(),
|
||||
y_desc,
|
||||
y_data,
|
||||
y_h_desc,
|
||||
y_h_data,
|
||||
y_c_desc,
|
||||
y_c_data,
|
||||
nullptr, nullptr, nullptr, nullptr,
|
||||
nullptr, nullptr, nullptr, nullptr,
|
||||
workspace_cuda.get(),
|
||||
workspace_bytes));
|
||||
}
|
||||
|
||||
IAllocatorUniquePtr<T> y_reorganized_data;
|
||||
if (reverse_ || num_directions_ == 2) {
|
||||
//reverse output
|
||||
|
|
@ -288,12 +317,16 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
}
|
||||
}
|
||||
|
||||
if (sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) {
|
||||
if ((CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_) && sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) {
|
||||
auto count = sequence_lens->Shape().Size();
|
||||
CudaAsyncBuffer<int32_t> sequence_lens_buffer(this, GetDeviceId(), count);
|
||||
memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, count * sizeof(int32_t));
|
||||
sequence_lens_buffer.CopyToGpu();
|
||||
RnnMaskImpl(gsl::narrow_cast<int32_t>(num_directions_),
|
||||
gsl::narrow_cast<int32_t>(seq_length),
|
||||
gsl::narrow_cast<int32_t>(batch_size),
|
||||
gsl::narrow_cast<int32_t>(hidden_size_),
|
||||
sequence_lens_data,
|
||||
sequence_lens_buffer.GpuPtr(),
|
||||
reinterpret_cast<CudaT*>(y_data),
|
||||
reinterpret_cast<CudaT*>(y_h_data),
|
||||
output_size);
|
||||
|
|
|
|||
|
|
@ -11,6 +11,16 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
enum RNN_Input_Index {
|
||||
X = 0,
|
||||
W = 1,
|
||||
R = 2,
|
||||
B = 3,
|
||||
sequence_lens = 4,
|
||||
initial_h = 5,
|
||||
initial_c = 6
|
||||
};
|
||||
|
||||
class CudnnDropout {
|
||||
public:
|
||||
CudnnDropout() : dropout_desc_(nullptr) {
|
||||
|
|
@ -158,15 +168,6 @@ class CudnnRnnBase : public CudaKernel {
|
|||
IAllocatorUniquePtr<void> w_data_cache_;
|
||||
bool weight_cached_;
|
||||
|
||||
enum Input_Index {
|
||||
X = 0,
|
||||
W = 1,
|
||||
R = 2,
|
||||
B = 3,
|
||||
sequence_lens = 4,
|
||||
initial_h = 5,
|
||||
initial_c = 6
|
||||
};
|
||||
enum Output_Index {
|
||||
Y = 0,
|
||||
Y_h = 1,
|
||||
|
|
|
|||
|
|
@ -10,16 +10,17 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
GRU, \
|
||||
kOnnxDomain, \
|
||||
7, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>()), \
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
GRU, \
|
||||
kOnnxDomain, \
|
||||
7, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>()) \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(RNN_Input_Index::sequence_lens), \
|
||||
GRU<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float);
|
||||
|
|
|
|||
|
|
@ -8,16 +8,17 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
LSTM, \
|
||||
kOnnxDomain, \
|
||||
7, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>()), \
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
LSTM, \
|
||||
kOnnxDomain, \
|
||||
7, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>()) \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(RNN_Input_Index::sequence_lens), \
|
||||
LSTM<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float);
|
||||
|
|
|
|||
|
|
@ -10,16 +10,17 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
RNN, \
|
||||
kOnnxDomain, \
|
||||
7, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>()), \
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
RNN, \
|
||||
kOnnxDomain, \
|
||||
7, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>()) \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(RNN_Input_Index::sequence_lens), \
|
||||
RNN<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float);
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ Status Pad<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
const auto& input_tensor = *ctx->Input<Tensor>(0);
|
||||
auto const& input_shape = input_tensor.Shape();
|
||||
auto dimension_count = input_shape.NumDimensions();
|
||||
int device_id = 0;
|
||||
int device_id = GetDeviceId();
|
||||
CudaAsyncBuffer<int64_t> input_dims(this, device_id, input_shape.GetDims());
|
||||
CudaAsyncBuffer<int64_t> input_strides(this, device_id, dimension_count);
|
||||
CudaAsyncBuffer<int64_t> lower_pads(this, device_id, dimension_count);
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ Status Slice<Tind, dynamic>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
if (output_size == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
int device_id = 0;
|
||||
int device_id = GetDeviceId();
|
||||
CudaAsyncBuffer<int64_t> starts_buffer(this, device_id, dimension_count);
|
||||
gsl::span<int64_t> starts_buffer_span = starts_buffer.CpuSpan();
|
||||
for (int i = 0; i < dimension_count; ++i) {
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ Status Split::ComputeInternal(OpKernelContext* ctx) const {
|
|||
auto& input_dims = input_shape.GetDims();
|
||||
std::vector<int64_t> output_dimensions{input_dims};
|
||||
|
||||
int device_id = 0;
|
||||
int device_id = GetDeviceId();
|
||||
CudaAsyncBuffer<void*> output_ptr(this, device_id, num_outputs);
|
||||
gsl::span<void*> output_ptr_span = output_ptr.CpuSpan();
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ Status Tile<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
|
||||
T* output_data = output_tensor.template MutableData<T>();
|
||||
const T* input_data = input_tensor.template Data<T>();
|
||||
int device_id = 0;
|
||||
int device_id = GetDeviceId();
|
||||
CudaAsyncBuffer<int64_t> input_strides(this, device_id, rank);
|
||||
CudaAsyncBuffer<fast_divmod> fdm_input_shape(this, device_id, rank);
|
||||
CudaAsyncBuffer<fast_divmod> fdm_output_strides(this, device_id, rank);
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ Status Transpose<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
int device_id = 0;
|
||||
int device_id = GetDeviceId();
|
||||
CudaAsyncBuffer<int64_t> input_strides(this, device_id, rank);
|
||||
CudaAsyncBuffer<size_t> perm(this, device_id, *p_perm);
|
||||
CudaAsyncBuffer<fast_divmod> fdm_output_strides(this, device_id, rank);
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context, const std::vector<floa
|
|||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
||||
// kernel
|
||||
int device_id = 0;
|
||||
int device_id = GetDeviceId();
|
||||
TensorPitches input_pitches(X_dims);
|
||||
CudaAsyncBuffer<int64_t> input_strides(this, device_id, rank);
|
||||
gsl::span<int64_t> input_stride_span = input_strides.CpuSpan();
|
||||
|
|
|
|||
|
|
@ -474,19 +474,19 @@ void DeepCpuGruOpTestContext::RunTest(const std::vector<float>& X,
|
|||
alphas_,
|
||||
betas_);
|
||||
|
||||
//::onnxruntime::test::RunGruTest(X, gru_input_weights_, gru_recurrent_weights_,
|
||||
// expected_Y, expected_Y_h,
|
||||
// input_size_, batch_size, hidden_dim_, seq_length,
|
||||
// use_bias_ ? &gru_bias_ : nullptr,
|
||||
// initial_h,
|
||||
// &sequence_lens,
|
||||
// direction_,
|
||||
// 9999999999.f,
|
||||
// /*output_sequence*/ false,
|
||||
// linear_before_reset,
|
||||
// activation_func_names_,
|
||||
// alphas_,
|
||||
// betas_);
|
||||
::onnxruntime::test::RunGruTest(X, gru_input_weights_, gru_recurrent_weights_,
|
||||
expected_Y, expected_Y_h,
|
||||
input_size_, batch_size, hidden_dim_, seq_length,
|
||||
use_bias_ ? &gru_bias_ : nullptr,
|
||||
initial_h,
|
||||
&sequence_lens,
|
||||
direction_,
|
||||
9999999999.f,
|
||||
/*output_sequence*/ false,
|
||||
linear_before_reset,
|
||||
activation_func_names_,
|
||||
alphas_,
|
||||
betas_);
|
||||
}
|
||||
|
||||
TEST(GRUTest, ONNXRuntime_TestGRUOpForwardBasic) {
|
||||
|
|
@ -762,34 +762,33 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeRe
|
|||
ctx.RunTest(X, batch_size, seq_length, sequence_length, &initial_h, expected_Y, expected_Y_h, true);
|
||||
}
|
||||
|
||||
// Need CPU fix
|
||||
//TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeReset) {
|
||||
// const std::string direction = "bidirectional";
|
||||
// const std::vector<std::string> activations = {"sigmoid", "tanh", "sigmoid", "tanh"};
|
||||
//
|
||||
// DeepCpuGruOpTestContext ctx(direction, activations);
|
||||
//
|
||||
// const int batch_size = 2;
|
||||
// const int seq_length = 2;
|
||||
// std::vector<float> X = {-0.455351f, -0.276391f,
|
||||
// 0.855351f, 0.676391f,
|
||||
// -0.185934f, -0.269585f,
|
||||
// 0.585934f, 0.669585f};
|
||||
// std::vector<int> sequence_length = {2, 1};
|
||||
// std::vector<float> initial_h = {0.0f, 0.0f, 0.0f, 0.0f,
|
||||
// 0.0f, 0.0f, 0.0f, 0.0f};
|
||||
// std::vector<float> expected_Y = {-0.0325528607f, 0.0774837881f, -0.275918573f, -0.00228558504f,
|
||||
// -0.0559310019f, 0.101836264f, -0.385578573f, 0.0370728001f,
|
||||
//
|
||||
// -0.0577347837f, 0.0796165839f, 0.0f, 0.0f,
|
||||
// -0.0456649922f, 0.0462125242f, 0.0f, 0.0f};
|
||||
// std::vector<float> expected_Y_h = {-0.0577347837f, 0.0796165839f,
|
||||
// -0.275918573f, -0.00228558504f,
|
||||
// -0.0559310019f, 0.101836264f,
|
||||
// -0.385578573f, 0.0370728001f};
|
||||
//
|
||||
// ctx.RunTest(X, batch_size, seq_length, sequence_length, &initial_h, expected_Y, expected_Y_h, true);
|
||||
//}
|
||||
TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeReset) {
|
||||
const std::string direction = "bidirectional";
|
||||
const std::vector<std::string> activations = {"sigmoid", "tanh", "sigmoid", "tanh"};
|
||||
|
||||
DeepCpuGruOpTestContext ctx(direction, activations);
|
||||
|
||||
const int batch_size = 2;
|
||||
const int seq_length = 2;
|
||||
std::vector<float> X = {-0.455351f, -0.276391f,
|
||||
0.855351f, 0.676391f,
|
||||
-0.185934f, -0.269585f,
|
||||
0.585934f, 0.669585f};
|
||||
std::vector<int> sequence_length = {2, 1};
|
||||
std::vector<float> initial_h = {0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f};
|
||||
std::vector<float> expected_Y = {-0.0325528607f, 0.0774837881f, -0.275918573f, -0.00228558504f,
|
||||
-0.0559310019f, 0.101836264f, -0.275918573f, -0.00228558504f,
|
||||
|
||||
-0.0577347837f, 0.0796165839f, 0.0f, 0.0f,
|
||||
-0.0456649922f, 0.0462125242f, 0.0f, 0.0f};
|
||||
std::vector<float> expected_Y_h = {-0.0577347837f, 0.0796165839f,
|
||||
-0.275918573f, -0.00228558504f,
|
||||
-0.0559310019f, 0.101836264f,
|
||||
-0.275918573f, -0.00228558504f};
|
||||
|
||||
ctx.RunTest(X, batch_size, seq_length, sequence_length, &initial_h, expected_Y, expected_Y_h, true);
|
||||
}
|
||||
|
||||
TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithPartialZero) {
|
||||
const std::string direction = "bidirectional";
|
||||
|
|
|
|||
Loading…
Reference in a new issue