implement dynamic slice cuda (#286)

* implement dynamic slice cuda

* add template parameter

* add delaration

* init base class

* exclude case from cuda

* use cuda mapped type

* separate function implementation

* add cpy logic

* refactor

* add type check

* use InputMemoryType

* merge functions
This commit is contained in:
Randy 2019-01-10 09:42:18 -08:00 committed by GitHub
parent 98a92547bf
commit fa0ea9a273
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 70 additions and 49 deletions

View file

@ -113,34 +113,28 @@ Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
return Status::OK();
}
template <typename T, typename Tind, bool dynamic>
void Slice<T, Tind, dynamic>::FillVectors(const OpKernelContext* context,
std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends,
std::vector<int64_t>& input_axes) const {
ORT_ENFORCE(context->Input<Tensor>(1) != nullptr, "Required starts input is missing");
ORT_ENFORCE(context->Input<Tensor>(2) != nullptr, "Required ends input is missing");
template <typename Tind>
void SliceBase::FillVectorsFromInput(const OpKernelContext* context,
std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends,
std::vector<int64_t>& input_axes) const {
auto stat_tensor = context->Input<Tensor>(1);
auto ends_tensor = context->Input<Tensor>(2);
auto axes_tensor = context->Input<Tensor>(3);
auto starts_tensor_ptr = context->Input<Tensor>(1);
ORT_ENFORCE(starts_tensor_ptr->Shape().NumDimensions() == 1, "Starts input must be a 1-D array");
input_starts = std::vector<int64_t> (starts_tensor_ptr->Data<Tind>(),
starts_tensor_ptr->Data<Tind>() +
starts_tensor_ptr->Shape().Size());
auto ends_tensor_ptr = context->Input<Tensor>(2);
ORT_ENFORCE(ends_tensor_ptr->Shape().NumDimensions() == 1, "ends input must be a 1-D array");
input_ends = std::vector<int64_t> (ends_tensor_ptr->Data<Tind>(),
ends_tensor_ptr->Data<Tind>() +
ends_tensor_ptr->Shape().Size());
ORT_ENFORCE (nullptr != stat_tensor && stat_tensor->Shape().NumDimensions() == 1, "Starts must be a 1-D array" );
ORT_ENFORCE (nullptr != ends_tensor && ends_tensor->Shape().NumDimensions() == 1, "ends must be a 1-D array" );
ORT_ENFORCE (stat_tensor->Shape() == ends_tensor->Shape(), "Starts and ends shape mismatch");
ORT_ENFORCE (nullptr == axes_tensor || stat_tensor->Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch");
ORT_ENFORCE(input_starts.size() == input_ends.size(), "Found mismatch between starts and ends input");
if (context->Input<Tensor>(3) != nullptr) {
auto axes_tensor_ptr = context->Input<Tensor>(3);
input_axes = std::vector<int64_t> (axes_tensor_ptr->Data<Tind>(),
axes_tensor_ptr->Data<Tind>() +
axes_tensor_ptr->Shape().Size());
ORT_ENFORCE(input_axes.size() == input_starts.size(), "Axes input is invalid");
auto size = stat_tensor->Shape().Size();
input_starts.resize(size);
std::copy(stat_tensor->Data<Tind>(), stat_tensor->Data<Tind>() + size, input_starts.begin());
input_ends.resize(size);
std::copy(ends_tensor->Data<Tind>(), ends_tensor->Data<Tind>() + size, input_ends.begin());
if (nullptr != axes_tensor) {
input_axes.resize(size);
std::copy(axes_tensor->Data<Tind>(), axes_tensor->Data<Tind>() + size, input_axes.begin());
}
}
@ -158,7 +152,7 @@ Status Slice<T, Tind, dynamic>::Compute(OpKernelContext* ctx) const {
if (dynamic) {
std::vector<int64_t> input_starts, input_ends, input_axes;
FillVectors(ctx, input_starts, input_ends, input_axes);
FillVectorsFromInput<Tind>(ctx, input_starts, input_ends, input_axes);
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes,
dimension_count, input_dimensions, starts, output_dims));
} else {

View file

@ -15,9 +15,9 @@ class SliceBase {
auto has_ends = info.GetAttrs("ends", attr_ends_).IsOK();
auto has_axes = info.GetAttrs("axes", attr_axes_).IsOK();
ORT_ENFORCE(has_starts && has_ends && attr_starts_.size() == attr_ends_.size(),
"Missing or invalid starts and ends attribute");
"Missing or invalid starts and ends attribute");
ORT_ENFORCE(!has_axes || attr_axes_.size() == attr_starts_.size(),
"Invalid axes attribute");
"Invalid axes attribute");
}
}
@ -28,6 +28,11 @@ class SliceBase {
const std::vector<int64_t>& input_dimensions,
std::vector<int64_t>& starts,
std::vector<int64_t>& output_dims) const;
template<typename Tind>
void FillVectorsFromInput(const OpKernelContext* context,
std::vector<int64_t>& raw_starts,
std::vector<int64_t>& raw_ends,
std::vector<int64_t>& raw_axes) const;
std::vector<int64_t> attr_starts_, attr_ends_, attr_axes_;
};
@ -36,11 +41,6 @@ template <typename T, typename Tind, bool dynamic>
struct Slice final : public OpKernel, public SliceBase {
Slice(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, dynamic) {}
Status Compute(OpKernelContext* context) const override;
private:
void FillVectors(const OpKernelContext* context,
std::vector<int64_t>& raw_starts,
std::vector<int64_t>& raw_ends,
std::vector<int64_t>& raw_axes) const;
}; // namespace onnxruntime
} // namespace onnxruntime

