use cudnnRNNForwardInferenceEx for unpacked (padded) layout case (#899)

1. Use cudnnRNNForwardInferenceEx for unpacked (padded) layout case, place the sequence_lens data on CPU
2. Fix hard code device ID issue. In cuda kernel, it should get the device id from provider.
This commit is contained in:
Hector Li 2019-06-12 09:51:15 -07:00 committed by GitHub
parent 38963d81eb
commit bbdd1d658b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 212 additions and 118 deletions

View file

@ -140,6 +140,8 @@ class CudaKernel : public OpKernel {
return provider_->CopyTensor(src, dst);
}
inline int GetDeviceId() const { return provider_->GetDeviceId(); }
private:
CUDAExecutionProvider* provider_;
};

View file

@ -76,6 +76,8 @@ class CUDAExecutionProvider : public IExecutionProvider {
GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const KernelRegistry*>& kernel_registries) const override;
int GetDeviceId() const { return device_id_; }
private:
cudaStream_t streams_[kTotalCudaStreams];
int device_id_;

View file

@ -59,6 +59,41 @@ cudnnDataType_t CudnnTensor::GetDataType() {
ORT_THROW("cuDNN engine currently supports only single/double/half precision data types.");
}
CudnnDataTensor::CudnnDataTensor()
: tensor_(nullptr) {
}
CudnnDataTensor::~CudnnDataTensor() {
if (tensor_ != nullptr) {
cudnnDestroyRNNDataDescriptor(tensor_);
tensor_ = nullptr;
}
}
Status CudnnDataTensor::CreateTensorIfNeeded() {
if (!tensor_)
CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDataDescriptor(&tensor_));
return Status::OK();
}
Status CudnnDataTensor::Set(cudnnDataType_t dataType,
int64_t max_seq_length,
int64_t batch_size,
int64_t data_size,
const int32_t* seq_lengths) {
ORT_RETURN_IF_ERROR(CreateTensorIfNeeded());
cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED;
float padding_fill = 0.0f;
CUDNN_RETURN_IF_ERROR(cudnnSetRNNDataDescriptor(tensor_, dataType, layout,
static_cast<int>(max_seq_length),
static_cast<int>(batch_size),
static_cast<int>(data_size),
seq_lengths,
static_cast<void*>(&padding_fill)));
return Status::OK();
}
CudnnFilterDescriptor::CudnnFilterDescriptor() : desc_(nullptr) {
cudnnCreateFilterDescriptor(&desc_);
}

View file

@ -28,6 +28,25 @@ class CudnnTensor final {
cudnnTensorDescriptor_t tensor_;
};
class CudnnDataTensor final {
public:
CudnnDataTensor();
~CudnnDataTensor();
Status Set(cudnnDataType_t dataType,
int64_t max_seq_length,
int64_t batch_size,
int64_t data_size,
const int32_t* seq_lengths);
operator cudnnRNNDataDescriptor_t() const { return tensor_; }
private:
Status CreateTensorIfNeeded();
cudnnRNNDataDescriptor_t tensor_;
};
class CudnnFilterDescriptor final {
public:
CudnnFilterDescriptor();

View file

@ -67,7 +67,7 @@ Status MatMul<T>::ComputeInternal(OpKernelContext* ctx) const {
static_cast<int>(helper.N())));
return Status::OK();
}
int device_id = 0;
int device_id = GetDeviceId();
CudaAsyncBuffer<const CudaT*> left_arrays(this, device_id, helper.LeftOffsets().size());
CudaAsyncBuffer<const CudaT*> right_arrays(this, device_id, helper.RightOffsets().size());
CudaAsyncBuffer<CudaT*> output_arrays(this, device_id, helper.OutputOffsets().size());

View file

