From 1c9d0b27293019eda0bbc189da2ad6fc5a0aef2e Mon Sep 17 00:00:00 2001 From: Hector Li Date: Mon, 3 Dec 2018 13:54:48 -0800 Subject: [PATCH] Add missing types for Slice op (#74) --- .../providers/cpu/cpu_execution_provider.cc | 28 +++++++++++++-- .../core/providers/cpu/tensor/slice.cc | 36 ++++++++++++++----- .../providers/cpu/tensor/slice_op.test.cc | 24 +++++++++++++ 3 files changed, 77 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 39ec9c6fc4..d377dca633 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -161,7 +161,19 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 5, Reshape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Size); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Compress); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace); @@ -333,7 +345,19 @@ void RegisterOnnxOperatorKernels(std::function fn) { fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); - fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); diff --git a/onnxruntime/core/providers/cpu/tensor/slice.cc b/onnxruntime/core/providers/cpu/tensor/slice.cc index fe7729c86d..f09e32fc3c 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.cc +++ b/onnxruntime/core/providers/cpu/tensor/slice.cc @@ -4,14 +4,32 @@ #include "core/providers/cpu/tensor/slice.h" #include "core/providers/cpu/tensor/utils.h" using namespace ::onnxruntime::common; +using namespace std; namespace onnxruntime { -ONNX_CPU_OPERATOR_KERNEL( - Slice, - 1, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Slice); +#define ADD_TYPED_SLICE_OP(data_type) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + Slice, \ + 1, \ + data_type, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Slice); + +ADD_TYPED_SLICE_OP(uint8_t); +ADD_TYPED_SLICE_OP(uint16_t); +ADD_TYPED_SLICE_OP(uint32_t); +ADD_TYPED_SLICE_OP(uint64_t); +ADD_TYPED_SLICE_OP(int8_t); +ADD_TYPED_SLICE_OP(int16_t); +ADD_TYPED_SLICE_OP(int32_t); +ADD_TYPED_SLICE_OP(int64_t); +ADD_TYPED_SLICE_OP(float); +ADD_TYPED_SLICE_OP(double); +ADD_TYPED_SLICE_OP(MLFloat16); +ADD_TYPED_SLICE_OP(bool); +ADD_TYPED_SLICE_OP(string); + namespace { // std::clamp doesn't exist until C++17 so create a local version template @@ -58,8 +76,8 @@ Status SliceBase::PrepareForCompute(const size_t dimension_count, const std::vec return Status::OK(); } -template <> -Status Slice::Compute(OpKernelContext* ctx) const { +template +Status Slice::Compute(OpKernelContext* ctx) const { const Tensor* input_tensor_ptr = ctx->Input(0); ONNXRUNTIME_ENFORCE(input_tensor_ptr != nullptr); auto& input_tensor = *input_tensor_ptr; @@ -74,10 +92,10 @@ Status Slice::Compute(OpKernelContext* ctx) const { TensorShape output_shape(output_dims); auto& output_tensor = *ctx->Output(0, output_shape); - auto* output = output_tensor.template MutableData(); + auto* output = output_tensor.template MutableData(); const auto* output_end = output + output_shape.Size(); - SliceIterator input_iterator(input_tensor, starts, output_dims); + SliceIterator input_iterator(input_tensor, starts, output_dims); while (output != output_end) *output++ = *input_iterator++; diff --git a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc index 65f44d595b..00c91a3bba 100644 --- a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc @@ -134,5 +134,29 @@ TEST(SliceTest, Slice3D) { test.Run(); } +TEST(SliceTest, Slice1D_Int) { + OpTester test("Slice"); + + test.AddAttribute("axes", std::vector{0}); + test.AddAttribute("starts", std::vector{2}); + test.AddAttribute("ends", std::vector{4}); + + test.AddInput("data", {6}, {0L, 1L, 2L, 3L, 4L, 5L}); + test.AddOutput("output", {2}, {2L, 3L}); + test.Run(); +} + +TEST(SliceTest, Slice1D_String) { + OpTester test("Slice"); + + test.AddAttribute("axes", std::vector{0}); + test.AddAttribute("starts", std::vector{2}); + test.AddAttribute("ends", std::vector{4}); + + test.AddInput("data", {6}, {"0", "1", "2", "3", "4", "5"}); + test.AddOutput("output", {2}, {"2", "3"}); + test.Run(); +} + } // namespace Test } // namespace onnxruntime