diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 9a946c8270..2e318271cd 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -177,33 +177,19 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int32_t, DynamicSlice); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, DynamicSlice); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace); @@ -442,33 +428,19 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); diff --git a/onnxruntime/core/providers/cpu/tensor/slice.cc b/onnxruntime/core/providers/cpu/tensor/slice.cc index e3c619ec5a..5ceeb630dd 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.cc +++ b/onnxruntime/core/providers/cpu/tensor/slice.cc @@ -8,64 +8,51 @@ using namespace std; namespace onnxruntime { -#define ADD_TYPED_SLICE_OP(data_type, indice_type) \ +#define ADD_TYPED_SLICE_OP(data_type) \ ONNX_CPU_OPERATOR_TYPED_KERNEL( \ Slice, \ 1, \ data_type, \ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Slice); + Slice); -ADD_TYPED_SLICE_OP(uint8_t, int64_t); -ADD_TYPED_SLICE_OP(uint16_t, int64_t); -ADD_TYPED_SLICE_OP(uint32_t, int64_t); -ADD_TYPED_SLICE_OP(uint64_t, int64_t); -ADD_TYPED_SLICE_OP(int8_t, int64_t); -ADD_TYPED_SLICE_OP(int16_t, int64_t); -ADD_TYPED_SLICE_OP(int32_t, int64_t); -ADD_TYPED_SLICE_OP(int64_t, int64_t); -ADD_TYPED_SLICE_OP(float, int64_t); -ADD_TYPED_SLICE_OP(double, int64_t); -ADD_TYPED_SLICE_OP(MLFloat16,int64_t); -ADD_TYPED_SLICE_OP(bool, int64_t); -ADD_TYPED_SLICE_OP(string, int64_t); +ADD_TYPED_SLICE_OP(uint8_t); +ADD_TYPED_SLICE_OP(uint16_t); +ADD_TYPED_SLICE_OP(uint32_t); +ADD_TYPED_SLICE_OP(uint64_t); +ADD_TYPED_SLICE_OP(int8_t); +ADD_TYPED_SLICE_OP(int16_t); +ADD_TYPED_SLICE_OP(int32_t); +ADD_TYPED_SLICE_OP(int64_t); +ADD_TYPED_SLICE_OP(float); +ADD_TYPED_SLICE_OP(double); +ADD_TYPED_SLICE_OP(MLFloat16); +ADD_TYPED_SLICE_OP(bool); +ADD_TYPED_SLICE_OP(string); -#define ADD_TYPED_DYNAMIC_SLICE_OP(data_type, indice_type) \ - ONNX_CPU_OPERATOR_TYPED_KERNEL( \ - DynamicSlice, \ - 1, \ - data_type##_##indice_type, \ - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ - Slice); +#define ADD_TYPED_DYNAMIC_SLICE_OP(data_type) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + DynamicSlice, \ + 1, \ + data_type, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("Tind", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}), \ + Slice); -ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t, int32_t); -ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t, int32_t); -ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t, int32_t); -ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t, int32_t); -ADD_TYPED_DYNAMIC_SLICE_OP(int8_t, int32_t); -ADD_TYPED_DYNAMIC_SLICE_OP(int16_t, int32_t); -ADD_TYPED_DYNAMIC_SLICE_OP(int32_t, int32_t); -ADD_TYPED_DYNAMIC_SLICE_OP(int64_t, int32_t); -ADD_TYPED_DYNAMIC_SLICE_OP(float, int32_t); -ADD_TYPED_DYNAMIC_SLICE_OP(double, int32_t); -ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16,int32_t); -ADD_TYPED_DYNAMIC_SLICE_OP(bool, int32_t); -ADD_TYPED_DYNAMIC_SLICE_OP(string, int32_t); - -ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t, int64_t); -ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t, int64_t); -ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t, int64_t); -ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t, int64_t); -ADD_TYPED_DYNAMIC_SLICE_OP(int8_t, int64_t); -ADD_TYPED_DYNAMIC_SLICE_OP(int16_t, int64_t); -ADD_TYPED_DYNAMIC_SLICE_OP(int32_t, int64_t); -ADD_TYPED_DYNAMIC_SLICE_OP(int64_t, int64_t); -ADD_TYPED_DYNAMIC_SLICE_OP(float, int64_t); -ADD_TYPED_DYNAMIC_SLICE_OP(double, int64_t); -ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16,int64_t); -ADD_TYPED_DYNAMIC_SLICE_OP(bool, int64_t); -ADD_TYPED_DYNAMIC_SLICE_OP(string, int64_t); +ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t); +ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t); +ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t); +ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t); +ADD_TYPED_DYNAMIC_SLICE_OP(int8_t); +ADD_TYPED_DYNAMIC_SLICE_OP(int16_t); +ADD_TYPED_DYNAMIC_SLICE_OP(int32_t); +ADD_TYPED_DYNAMIC_SLICE_OP(int64_t); +ADD_TYPED_DYNAMIC_SLICE_OP(float); +ADD_TYPED_DYNAMIC_SLICE_OP(double); +ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16); +ADD_TYPED_DYNAMIC_SLICE_OP(bool); +ADD_TYPED_DYNAMIC_SLICE_OP(string); namespace { // std::clamp doesn't exist until C++17 so create a local version @@ -78,12 +65,11 @@ const T& clamp(const T& v, const T& lo, const T& hi) { } // namespace Status SliceBase::PrepareForCompute(const std::vector& raw_starts, - const std::vector& raw_ends, + const std::vector& raw_ends, const std::vector& raw_axes, - const size_t dimension_count, const std::vector& input_dimensions, - std::vector& starts, - std::vector& output_dims) const { + std::vector& starts, + std::vector& output_dims) const { // Initialize axes to the provided axes attribute or to the default sequence std::vector axes(raw_axes); if (axes.size() == 0) { @@ -93,6 +79,7 @@ Status SliceBase::PrepareForCompute(const std::vector& raw_starts, } // Iterate through the provided axes and override the start/end ranges + const auto& dimension_count = input_dimensions.size(); for (size_t axesIndex = 0; axesIndex < axes.size(); axesIndex++) { auto axis = axes[axesIndex] < 0 ? axes[axesIndex] + static_cast(dimension_count) : axes[axesIndex]; if (axis >= static_cast(dimension_count) || axis < 0) @@ -113,63 +100,84 @@ Status SliceBase::PrepareForCompute(const std::vector& raw_starts, return Status::OK(); } -template void SliceBase::FillVectorsFromInput(const OpKernelContext* context, - std::vector& input_starts, - std::vector& input_ends, - std::vector& input_axes) const { - auto stat_tensor = context->Input(1); + std::vector& input_starts, + std::vector& input_ends, + std::vector& input_axes) const { + auto start_tensor = context->Input(1); auto ends_tensor = context->Input(2); auto axes_tensor = context->Input(3); - ORT_ENFORCE (nullptr != stat_tensor && stat_tensor->Shape().NumDimensions() == 1, "Starts must be a 1-D array" ); - ORT_ENFORCE (nullptr != ends_tensor && ends_tensor->Shape().NumDimensions() == 1, "ends must be a 1-D array" ); - ORT_ENFORCE (stat_tensor->Shape() == ends_tensor->Shape(), "Starts and ends shape mismatch"); - ORT_ENFORCE (nullptr == axes_tensor || stat_tensor->Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch"); + ORT_ENFORCE(nullptr != start_tensor && start_tensor->Shape().NumDimensions() == 1, "Starts must be a 1-D array"); + ORT_ENFORCE(nullptr != ends_tensor && ends_tensor->Shape().NumDimensions() == 1, "Ends must be a 1-D array"); + ORT_ENFORCE(start_tensor->Shape() == ends_tensor->Shape(), "Starts and ends shape mismatch"); + ORT_ENFORCE(nullptr == axes_tensor || start_tensor->Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch"); - auto size = stat_tensor->Shape().Size(); + const auto& dtype = start_tensor->DataType(); + const auto& size = start_tensor->Shape().Size(); input_starts.resize(size); - std::copy(stat_tensor->Data(), stat_tensor->Data() + size, input_starts.begin()); input_ends.resize(size); - std::copy(ends_tensor->Data(), ends_tensor->Data() + size, input_ends.begin()); - if (nullptr != axes_tensor) { + if (nullptr != axes_tensor) input_axes.resize(size); - std::copy(axes_tensor->Data(), axes_tensor->Data() + size, input_axes.begin()); + + if (dtype == DataTypeImpl::GetType()) { + std::copy(start_tensor->Data(), start_tensor->Data() + size, input_starts.begin()); + std::copy(ends_tensor->Data(), ends_tensor->Data() + size, input_ends.begin()); + if (nullptr != axes_tensor) + std::copy(axes_tensor->Data(), axes_tensor->Data() + size, input_axes.begin()); + } + + else if (dtype == DataTypeImpl::GetType()) { + std::copy(start_tensor->Data(), start_tensor->Data() + size, input_starts.begin()); + std::copy(ends_tensor->Data(), ends_tensor->Data() + size, input_ends.begin()); + if (nullptr != axes_tensor) + std::copy(axes_tensor->Data(), axes_tensor->Data() + size, input_axes.begin()); + } + + // should not reach this as no kernel is registered for this condition to be triggered - just an additional safety check + else { + ORT_THROW("Data type for starts and ends inputs' need to be int32_t or int64_t, but instead got ", dtype); } } -template -Status Slice::Compute(OpKernelContext* ctx) const { - const Tensor* input_tensor_ptr = ctx->Input(0); - ORT_ENFORCE(input_tensor_ptr != nullptr); - auto& input_tensor = *input_tensor_ptr; - auto& input_dimensions = input_tensor.Shape().GetDims(); - - // Initialize the starts & ends to the actual tensor shape - const size_t dimension_count = input_dimensions.size(); - std::vector starts(dimension_count, 0); - std::vector output_dims(input_dimensions); - - if (dynamic) { - std::vector input_starts, input_ends, input_axes; - FillVectorsFromInput(ctx, input_starts, input_ends, input_axes); - ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, - dimension_count, input_dimensions, starts, output_dims)); - } else { - ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_, - dimension_count, input_dimensions, starts, output_dims)); - } - +template +Status SliceImpl(OpKernelContext* ctx, + const Tensor& input_tensor, + std::vector& output_dims, + const std::vector& starts) { TensorShape output_shape(output_dims); auto& output_tensor = *ctx->Output(0, output_shape); auto* output = output_tensor.template MutableData(); - const auto* output_end = output + output_shape.Size(); + const auto* output_end = output + output_tensor.Shape().Size(); - SliceIterator input_iterator(input_tensor, starts, output_dims); + SliceIterator input_iterator(input_tensor, starts, output_tensor.Shape().GetDims()); while (output != output_end) *output++ = *input_iterator++; return Status::OK(); } +template +Status Slice::Compute(OpKernelContext* ctx) const { + const Tensor* input_tensor_ptr = ctx->Input(0); + ORT_ENFORCE(input_tensor_ptr != nullptr, "Missing input tensor to be processed"); + const auto& input_tensor = *input_tensor_ptr; + const auto& input_dimensions = input_tensor.Shape().GetDims(); + + // Initialize the starts & ends to the actual tensor shape + std::vector starts(input_dimensions.size(), 0); + std::vector output_dims(input_dimensions); + + if (dynamic) { + std::vector input_starts, input_ends, input_axes; + FillVectorsFromInput(ctx, input_starts, input_ends, input_axes); + ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, + input_dimensions, starts, output_dims)); + } else { + ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_, + input_dimensions, starts, output_dims)); + } + + return SliceImpl(ctx, input_tensor, output_dims, starts); +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/slice.h b/onnxruntime/core/providers/cpu/tensor/slice.h index 82e8cd7208..3e545bf1e5 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.h +++ b/onnxruntime/core/providers/cpu/tensor/slice.h @@ -9,35 +9,34 @@ namespace onnxruntime { class SliceBase { protected: - SliceBase (const OpKernelInfo& info, bool dynamic = false) { + SliceBase(const OpKernelInfo& info, bool dynamic = false) { if (!dynamic) { auto has_starts = info.GetAttrs("starts", attr_starts_).IsOK(); - auto has_ends = info.GetAttrs("ends", attr_ends_).IsOK(); - auto has_axes = info.GetAttrs("axes", attr_axes_).IsOK(); + auto has_ends = info.GetAttrs("ends", attr_ends_).IsOK(); + auto has_axes = info.GetAttrs("axes", attr_axes_).IsOK(); ORT_ENFORCE(has_starts && has_ends && attr_starts_.size() == attr_ends_.size(), - "Missing or invalid starts and ends attribute"); + "Missing or invalid starts and ends attribute"); ORT_ENFORCE(!has_axes || attr_axes_.size() == attr_starts_.size(), - "Invalid axes attribute"); + "Invalid axes attribute, axes attribute (if present) should have the same size as starts/ends attributes"); } } Status PrepareForCompute(const std::vector& raw_starts, - const std::vector& raw_ends, + const std::vector& raw_ends, const std::vector& raw_axes, - const size_t dimension_count, const std::vector& input_dimensions, - std::vector& starts, - std::vector& output_dims) const; - template + std::vector& starts, + std::vector& output_dims) const; + void FillVectorsFromInput(const OpKernelContext* context, - std::vector& raw_starts, - std::vector& raw_ends, - std::vector& raw_axes) const; + std::vector& raw_starts, + std::vector& raw_ends, + std::vector& raw_axes) const; std::vector attr_starts_, attr_ends_, attr_axes_; }; -template +template struct Slice final : public OpKernel, public SliceBase { Slice(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, dynamic) {} Status Compute(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cuda/tensor/slice.cc b/onnxruntime/core/providers/cuda/tensor/slice.cc index fd403e53a9..e3b8197ad6 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.cc +++ b/onnxruntime/core/providers/cuda/tensor/slice.cc @@ -40,13 +40,13 @@ Status Slice::ComputeInternal(OpKernelContext* ctx) const { if (dynamic) { std::vector input_starts, input_ends, input_axes; - FillVectorsFromInput(ctx, input_starts, input_ends, input_axes); + FillVectorsFromInput(ctx, input_starts, input_ends, input_axes); ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, - dimension_count, input_dimensions, starts, output_dims)); + input_dimensions, starts, output_dims)); } else { ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_, - dimension_count, input_dimensions, starts, output_dims)); + input_dimensions, starts, output_dims)); } TensorShape output_shape(output_dims);