diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 555c76c0d9..66b626ec06 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -171,6 +171,35 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int32_t, DynamicSlice); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int64_t, DynamicSlice); + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 2, Split); @@ -355,6 +384,7 @@ void RegisterOnnxOperatorKernels(std::function fn) { fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); + fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); @@ -368,6 +398,35 @@ void RegisterOnnxOperatorKernels(std::function fn) { fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); + + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); diff --git a/onnxruntime/core/providers/cpu/tensor/slice.cc b/onnxruntime/core/providers/cpu/tensor/slice.cc index 23fae30964..99f9aea0cf 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.cc +++ b/onnxruntime/core/providers/cpu/tensor/slice.cc @@ -8,27 +8,64 @@ using namespace std; namespace onnxruntime { -#define ADD_TYPED_SLICE_OP(data_type) \ +#define ADD_TYPED_SLICE_OP(data_type, indice_type) \ ONNX_CPU_OPERATOR_TYPED_KERNEL( \ Slice, \ 1, \ data_type, \ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Slice); + Slice); -ADD_TYPED_SLICE_OP(uint8_t); -ADD_TYPED_SLICE_OP(uint16_t); -ADD_TYPED_SLICE_OP(uint32_t); -ADD_TYPED_SLICE_OP(uint64_t); -ADD_TYPED_SLICE_OP(int8_t); -ADD_TYPED_SLICE_OP(int16_t); -ADD_TYPED_SLICE_OP(int32_t); -ADD_TYPED_SLICE_OP(int64_t); -ADD_TYPED_SLICE_OP(float); -ADD_TYPED_SLICE_OP(double); -ADD_TYPED_SLICE_OP(MLFloat16); -ADD_TYPED_SLICE_OP(bool); -ADD_TYPED_SLICE_OP(string); +ADD_TYPED_SLICE_OP(uint8_t, int64_t); +ADD_TYPED_SLICE_OP(uint16_t, int64_t); +ADD_TYPED_SLICE_OP(uint32_t, int64_t); +ADD_TYPED_SLICE_OP(uint64_t, int64_t); +ADD_TYPED_SLICE_OP(int8_t, int64_t); +ADD_TYPED_SLICE_OP(int16_t, int64_t); +ADD_TYPED_SLICE_OP(int32_t, int64_t); +ADD_TYPED_SLICE_OP(int64_t, int64_t); +ADD_TYPED_SLICE_OP(float, int64_t); +ADD_TYPED_SLICE_OP(double, int64_t); +ADD_TYPED_SLICE_OP(MLFloat16,int64_t); +ADD_TYPED_SLICE_OP(bool, int64_t); +ADD_TYPED_SLICE_OP(string, int64_t); + +#define ADD_TYPED_DYNAMIC_SLICE_OP(data_type, indice_type) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + DynamicSlice, \ + 1, \ + data_type##_##indice_type, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ + Slice); + +ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t, int32_t); +ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t, int32_t); +ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t, int32_t); +ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t, int32_t); +ADD_TYPED_DYNAMIC_SLICE_OP(int8_t, int32_t); +ADD_TYPED_DYNAMIC_SLICE_OP(int16_t, int32_t); +ADD_TYPED_DYNAMIC_SLICE_OP(int32_t, int32_t); +ADD_TYPED_DYNAMIC_SLICE_OP(int64_t, int32_t); +ADD_TYPED_DYNAMIC_SLICE_OP(float, int32_t); +ADD_TYPED_DYNAMIC_SLICE_OP(double, int32_t); +ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16,int32_t); +ADD_TYPED_DYNAMIC_SLICE_OP(bool, int32_t); +ADD_TYPED_DYNAMIC_SLICE_OP(string, int32_t); + +ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t, int64_t); +ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t, int64_t); +ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t, int64_t); +ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t, int64_t); +ADD_TYPED_DYNAMIC_SLICE_OP(int8_t, int64_t); +ADD_TYPED_DYNAMIC_SLICE_OP(int16_t, int64_t); +ADD_TYPED_DYNAMIC_SLICE_OP(int32_t, int64_t); +ADD_TYPED_DYNAMIC_SLICE_OP(int64_t, int64_t); +ADD_TYPED_DYNAMIC_SLICE_OP(float, int64_t); +ADD_TYPED_DYNAMIC_SLICE_OP(double, int64_t); +ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16,int64_t); +ADD_TYPED_DYNAMIC_SLICE_OP(bool, int64_t); +ADD_TYPED_DYNAMIC_SLICE_OP(string, int64_t); namespace { // std::clamp doesn't exist until C++17 so create a local version @@ -39,33 +76,33 @@ const T& clamp(const T& v, const T& lo, const T& hi) { return v; } } // namespace -Status SliceBase::PrepareForCompute(const size_t dimension_count, const std::vector& input_dimensions, - std::vector& starts, std::vector& output_dims) const { + +Status SliceBase::PrepareForCompute(const std::vector& raw_starts, + const std::vector& raw_ends, + const std::vector& raw_axes, + const size_t dimension_count, + const std::vector& input_dimensions, + std::vector& starts, + std::vector& output_dims) const { // Initialize axes to the provided axes attribute or to the default sequence - std::vector axes(axes_); - if (!has_axes_) { + std::vector axes(raw_axes); + if (axes.size() == 0) { //axes are omitted, they are set to[0, ..., ndim - 1] axes.resize(starts.size()); - for (size_t i = 0; i < starts.size(); i++) - axes[i] = i; - - if (axes.size() > starts_.size()) - return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'axes' has more entries than the 'starts' attribute holds"); - if (axes.size() > ends_.size()) - return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'axes' has more entries than the 'ends' attribute holds"); + std::iota(axes.begin(), axes.end(), 0); } // Iterate through the provided axes and override the start/end ranges for (size_t axesIndex = 0; axesIndex < axes.size(); axesIndex++) { - auto axis = static_cast(axes[axesIndex]); - if (axis >= dimension_count) + auto axis = axes[axesIndex] < 0 ? axes[axesIndex] + static_cast(dimension_count) : axes[axesIndex]; + if (axis >= static_cast(dimension_count) || axis < 0) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'axes' has an axis outside of the tensor dimension count"); - auto start = starts_[axesIndex]; + auto start = raw_starts[axesIndex]; if (start < 0) start += input_dimensions[axis]; starts[axis] = clamp(start, int64_t{0}, input_dimensions[axis]); - auto end = ends_[axesIndex]; + auto end = raw_ends[axesIndex]; if (end < 0) end += input_dimensions[axis]; output_dims[axis] = clamp(end, int64_t{0}, input_dimensions[axis]) - starts[axis]; @@ -76,8 +113,39 @@ Status SliceBase::PrepareForCompute(const size_t dimension_count, const std::vec return Status::OK(); } -template -Status Slice::Compute(OpKernelContext* ctx) const { +template +void Slice::FillVectors(const OpKernelContext* context, + std::vector& input_starts, + std::vector& input_ends, + std::vector& input_axes) const { + ORT_ENFORCE(context->Input(1) != nullptr, "Required starts input is missing"); + ORT_ENFORCE(context->Input(2) != nullptr, "Required ends input is missing"); + + auto starts_tensor_ptr = context->Input(1); + ORT_ENFORCE(starts_tensor_ptr->Shape().NumDimensions() == 1, "Starts input must be a 1-D array"); + input_starts = std::vector (starts_tensor_ptr->Data(), + starts_tensor_ptr->Data() + + starts_tensor_ptr->Shape().Size()); + + auto ends_tensor_ptr = context->Input(2); + ORT_ENFORCE(ends_tensor_ptr->Shape().NumDimensions() == 1, "ends input must be a 1-D array"); + input_ends = std::vector (ends_tensor_ptr->Data(), + ends_tensor_ptr->Data() + + ends_tensor_ptr->Shape().Size()); + + ORT_ENFORCE(input_starts.size() == input_ends.size(), "Found mismatch between starts and ends input"); + + if (context->Input(3) != nullptr) { + auto axes_tensor_ptr = context->Input(3); + input_axes = std::vector (axes_tensor_ptr->Data(), + axes_tensor_ptr->Data() + + axes_tensor_ptr->Shape().Size()); + ORT_ENFORCE(input_axes.size() == input_starts.size(), "Axes input is invalid"); + } +} + +template +Status Slice::Compute(OpKernelContext* ctx) const { const Tensor* input_tensor_ptr = ctx->Input(0); ORT_ENFORCE(input_tensor_ptr != nullptr); auto& input_tensor = *input_tensor_ptr; @@ -88,7 +156,15 @@ Status Slice::Compute(OpKernelContext* ctx) const { std::vector starts(dimension_count, 0); std::vector output_dims(input_dimensions); - ORT_RETURN_IF_ERROR(PrepareForCompute(dimension_count, input_dimensions, starts, output_dims)); + if (dynamic) { + std::vector input_starts, input_ends, input_axes; + FillVectors(ctx, input_starts, input_ends, input_axes); + ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, + dimension_count, input_dimensions, starts, output_dims)); + } else { + ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_, + dimension_count, input_dimensions, starts, output_dims)); + } TensorShape output_shape(output_dims); auto& output_tensor = *ctx->Output(0, output_shape); diff --git a/onnxruntime/core/providers/cpu/tensor/slice.h b/onnxruntime/core/providers/cpu/tensor/slice.h index 2c71693e45..f91d86200a 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.h +++ b/onnxruntime/core/providers/cpu/tensor/slice.h @@ -9,33 +9,38 @@ namespace onnxruntime { class SliceBase { protected: - SliceBase(const OpKernelInfo& info) { - has_axes_ = info.GetAttrs("axes", axes_).IsOK(); - - ORT_ENFORCE(info.GetAttrs("starts", starts_).IsOK(), "Invalid 'starts' attribute value"); - ORT_ENFORCE(info.GetAttrs("ends", ends_).IsOK(), "Invalid 'ends' attribute value"); - - if (has_axes_) { - if (axes_.size() > starts_.size()) - ORT_THROW("'axes' has more entries than the 'starts' attribute holds"); - if (axes_.size() > ends_.size()) - ORT_THROW("'axes' has more entries than the 'ends' attribute holds"); + SliceBase (const OpKernelInfo& info, bool dynamic = false) { + if (!dynamic) { + auto has_starts = info.GetAttrs("starts", attr_starts_).IsOK(); + auto has_ends = info.GetAttrs("ends", attr_ends_).IsOK(); + auto has_axes = info.GetAttrs("axes", attr_axes_).IsOK(); + ORT_ENFORCE(has_starts && has_ends && attr_starts_.size() == attr_ends_.size(), + "Missing or invalid starts and ends attribute"); + ORT_ENFORCE(!has_axes || attr_axes_.size() == attr_starts_.size(), + "Invalid axes attribute"); } } - Status PrepareForCompute(const size_t dimension_count, const std::vector& input_dimensions, - std::vector& starts, std::vector& output_dims) const; + Status PrepareForCompute(const std::vector& raw_starts, + const std::vector& raw_ends, + const std::vector& raw_axes, + const size_t dimension_count, + const std::vector& input_dimensions, + std::vector& starts, + std::vector& output_dims) const; - std::vector axes_; - bool has_axes_; - std::vector starts_, ends_; + std::vector attr_starts_, attr_ends_, attr_axes_; }; -template +template struct Slice final : public OpKernel, public SliceBase { - Slice(const OpKernelInfo& info) : OpKernel(info), SliceBase(info) {} - + Slice(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, dynamic) {} Status Compute(OpKernelContext* context) const override; +private: + void FillVectors(const OpKernelContext* context, + std::vector& raw_starts, + std::vector& raw_ends, + std::vector& raw_axes) const; }; // namespace onnxruntime } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/slice.cc b/onnxruntime/core/providers/cuda/tensor/slice.cc index f1e4e0730a..b9868e4a8b 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.cc +++ b/onnxruntime/core/providers/cuda/tensor/slice.cc @@ -26,7 +26,9 @@ Status Slice::ComputeInternal(OpKernelContext* ctx) const { std::vector starts(dimension_count, 0); std::vector output_dims(input_dimensions); - ORT_RETURN_IF_ERROR(PrepareForCompute(dimension_count, input_dimensions, starts, output_dims)); + ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_, + dimension_count, input_dimensions, + starts, output_dims)); TensorShape output_shape(output_dims); auto output_tensor = ctx->Output(0, output_shape); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 6436fce3b9..fac6f4a823 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -270,11 +270,6 @@ int real_main(int argc, char* argv[]) { {"Softsign", "disable reason"}, {"convtranspose_1d", "disable reason"}, {"convtranspose_3d", "disable reason"}, - {"dynamic_slice", "disable reason"}, - {"dynamic_slice_default_axes", "disable reason"}, - {"dynamic_slice_end_out_of_bounds", "disable reason"}, - {"dynamic_slice_neg", "disable reason"}, - {"dynamic_slice_start_out_of_bounds", "disable reason"}, {"eyelike_populate_off_main_diagonal", "disable reason"}, {"eyelike_with_dtype", "disable reason"}, {"eyelike_without_dtype", "disable reason"}, diff --git a/onnxruntime/test/providers/cpu/tensor/dynamic_slice_op_test.cc b/onnxruntime/test/providers/cpu/tensor/dynamic_slice_op_test.cc new file mode 100644 index 0000000000..7dfafff725 --- /dev/null +++ b/onnxruntime/test/providers/cpu/tensor/dynamic_slice_op_test.cc @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +TEST(DynamicSliceTest, dynamic_slice_varied_types) { + OpTester test1 ("DynamicSlice", 1); + test1.AddInput ("data", {3,3}, {1,2,3,4,5,6,7,8,9}); + test1.AddInput ("starts", {2}, {1,1}); + test1.AddInput ("ends", {2}, {3,3}); + test1.AddOutput ("output", {2,2}, {5,6,8,9}); + test1.Run(); + + OpTester test2("DynamicSlice", 1); + test2.AddInput ("data", {3,3}, {1LL,2LL,3LL,4LL,5LL,6LL,7LL,8LL,9LL}); + test2.AddInput ("starts", {2}, {1,1}); + test2.AddInput ("ends", {2}, {3,3}); + test2.AddOutput ("output", {2,2}, {5LL,6LL,8LL,9LL}); + test2.Run(); + + OpTester test3("DynamicSlice", 1); + test3.AddInput ("data", {3,3}, {"a","b","c","d","e","f","g","h","i"}); + test3.AddInput ("starts", {2}, {1,1}); + test3.AddInput ("ends", {2}, {3,3}); + test3.AddOutput ("output", {2,2}, {"e","f","h","i"}); + test3.Run(); + + OpTester test4("DynamicSlice", 1); + test4.AddInput ("data", {3,3}, {1.1f,2.2f,3.3f,4.4f,5.5f,6.6f,7.7f,8.8f,9.9f}); + test4.AddInput ("starts", {2}, {1,1}); + test4.AddInput ("ends", {2}, {3,3}); + test4.AddOutput ("output", {2,2}, {5.5f,6.6f,8.8f,9.9f}); + test4.Run(); + + OpTester test5("DynamicSlice", 1); + test5.AddInput ("data", {3,3}, {false,true,false,false,false,false,true,false,true}); + test5.AddInput ("starts", {2}, {1,1}); + test5.AddInput ("ends", {2}, {3,3}); + test5.AddOutput ("output", {2,2}, {false,false,false,true}); + test5.Run(); +} + +TEST(DynamicSliceTest, dynamic_slice_with_axes) { + OpTester test1 ("DynamicSlice", 1); + test1.AddInput ("data", {3,3}, {1,2,3,4,5,6,7,8,9}); + test1.AddInput ("starts", {1}, {1}); + test1.AddInput ("ends", {1}, {3}); + test1.AddInput ("axes", {1}, {-1}); + test1.AddOutput ("output", {3,2}, {2,3,5,6,8,9}); + test1.Run(); + + OpTester test2 ("DynamicSlice", 1); + test2.AddInput ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10,11,12,13,14,15,16,17,18, + 19,20,21,22,23,24,25,26,27}); + test2.AddInput ("starts", {1}, {1}); + test2.AddInput ("ends", {1}, {2}); + test2.AddInput ("axes", {1}, {2}); + test2.AddOutput ("output", {3,3,1}, {2,5,8,11,14,17,20,23,26}); + test2.Run(); +} + +TEST(DynamicSliceTest, dynamic_slice_with_negative_axes) { + OpTester test1 ("DynamicSlice", 1); + test1.AddInput ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10,11,12,13,14,15,16,17,18, + 19,20,21,22,23,24,25,26,27}); + test1.AddInput ("starts", {1}, {1}); + test1.AddInput ("ends", {1}, {-1}); + test1.AddInput ("axes", {1}, {1}); + test1.AddOutput ("output", {3,1,3}, {4,5,6,13,14,15,22,23,24}); + test1.Run(); + + OpTester test2 ("DynamicSlice", 1); + test2.AddInput ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10,11,12,13,14,15,16,17,18, + 19,20,21,22,23,24,25,26,27}); + test2.AddInput ("starts", {2}, {-3,0}); + test2.AddInput ("ends", {2}, {-1,2}); + test2.AddInput ("axes", {2}, {0,2}); + test2.AddOutput ("output", {2,3,2}, {1,2,4,5,7,8,10,11,13,14,16,17}); + test2.Run(); +} + +TEST(DynamicSliceTest, dynamic_slice_ends_out_of_bounds) { + OpTester test ("DynamicSlice", 1); + test.AddInput ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10,11,12,13,14,15,16,17,18, + 19,20,21,22,23,24,25,26,27}); + test.AddInput ("starts", {2}, {0,-2}); + test.AddInput ("ends", {2}, {2,1000}); + test.AddInput ("axes", {2}, {1,2}); + test.AddOutput ("output", {3,2,2}, {2,3,5,6,11,12,14,15,20,21,23,24}); + test.Run(); +} + +TEST(DynamicSliceTest, dynamic_slice_full_axes) { + OpTester test1 ("DynamicSlice", 1); + test1.AddInput ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10,11,12,13,14,15,16,17,18, + 19,20,21,22,23,24,25,26,27}); + test1.AddInput ("starts", {3}, {0,1,1}); + test1.AddInput ("ends", {3}, {1,3,2}); + test1.AddInput ("axes", {3}, {0,1,2}); + test1.AddOutput ("output", {1,2,1}, {5,8}); + test1.Run(); + + OpTester test2 ("DynamicSlice", 1); + test2.AddInput ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10,11,12,13,14,15,16,17,18, + 19,20,21,22,23,24,25,26,27}); + test2.AddInput ("starts", {3}, {1,0,1}); + test2.AddInput ("ends", {3}, {2,1,3}); + test2.AddInput ("axes", {3}, {2,0,1}); + test2.AddOutput ("output", {1,2,1}, {5,8}); + test2.Run(); +} + +} // namespace Test +} // namespace onnxruntime diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index 7dc01e3083..70f5e99fad 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -25,11 +25,6 @@ backend_test.exclude(r'(test_acosh_cpu.*' '|test_convtranspose_1d_cpu.*' '|test_convtranspose_3d_cpu.*' '|test_cosh_example_cpu.*' -'|test_dynamic_slice_cpu.*' -'|test_dynamic_slice_default_axes_cpu.*' -'|test_dynamic_slice_end_out_of_bounds_cpu.*' -'|test_dynamic_slice_neg_cpu.*' -'|test_dynamic_slice_start_out_of_bounds_cpu.*' '|test_eyelike_populate_off_main_diagonal_cpu.*' '|test_eyelike_with_dtype_cpu.*' '|test_eyelike_without_dtype_cpu.*'