@ -133,9 +133,9 @@ Status CudnnRnnBase<T>::CacheCudnnRnnWeights(const OpKernelInfo& info) {
const Tensor* W;
const Tensor* R;
const Tensor* B;
bool get_W = info.TryGetConstantInput(Input_Index::W, &W);
bool get_R = info.TryGetConstantInput(Input_Index::R, &R);
bool get_B = info.TryGetConstantInput(Input_Index::B, &B);
bool get_W = info.TryGetConstantInput(RNN_Input_Index::W, &W);
bool get_R = info.TryGetConstantInput(RNN_Input_Index::R, &R);
bool get_B = info.TryGetConstantInput(RNN_Input_Index::B, &B);
if (get_W && get_R) {
if (get_B) {
@ -154,15 +154,15 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
typedef typename ToCudaType<T>::MappedType CudaT;
// inputs
const Tensor* X = ctx->Input<Tensor>(Input_Index::X); // inputs. [seq_length, batch_size, input_size]
const Tensor* X = ctx->Input<Tensor>(RNN_Input_Index::X); // inputs. [seq_length, batch_size, input_size]
ORT_ENFORCE(nullptr != X);
// optional inputs
const Tensor* sequence_lens = ctx->Input<Tensor>(Input_Index::sequence_lens); // [batch_size]
const Tensor* initial_h = ctx->Input<Tensor>(Input_Index::initial_h); // initial hidden. [num_directions_, batch_size, hidden_size_]
const Tensor* sequence_lens = ctx->Input<Tensor>(RNN_Input_Index::sequence_lens); // [batch_size]
const Tensor* initial_h = ctx->Input<Tensor>(RNN_Input_Index::initial_h); // initial hidden. [num_directions_, batch_size, hidden_size_]
const Tensor* initial_c(nullptr);
if (rnn_mode_ == CUDNN_LSTM) {
initial_c = ctx->Input<Tensor>(Input_Index::initial_c); // initial cell. [num_directions_, batch_size, hidden_size_]
initial_c = ctx->Input<Tensor>(RNN_Input_Index::initial_c); // initial cell. [num_directions_, batch_size, hidden_size_]
}
int64_t seq_length = X->Shape()[0];
@ -200,9 +200,9 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
IAllocatorUniquePtr<void> w_data;
CudnnFilterDescriptor w_desc;
if (!weight_cached_) {
const Tensor& W = *ctx->Input<Tensor>(Input_Index::W);
const Tensor& R = *ctx->Input<Tensor>(Input_Index::R);
const Tensor* B = ctx->Input<Tensor>(Input_Index::B);
const Tensor& W = *ctx->Input<Tensor>(RNN_Input_Index::W);
const Tensor& R = *ctx->Input<Tensor>(RNN_Input_Index::R);
const Tensor* B = ctx->Input<Tensor>(RNN_Input_Index::B);
ReorganizeWeights(&W, &R, B, w_data, w_desc);
}
@ -240,25 +240,54 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
workspace_bytes *= num_directions_;
auto workspace_cuda = GetScratchBuffer<void>(workspace_bytes);
CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(),
rnn_desc_,
gsl::narrow_cast<int>(seq_length),
x_desc.data(),
x_data,
hx_desc,
hx_data,
cx_desc,
cx_data,
weight_cached_ ? w_desc_cache_ : w_desc,
weight_cached_ ? w_data_cache_.get() : w_data.get(),
y_desc.data(),
y_data,
y_h_desc,
y_h_data,
y_c_desc,
y_c_data,
workspace_cuda.get(),
workspace_bytes));
if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) {
CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(),
rnn_desc_,
gsl::narrow_cast<int>(seq_length),
x_desc.data(),
x_data,
hx_desc,
hx_data,
cx_desc,
cx_data,
weight_cached_ ? w_desc_cache_ : w_desc,
weight_cached_ ? w_data_cache_.get() : w_data.get(),
y_desc.data(),
y_data,
y_h_desc,
y_h_data,
y_c_desc,
y_c_data,
workspace_cuda.get(),
workspace_bytes));
} else {
CudnnDataTensor x_desc;
x_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, input_size, sequence_lens_data);
CudnnDataTensor y_desc;
y_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, hidden_size_ * num_directions_, sequence_lens_data);
CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(),
rnn_desc_,
x_desc,
x_data,
hx_desc,
hx_data,
cx_desc,
cx_data,
weight_cached_ ? w_desc_cache_ : w_desc,
weight_cached_ ? w_data_cache_.get() : w_data.get(),
y_desc,
y_data,
y_h_desc,
y_h_data,
y_c_desc,
y_c_data,
nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr,
workspace_cuda.get(),
workspace_bytes));
}
IAllocatorUniquePtr<T> y_reorganized_data;
if (reverse_ || num_directions_ == 2) {
//reverse output
@ -288,12 +317,16 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
}
}
if (sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) {
if ((CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_) && sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) {
auto count = sequence_lens->Shape().Size();
CudaAsyncBuffer<int32_t> sequence_lens_buffer(this, GetDeviceId(), count);
memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, count * sizeof(int32_t));
sequence_lens_buffer.CopyToGpu();
RnnMaskImpl(gsl::narrow_cast<int32_t>(num_directions_),
gsl::narrow_cast<int32_t>(seq_length),
gsl::narrow_cast<int32_t>(batch_size),
gsl::narrow_cast<int32_t>(hidden_size_),
sequence_lens_data,
sequence_lens_buffer.GpuPtr(),
reinterpret_cast<CudaT*>(y_data),
reinterpret_cast<CudaT*>(y_h_data),
output_size);

View file

@ -11,6 +11,16 @@
namespace onnxruntime {
namespace cuda {
enum RNN_Input_Index {
X = 0,
W = 1,
R = 2,
B = 3,
sequence_lens = 4,
initial_h = 5,
initial_c = 6
};
class CudnnDropout {
public:
CudnnDropout() : dropout_desc_(nullptr) {
@ -158,15 +168,6 @@ class CudnnRnnBase : public CudaKernel {
IAllocatorUniquePtr<void> w_data_cache_;
bool weight_cached_;
enum Input_Index {
X = 0,
W = 1,
R = 2,
B = 3,
sequence_lens = 4,
initial_h = 5,
initial_c = 6
};
enum Output_Index {
Y = 0,
Y_h = 1,

View file

@ -10,16 +10,17 @@
namespace onnxruntime {
namespace cuda {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GRU, \
kOnnxDomain, \
7, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>()), \
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GRU, \
kOnnxDomain, \
7, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>()) \
.InputMemoryType<OrtMemTypeCPUInput>(RNN_Input_Index::sequence_lens), \
GRU<T>);
REGISTER_KERNEL_TYPED(float);

View file

@ -8,16 +8,17 @@
namespace onnxruntime {
namespace cuda {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
LSTM, \
kOnnxDomain, \
7, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>()), \
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
LSTM, \
kOnnxDomain, \
7, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>()) \
.InputMemoryType<OrtMemTypeCPUInput>(RNN_Input_Index::sequence_lens), \
LSTM<T>);
REGISTER_KERNEL_TYPED(float);

View file

@ -10,16 +10,17 @@
namespace onnxruntime {
namespace cuda {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
RNN, \
kOnnxDomain, \
7, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>()), \
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
RNN, \
kOnnxDomain, \
7, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>()) \
.InputMemoryType<OrtMemTypeCPUInput>(RNN_Input_Index::sequence_lens), \
RNN<T>);
REGISTER_KERNEL_TYPED(float);

View file

@ -24,7 +24,7 @@ Status Pad<T>::ComputeInternal(OpKernelContext* ctx) const {
const auto& input_tensor = *ctx->Input<Tensor>(0);
auto const& input_shape = input_tensor.Shape();
auto dimension_count = input_shape.NumDimensions();
int device_id = 0;
int device_id = GetDeviceId();
CudaAsyncBuffer<int64_t> input_dims(this, device_id, input_shape.GetDims());
CudaAsyncBuffer<int64_t> input_strides(this, device_id, dimension_count);
CudaAsyncBuffer<int64_t> lower_pads(this, device_id, dimension_count);

View file

@ -85,7 +85,7 @@ Status Slice<Tind, dynamic>::ComputeInternal(OpKernelContext* ctx) const {
if (output_size == 0) {
return Status::OK();
}
int device_id = 0;
int device_id = GetDeviceId();
CudaAsyncBuffer<int64_t> starts_buffer(this, device_id, dimension_count);
gsl::span<int64_t> starts_buffer_span = starts_buffer.CpuSpan();
for (int i = 0; i < dimension_count; ++i) {

View file

@ -41,7 +41,7 @@ Status Split::ComputeInternal(OpKernelContext* ctx) const {
auto& input_dims = input_shape.GetDims();
std::vector<int64_t> output_dimensions{input_dims};
int device_id = 0;
int device_id = GetDeviceId();
CudaAsyncBuffer<void*> output_ptr(this, device_id, num_outputs);
gsl::span<void*> output_ptr_span = output_ptr.CpuSpan();
for (int i = 0; i < num_outputs; ++i) {

View file

@ -42,7 +42,7 @@ Status Tile<T>::ComputeInternal(OpKernelContext* ctx) const {
T* output_data = output_tensor.template MutableData<T>();
const T* input_data = input_tensor.template Data<T>();
int device_id = 0;
int device_id = GetDeviceId();
CudaAsyncBuffer<int64_t> input_strides(this, device_id, rank);
CudaAsyncBuffer<fast_divmod> fdm_input_shape(this, device_id, rank);
CudaAsyncBuffer<fast_divmod> fdm_output_strides(this, device_id, rank);

View file

@ -92,7 +92,7 @@ Status Transpose<T>::ComputeInternal(OpKernelContext* ctx) const {
return Status::OK();
}
int device_id = 0;
int device_id = GetDeviceId();
CudaAsyncBuffer<int64_t> input_strides(this, device_id, rank);
CudaAsyncBuffer<size_t> perm(this, device_id, *p_perm);
CudaAsyncBuffer<fast_divmod> fdm_output_strides(this, device_id, rank);

View file

@ -51,7 +51,7 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context, const std::vector<floa
typedef typename ToCudaType<T>::MappedType CudaT;
// kernel
int device_id = 0;
int device_id = GetDeviceId();
TensorPitches input_pitches(X_dims);
CudaAsyncBuffer<int64_t> input_strides(this, device_id, rank);
gsl::span<int64_t> input_stride_span = input_strides.CpuSpan();

View file

@ -474,19 +474,19 @@ void DeepCpuGruOpTestContext::RunTest(const std::vector<float>& X,
alphas_,
betas_);
//::onnxruntime::test::RunGruTest(X, gru_input_weights_, gru_recurrent_weights_,
// expected_Y, expected_Y_h,
// input_size_, batch_size, hidden_dim_, seq_length,
// use_bias_ ? &gru_bias_ : nullptr,
// initial_h,
// &sequence_lens,
// direction_,
// 9999999999.f,
// /*output_sequence*/ false,
// linear_before_reset,
// activation_func_names_,
// alphas_,
// betas_);
::onnxruntime::test::RunGruTest(X, gru_input_weights_, gru_recurrent_weights_,
expected_Y, expected_Y_h,
input_size_, batch_size, hidden_dim_, seq_length,
use_bias_ ? &gru_bias_ : nullptr,
initial_h,
&sequence_lens,
direction_,
9999999999.f,
/*output_sequence*/ false,
linear_before_reset,
activation_func_names_,
alphas_,
betas_);
}
TEST(GRUTest, ONNXRuntime_TestGRUOpForwardBasic) {
@ -762,34 +762,33 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeRe
ctx.RunTest(X, batch_size, seq_length, sequence_length, &initial_h, expected_Y, expected_Y_h, true);
}
// Need CPU fix
//TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeReset) {
// const std::string direction = "bidirectional";
// const std::vector<std::string> activations = {"sigmoid", "tanh", "sigmoid", "tanh"};
//
// DeepCpuGruOpTestContext ctx(direction, activations);
//
// const int batch_size = 2;
// const int seq_length = 2;
// std::vector<float> X = {-0.455351f, -0.276391f,
// 0.855351f, 0.676391f,
// -0.185934f, -0.269585f,
// 0.585934f, 0.669585f};
// std::vector<int> sequence_length = {2, 1};
// std::vector<float> initial_h = {0.0f, 0.0f, 0.0f, 0.0f,
// 0.0f, 0.0f, 0.0f, 0.0f};
// std::vector<float> expected_Y = {-0.0325528607f, 0.0774837881f, -0.275918573f, -0.00228558504f,
// -0.0559310019f, 0.101836264f, -0.385578573f, 0.0370728001f,
//
// -0.0577347837f, 0.0796165839f, 0.0f, 0.0f,
// -0.0456649922f, 0.0462125242f, 0.0f, 0.0f};
// std::vector<float> expected_Y_h = {-0.0577347837f, 0.0796165839f,
// -0.275918573f, -0.00228558504f,
// -0.0559310019f, 0.101836264f,
// -0.385578573f, 0.0370728001f};
//
// ctx.RunTest(X, batch_size, seq_length, sequence_length, &initial_h, expected_Y, expected_Y_h, true);
//}
TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeReset) {
const std::string direction = "bidirectional";
const std::vector<std::string> activations = {"sigmoid", "tanh", "sigmoid", "tanh"};
DeepCpuGruOpTestContext ctx(direction, activations);
const int batch_size = 2;
const int seq_length = 2;
std::vector<float> X = {-0.455351f, -0.276391f,
0.855351f, 0.676391f,
-0.185934f, -0.269585f,
0.585934f, 0.669585f};
std::vector<int> sequence_length = {2, 1};
std::vector<float> initial_h = {0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f};
std::vector<float> expected_Y = {-0.0325528607f, 0.0774837881f, -0.275918573f, -0.00228558504f,
-0.0559310019f, 0.101836264f, -0.275918573f, -0.00228558504f,
-0.0577347837f, 0.0796165839f, 0.0f, 0.0f,
-0.0456649922f, 0.0462125242f, 0.0f, 0.0f};
std::vector<float> expected_Y_h = {-0.0577347837f, 0.0796165839f,
-0.275918573f, -0.00228558504f,
-0.0559310019f, 0.101836264f,
-0.275918573f, -0.00228558504f};
ctx.RunTest(X, batch_size, seq_length, sequence_length, &initial_h, expected_Y, expected_Y_h, true);
}
TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithPartialZero) {
const std::string direction = "bidirectional";