From fa0ea9a273a1a8a67cb8c9c73eca09954a848e01 Mon Sep 17 00:00:00 2001 From: Randy <45701928+RandyShuai@users.noreply.github.com> Date: Thu, 10 Jan 2019 09:42:18 -0800 Subject: [PATCH] implement dynamic slice cuda (#286) * implement dynamic slice cuda * add template parameter * add delaration * init base class * exclude case from cuda * use cuda mapped type * separate function implementation * add cpy logic * refactor * add type check * use InputMemoryType * merge functions --- .../core/providers/cpu/tensor/slice.cc | 48 ++++++++----------- onnxruntime/core/providers/cpu/tensor/slice.h | 14 +++--- .../providers/cuda/cuda_execution_provider.cc | 10 +++- .../core/providers/cuda/tensor/slice.cc | 41 +++++++++++----- .../core/providers/cuda/tensor/slice.h | 3 +- .../cpu/tensor/dynamic_slice_op_test.cc | 3 +- 6 files changed, 70 insertions(+), 49 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/slice.cc b/onnxruntime/core/providers/cpu/tensor/slice.cc index 99f9aea0cf..e3c619ec5a 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.cc +++ b/onnxruntime/core/providers/cpu/tensor/slice.cc @@ -113,34 +113,28 @@ Status SliceBase::PrepareForCompute(const std::vector& raw_starts, return Status::OK(); } -template -void Slice::FillVectors(const OpKernelContext* context, - std::vector& input_starts, - std::vector& input_ends, - std::vector& input_axes) const { - ORT_ENFORCE(context->Input(1) != nullptr, "Required starts input is missing"); - ORT_ENFORCE(context->Input(2) != nullptr, "Required ends input is missing"); +template +void SliceBase::FillVectorsFromInput(const OpKernelContext* context, + std::vector& input_starts, + std::vector& input_ends, + std::vector& input_axes) const { + auto stat_tensor = context->Input(1); + auto ends_tensor = context->Input(2); + auto axes_tensor = context->Input(3); - auto starts_tensor_ptr = context->Input(1); - ORT_ENFORCE(starts_tensor_ptr->Shape().NumDimensions() == 1, "Starts input must be a 1-D array"); - input_starts = std::vector (starts_tensor_ptr->Data(), - starts_tensor_ptr->Data() + - starts_tensor_ptr->Shape().Size()); - - auto ends_tensor_ptr = context->Input(2); - ORT_ENFORCE(ends_tensor_ptr->Shape().NumDimensions() == 1, "ends input must be a 1-D array"); - input_ends = std::vector (ends_tensor_ptr->Data(), - ends_tensor_ptr->Data() + - ends_tensor_ptr->Shape().Size()); + ORT_ENFORCE (nullptr != stat_tensor && stat_tensor->Shape().NumDimensions() == 1, "Starts must be a 1-D array" ); + ORT_ENFORCE (nullptr != ends_tensor && ends_tensor->Shape().NumDimensions() == 1, "ends must be a 1-D array" ); + ORT_ENFORCE (stat_tensor->Shape() == ends_tensor->Shape(), "Starts and ends shape mismatch"); + ORT_ENFORCE (nullptr == axes_tensor || stat_tensor->Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch"); - ORT_ENFORCE(input_starts.size() == input_ends.size(), "Found mismatch between starts and ends input"); - - if (context->Input(3) != nullptr) { - auto axes_tensor_ptr = context->Input(3); - input_axes = std::vector (axes_tensor_ptr->Data(), - axes_tensor_ptr->Data() + - axes_tensor_ptr->Shape().Size()); - ORT_ENFORCE(input_axes.size() == input_starts.size(), "Axes input is invalid"); + auto size = stat_tensor->Shape().Size(); + input_starts.resize(size); + std::copy(stat_tensor->Data(), stat_tensor->Data() + size, input_starts.begin()); + input_ends.resize(size); + std::copy(ends_tensor->Data(), ends_tensor->Data() + size, input_ends.begin()); + if (nullptr != axes_tensor) { + input_axes.resize(size); + std::copy(axes_tensor->Data(), axes_tensor->Data() + size, input_axes.begin()); } } @@ -158,7 +152,7 @@ Status Slice::Compute(OpKernelContext* ctx) const { if (dynamic) { std::vector input_starts, input_ends, input_axes; - FillVectors(ctx, input_starts, input_ends, input_axes); + FillVectorsFromInput(ctx, input_starts, input_ends, input_axes); ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, dimension_count, input_dimensions, starts, output_dims)); } else { diff --git a/onnxruntime/core/providers/cpu/tensor/slice.h b/onnxruntime/core/providers/cpu/tensor/slice.h index f91d86200a..82e8cd7208 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.h +++ b/onnxruntime/core/providers/cpu/tensor/slice.h @@ -15,9 +15,9 @@ class SliceBase { auto has_ends = info.GetAttrs("ends", attr_ends_).IsOK(); auto has_axes = info.GetAttrs("axes", attr_axes_).IsOK(); ORT_ENFORCE(has_starts && has_ends && attr_starts_.size() == attr_ends_.size(), - "Missing or invalid starts and ends attribute"); + "Missing or invalid starts and ends attribute"); ORT_ENFORCE(!has_axes || attr_axes_.size() == attr_starts_.size(), - "Invalid axes attribute"); + "Invalid axes attribute"); } } @@ -28,6 +28,11 @@ class SliceBase { const std::vector& input_dimensions, std::vector& starts, std::vector& output_dims) const; + template + void FillVectorsFromInput(const OpKernelContext* context, + std::vector& raw_starts, + std::vector& raw_ends, + std::vector& raw_axes) const; std::vector attr_starts_, attr_ends_, attr_axes_; }; @@ -36,11 +41,6 @@ template struct Slice final : public OpKernel, public SliceBase { Slice(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, dynamic) {} Status Compute(OpKernelContext* context) const override; -private: - void FillVectors(const OpKernelContext* context, - std::vector& raw_starts, - std::vector& raw_ends, - std::vector& raw_axes) const; }; // namespace onnxruntime } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 96b7d7d3f1..c5c5f29a51 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -493,7 +493,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, float, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, double, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, MLFloat16, LSTM); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, Compress); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, float, Upsample); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, double, Upsample); @@ -753,7 +756,10 @@ static void RegisterCudaKernels(std::function fn) { fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); - fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); diff --git a/onnxruntime/core/providers/cuda/tensor/slice.cc b/onnxruntime/core/providers/cuda/tensor/slice.cc index b9868e4a8b..fd403e53a9 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.cc +++ b/onnxruntime/core/providers/cuda/tensor/slice.cc @@ -8,15 +8,27 @@ namespace onnxruntime { namespace cuda { -ONNX_OPERATOR_KERNEL_EX( - Slice, - kOnnxDomain, - 1, - kCudaExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), - Slice); +#define REGISTER_TYPED_SLICE(NAME, TIND, DYNAMIC) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + NAME, \ + kOnnxDomain, \ + 1, \ + TIND, \ + kCudaExecutionProvider, \ + KernelDefBuilder().InputMemoryType(1). \ + InputMemoryType(2). \ + InputMemoryType(3). \ + TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()). \ + TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ + Slice); -Status Slice::ComputeInternal(OpKernelContext* ctx) const { +REGISTER_TYPED_SLICE(Slice, int32_t, false) +REGISTER_TYPED_SLICE(Slice, int64_t, false) +REGISTER_TYPED_SLICE(DynamicSlice, int32_t, true ) +REGISTER_TYPED_SLICE(DynamicSlice, int64_t, true ) + +template +Status Slice::ComputeInternal(OpKernelContext* ctx) const { auto input_tensor = ctx->Input(0); ORT_ENFORCE(nullptr != input_tensor); auto& input_dimensions = input_tensor->Shape().GetDims(); @@ -26,9 +38,16 @@ Status Slice::ComputeInternal(OpKernelContext* ctx) const { std::vector starts(dimension_count, 0); std::vector output_dims(input_dimensions); - ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_, - dimension_count, input_dimensions, - starts, output_dims)); + if (dynamic) { + std::vector input_starts, input_ends, input_axes; + FillVectorsFromInput(ctx, input_starts, input_ends, input_axes); + ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, + dimension_count, input_dimensions, starts, output_dims)); + + } else { + ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_, + dimension_count, input_dimensions, starts, output_dims)); + } TensorShape output_shape(output_dims); auto output_tensor = ctx->Output(0, output_shape); diff --git a/onnxruntime/core/providers/cuda/tensor/slice.h b/onnxruntime/core/providers/cuda/tensor/slice.h index 442f3c368d..a94cb0657f 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.h +++ b/onnxruntime/core/providers/cuda/tensor/slice.h @@ -8,9 +8,10 @@ namespace onnxruntime { namespace cuda { +template class Slice final : public CudaKernel, public SliceBase { public: - Slice(const OpKernelInfo& info) : CudaKernel(info), SliceBase(info) {} + Slice(const OpKernelInfo& info) : CudaKernel(info), SliceBase(info, dynamic) {} Status ComputeInternal(OpKernelContext* context) const override; }; diff --git a/onnxruntime/test/providers/cpu/tensor/dynamic_slice_op_test.cc b/onnxruntime/test/providers/cpu/tensor/dynamic_slice_op_test.cc index 7dfafff725..fd590efbdc 100644 --- a/onnxruntime/test/providers/cpu/tensor/dynamic_slice_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/dynamic_slice_op_test.cc @@ -22,12 +22,14 @@ TEST(DynamicSliceTest, dynamic_slice_varied_types) { test2.AddOutput ("output", {2,2}, {5LL,6LL,8LL,9LL}); test2.Run(); +#ifndef USE_CUDA OpTester test3("DynamicSlice", 1); test3.AddInput ("data", {3,3}, {"a","b","c","d","e","f","g","h","i"}); test3.AddInput ("starts", {2}, {1,1}); test3.AddInput ("ends", {2}, {3,3}); test3.AddOutput ("output", {2,2}, {"e","f","h","i"}); test3.Run(); +#endif OpTester test4("DynamicSlice", 1); test4.AddInput ("data", {3,3}, {1.1f,2.2f,3.3f,4.4f,5.5f,6.6f,7.7f,8.8f,9.9f}); @@ -119,6 +121,5 @@ TEST(DynamicSliceTest, dynamic_slice_full_axes) { test2.AddOutput ("output", {1,2,1}, {5,8}); test2.Run(); } - } // namespace Test } // namespace onnxruntime