View file

@ -493,7 +493,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, float, LSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, double, LSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, MLFloat16, LSTM);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, Compress);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, float, Upsample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, double, Upsample);
@ -753,7 +756,10 @@ static void RegisterCudaKernels(std::function<void(KernelCreateInfo&&)> fn) {
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, float, LSTM)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, double, LSTM)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, MLFloat16, LSTM)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, Compress)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, float, Upsample)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, double, Upsample)>());

View file

@ -8,15 +8,27 @@
namespace onnxruntime {
namespace cuda {
ONNX_OPERATOR_KERNEL_EX(
Slice,
kOnnxDomain,
1,
kCudaExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Slice);
#define REGISTER_TYPED_SLICE(NAME, TIND, DYNAMIC) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
NAME, \
kOnnxDomain, \
1, \
TIND, \
kCudaExecutionProvider, \
KernelDefBuilder().InputMemoryType<OrtMemTypeCPUInput>(1). \
InputMemoryType<OrtMemTypeCPUInput>(2). \
InputMemoryType<OrtMemTypeCPUInput>(3). \
TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()). \
TypeConstraint("Tind", DataTypeImpl::GetTensorType<TIND>()), \
Slice<TIND,DYNAMIC>);
Status Slice::ComputeInternal(OpKernelContext* ctx) const {
REGISTER_TYPED_SLICE(Slice, int32_t, false)
REGISTER_TYPED_SLICE(Slice, int64_t, false)
REGISTER_TYPED_SLICE(DynamicSlice, int32_t, true )
REGISTER_TYPED_SLICE(DynamicSlice, int64_t, true )
template<typename Tind, bool dynamic>
Status Slice<Tind, dynamic>::ComputeInternal(OpKernelContext* ctx) const {
auto input_tensor = ctx->Input<Tensor>(0);
ORT_ENFORCE(nullptr != input_tensor);
auto& input_dimensions = input_tensor->Shape().GetDims();
@ -26,9 +38,16 @@ Status Slice::ComputeInternal(OpKernelContext* ctx) const {
std::vector<int64_t> starts(dimension_count, 0);
std::vector<int64_t> output_dims(input_dimensions);
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
dimension_count, input_dimensions,
starts, output_dims));
if (dynamic) {
std::vector<int64_t> input_starts, input_ends, input_axes;
FillVectorsFromInput<Tind>(ctx, input_starts, input_ends, input_axes);
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes,
dimension_count, input_dimensions, starts, output_dims));
} else {
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
dimension_count, input_dimensions, starts, output_dims));
}
TensorShape output_shape(output_dims);
auto output_tensor = ctx->Output(0, output_shape);

View file

@ -8,9 +8,10 @@
namespace onnxruntime {
namespace cuda {
template<typename Tind, bool dynamic>
class Slice final : public CudaKernel, public SliceBase {
public:
Slice(const OpKernelInfo& info) : CudaKernel(info), SliceBase(info) {}
Slice(const OpKernelInfo& info) : CudaKernel(info), SliceBase(info, dynamic) {}
Status ComputeInternal(OpKernelContext* context) const override;
};

View file

@ -22,12 +22,14 @@ TEST(DynamicSliceTest, dynamic_slice_varied_types) {
test2.AddOutput <int64_t> ("output", {2,2}, {5LL,6LL,8LL,9LL});
test2.Run();
#ifndef USE_CUDA
OpTester test3("DynamicSlice", 1);
test3.AddInput <std::string> ("data", {3,3}, {"a","b","c","d","e","f","g","h","i"});
test3.AddInput <int64_t> ("starts", {2}, {1,1});
test3.AddInput <int64_t> ("ends", {2}, {3,3});
test3.AddOutput <std::string> ("output", {2,2}, {"e","f","h","i"});
test3.Run();
#endif
OpTester test4("DynamicSlice", 1);
test4.AddInput <float> ("data", {3,3}, {1.1f,2.2f,3.3f,4.4f,5.5f,6.6f,7.7f,8.8f,9.9f});
@ -119,6 +121,5 @@ TEST(DynamicSliceTest, dynamic_slice_full_axes) {
test2.AddOutput <int32_t> ("output", {1,2,1}, {5,8});
test2.Run();
}
} // namespace Test
} // namespace onnxruntime