From 48e96ea65f286eb5bf4e05bcc9b5011a18bec274 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 8 Apr 2020 07:19:29 +1000 Subject: [PATCH] Reduce binary size of Slice implementation (#3238) * Make the Slice implementation based on type sizes and reduce templatized code to a minimum. * Remove using 'dynamic' as a template param to Slice as well. --- onnxruntime/contrib_ops/cpu/dynamicslice.cc | 32 +--- .../contrib_ops/cpu_contrib_kernels.cc | 28 +-- .../providers/cpu/cpu_execution_provider.cc | 125 +------------- .../core/providers/cpu/tensor/slice.cc | 134 ++++++--------- onnxruntime/core/providers/cpu/tensor/slice.h | 26 ++- onnxruntime/core/providers/cpu/tensor/utils.h | 162 +++++++++++++----- .../core/providers/cuda/tensor/slice.cc | 2 +- 7 files changed, 220 insertions(+), 289 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/dynamicslice.cc b/onnxruntime/contrib_ops/cpu/dynamicslice.cc index f039b4de2e..58881ea3d3 100644 --- a/onnxruntime/contrib_ops/cpu/dynamicslice.cc +++ b/onnxruntime/contrib_ops/cpu/dynamicslice.cc @@ -3,33 +3,19 @@ #include "core/providers/cpu/tensor/slice.h" -using namespace ::onnxruntime::common; +using namespace onnxruntime::common; using namespace std; namespace onnxruntime { namespace contrib { -#define ADD_TYPED_DYNAMIC_SLICE_OP(data_type) \ - ONNX_CPU_OPERATOR_TYPED_KERNEL( \ - DynamicSlice, \ - 1, \ - data_type, \ - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()).TypeConstraint("Tind", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), \ - Slice); +ONNX_CPU_OPERATOR_KERNEL( + DynamicSlice, + 1, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .TypeConstraint("Tind", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), + Slice10); -ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t); -ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t); -ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t); -ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t); -ADD_TYPED_DYNAMIC_SLICE_OP(int8_t); -ADD_TYPED_DYNAMIC_SLICE_OP(int16_t); -ADD_TYPED_DYNAMIC_SLICE_OP(int32_t); -ADD_TYPED_DYNAMIC_SLICE_OP(int64_t); -ADD_TYPED_DYNAMIC_SLICE_OP(float); -ADD_TYPED_DYNAMIC_SLICE_OP(double); -ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16); -ADD_TYPED_DYNAMIC_SLICE_OP(bool); -ADD_TYPED_DYNAMIC_SLICE_OP(string); - -} // namespace contrib_ops +} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu_contrib_kernels.cc index 4e0c7f6872..315e4857c5 100644 --- a/onnxruntime/contrib_ops/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu_contrib_kernels.cc @@ -42,19 +42,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastG // we cannot change the domain now as this will break backward compatibility. class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Affine); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Crop); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, DynamicSlice); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, DynamicSlice); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ImageScaler); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MeanVarianceNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ParametricSoftplus); @@ -126,19 +114,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { // contrib ops to main backward compatibility BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 47c87eb2ec..c0167f23b0 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -189,19 +189,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 5, Reshape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Size); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, bool, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, float, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, double, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, MLFloat16, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, uint8_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, uint16_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, uint32_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, uint64_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, int8_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, int16_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, int32_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, int64_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, string, Slice); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, Slice); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, DepthToSpace); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 2, 10, Split); @@ -300,19 +288,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearConv); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, bool, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, float, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, double, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, MLFloat16, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, uint8_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, uint16_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, uint32_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, uint64_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, int8_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, int16_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, int32_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, int64_t, Slice); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, string, Slice); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, Slice); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, Dropout); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf); @@ -371,19 +347,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Fl class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Compress); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Concat); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gather); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, bool, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, double, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, MLFloat16, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint16_t, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint64_t, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int8_t, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int16_t, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int32_t, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int64_t, Slice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, string, Slice); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Slice); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Split); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Squeeze); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Unsqueeze); @@ -773,32 +737,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -969,32 +909,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { MatMulInteger)>, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1044,32 +960,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/tensor/slice.cc b/onnxruntime/core/providers/cpu/tensor/slice.cc index 0e43718b81..1a6ab71837 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.cc +++ b/onnxruntime/core/providers/cpu/tensor/slice.cc @@ -11,74 +11,29 @@ using namespace ::onnxruntime::common; using namespace std; namespace onnxruntime { -#define ADD_TYPED_SLICE_V9_OP(data_type) \ - ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ - Slice, \ - 1, 9, \ - data_type, \ - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Slice); +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( + Slice, + 1, 9, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()), + Slice1); -ADD_TYPED_SLICE_V9_OP(uint8_t); -ADD_TYPED_SLICE_V9_OP(uint16_t); -ADD_TYPED_SLICE_V9_OP(uint32_t); -ADD_TYPED_SLICE_V9_OP(uint64_t); -ADD_TYPED_SLICE_V9_OP(int8_t); -ADD_TYPED_SLICE_V9_OP(int16_t); -ADD_TYPED_SLICE_V9_OP(int32_t); -ADD_TYPED_SLICE_V9_OP(int64_t); -ADD_TYPED_SLICE_V9_OP(float); -ADD_TYPED_SLICE_V9_OP(double); -ADD_TYPED_SLICE_V9_OP(MLFloat16); -ADD_TYPED_SLICE_V9_OP(bool); -ADD_TYPED_SLICE_V9_OP(string); +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( + Slice, + 10, 10, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .TypeConstraint("Tind", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + Slice10); -#define ADD_TYPED_SLICE_V10_OP(data_type) \ - ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ - Slice, \ - 10, \ - 10, \ - data_type, \ - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()).TypeConstraint("Tind", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), \ - Slice); - -ADD_TYPED_SLICE_V10_OP(uint8_t); -ADD_TYPED_SLICE_V10_OP(uint16_t); -ADD_TYPED_SLICE_V10_OP(uint32_t); -ADD_TYPED_SLICE_V10_OP(uint64_t); -ADD_TYPED_SLICE_V10_OP(int8_t); -ADD_TYPED_SLICE_V10_OP(int16_t); -ADD_TYPED_SLICE_V10_OP(int32_t); -ADD_TYPED_SLICE_V10_OP(int64_t); -ADD_TYPED_SLICE_V10_OP(float); -ADD_TYPED_SLICE_V10_OP(double); -ADD_TYPED_SLICE_V10_OP(MLFloat16); -ADD_TYPED_SLICE_V10_OP(bool); -ADD_TYPED_SLICE_V10_OP(string); - -#define ADD_TYPED_SLICE_V11_OP(data_type) \ - ONNX_CPU_OPERATOR_TYPED_KERNEL( \ - Slice, \ - 11, \ - data_type, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("Tind", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), \ - Slice); - -ADD_TYPED_SLICE_V11_OP(uint8_t); -ADD_TYPED_SLICE_V11_OP(uint16_t); -ADD_TYPED_SLICE_V11_OP(uint32_t); -ADD_TYPED_SLICE_V11_OP(uint64_t); -ADD_TYPED_SLICE_V11_OP(int8_t); -ADD_TYPED_SLICE_V11_OP(int16_t); -ADD_TYPED_SLICE_V11_OP(int32_t); -ADD_TYPED_SLICE_V11_OP(int64_t); -ADD_TYPED_SLICE_V11_OP(float); -ADD_TYPED_SLICE_V11_OP(double); -ADD_TYPED_SLICE_V11_OP(MLFloat16); -ADD_TYPED_SLICE_V11_OP(bool); -ADD_TYPED_SLICE_V11_OP(string); +ONNX_CPU_OPERATOR_KERNEL( + Slice, + 11, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .TypeConstraint("Tind", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + Slice10); namespace { // std::clamp doesn't exist until C++17 so create a local version @@ -315,12 +270,12 @@ void SliceBase::FillVectorsFromInput(const OpKernelContext* context, } template -Status SliceImpl(OpKernelContext* ctx, - const Tensor& input_tensor, - std::vector& output_dims, - std::vector* flattened_output_dims, - const std::vector& starts, - const std::vector& steps) { +static Status SliceImpl(OpKernelContext* ctx, + const Tensor& input_tensor, + std::vector& output_dims, + std::vector* flattened_output_dims, + const std::vector& starts, + const std::vector& steps) { TensorShape output_shape(output_dims); auto& output_tensor = *ctx->Output(0, output_shape); @@ -328,7 +283,8 @@ Status SliceImpl(OpKernelContext* ctx, if (output_shape.Size() == 0) return Status::OK(); - auto* output = output_tensor.template MutableData(); + // use MutableDataRaw as actual data type in tensor may not match as we templatize on data size + T* output = reinterpret_cast(output_tensor.MutableDataRaw()); const auto* output_end = output + output_tensor.Shape().Size(); auto create_output = [&output, &output_end](SliceIterator& input_iterator) { @@ -363,8 +319,7 @@ Status SliceImpl(OpKernelContext* ctx, return Status::OK(); } -template -Status Slice::Compute(OpKernelContext* ctx) const { +Status SliceBase::Compute(OpKernelContext* ctx) const { const auto* input_tensor_ptr = ctx->Input(0); ORT_ENFORCE(input_tensor_ptr != nullptr, "Missing input tensor to be processed"); const auto& input_tensor = *input_tensor_ptr; @@ -379,7 +334,7 @@ Status Slice::Compute(OpKernelContext* ctx) const { std::vector* p_flattened_output_dims = &flattened_output_dims; // Slice V10 & DynamicSlice - if (dynamic) { + if (dynamic_) { std::vector input_starts; std::vector input_ends; std::vector input_axes; @@ -396,6 +351,31 @@ Status Slice::Compute(OpKernelContext* ctx) const { p_flattened_output_dims)); } - return SliceImpl(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps); + Status status = Status::OK(); + + if (input_tensor.IsDataTypeString()) { + status = SliceImpl(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps); + } else { + const auto element_size = input_tensor.DataType()->Size(); + + switch (element_size) { + case sizeof(uint32_t): + status = SliceImpl(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps); + break; + case sizeof(uint64_t): + status = SliceImpl(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps); + break; + case sizeof(uint16_t): + status = SliceImpl(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps); + break; + case sizeof(uint8_t): + status = SliceImpl(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps); + break; + default: + ORT_THROW("Unsupported input data type of ", input_tensor.DataType()); + } + } + return status; } + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/slice.h b/onnxruntime/core/providers/cpu/tensor/slice.h index 4435c68b9d..f81152403d 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.h +++ b/onnxruntime/core/providers/cpu/tensor/slice.h @@ -9,7 +9,8 @@ namespace onnxruntime { class SliceBase { protected: - SliceBase(const OpKernelInfo& info, bool dynamic = false) { + SliceBase(const OpKernelInfo& info, bool dynamic = false) + : dynamic_(dynamic) { if (!dynamic) { auto has_starts = info.GetAttrs("starts", attr_starts_).IsOK(); auto has_ends = info.GetAttrs("ends", attr_ends_).IsOK(); @@ -49,13 +50,26 @@ class SliceBase { std::vector& input_axes, std::vector& input_steps) const; + Status Compute(OpKernelContext* context) const; + + protected: + const std::vector& StartsAttribute() const { return attr_starts_; } + const std::vector& EndsAttribute() const { return attr_ends_; } + const std::vector& AxesAttribute() const { return attr_axes_; } + + private: + bool dynamic_; std::vector attr_starts_, attr_ends_, attr_axes_; }; -template -struct Slice final : public OpKernel, public SliceBase { - Slice(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, dynamic) {} - Status Compute(OpKernelContext* context) const override; -}; // namespace onnxruntime +struct Slice1 final : public OpKernel, public SliceBase { + Slice1(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, false) {} + Status Compute(OpKernelContext* context) const override { return SliceBase::Compute(context); } +}; + +struct Slice10 final : public OpKernel, public SliceBase { + Slice10(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, true) {} + Status Compute(OpKernelContext* context) const override { return SliceBase::Compute(context); } +}; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/utils.h b/onnxruntime/core/providers/cpu/tensor/utils.h index a1a9773a85..c905fa32b8 100644 --- a/onnxruntime/core/providers/cpu/tensor/utils.h +++ b/onnxruntime/core/providers/cpu/tensor/utils.h @@ -157,10 +157,14 @@ struct SliceSkips : std::vector { }; // This provides easy sequential iteration over a subset of a tensor given a span of starts, extents & optionally steps -template -struct SliceIterator { - SliceIterator(const Tensor& tensor, gsl::span starts, - gsl::span extents, gsl::span steps) +// The base class is type agnostic to minimize binary size. The derived class provides any type specific logic. +struct SliceIteratorBase { + private: + enum class byte : unsigned char {}; + + protected: + SliceIteratorBase(const Tensor& tensor, gsl::span starts, + gsl::span extents, gsl::span steps) : tensor_(tensor), extents_(extents), skips_(tensor_.Shape(), extents, steps), indices_(extents.size(), 0) { auto& dims = tensor_.Shape().GetDims(); Init(dims, starts, steps); @@ -170,16 +174,15 @@ struct SliceIterator { // The explicit tensor_shape usually has inner most axis flattened. For example, given shape[1,4,4,2], if last axis // does not have padding or slice, then it will be flattened as [1,4,8] for better performance (One inner most copy instead of 4). // Also supports arbitrary positive and negative stepping along individual axes - SliceIterator(const Tensor& tensor, const TensorShape& tensor_shape, gsl::span starts, - gsl::span extents, gsl::span steps) + SliceIteratorBase(const Tensor& tensor, const TensorShape& tensor_shape, gsl::span starts, + gsl::span extents, gsl::span steps) : tensor_(tensor), extents_(extents), skips_(tensor_shape, extents, steps), indices_(extents.size(), 0) { const auto& dims = tensor_shape.GetDims(); Init(dims, starts, steps); } // Initialize initial skip and inner_extent. - void Init(const std::vector& dims, gsl::span starts, - gsl::span steps) { + void Init(const std::vector& dims, gsl::span starts, gsl::span steps) { ORT_ENFORCE(dims.size() == starts.size() && dims.size() == extents_.size() && dims.size() >= steps.size()); @@ -187,7 +190,7 @@ struct SliceIterator { size_t pitch = 1; // Initial skip, so that input_ points to the first element to copy for (size_t i = dims.size(); i-- > 0;) { - input_ += pitch * starts[i]; + input_ += pitch * starts[i] * element_size_; pitch *= dims[i]; } @@ -199,38 +202,74 @@ struct SliceIterator { void AdvanceOverInnerExtent() { size_t axis = skips_.size() - 1; - input_ += skips_[axis]; + input_ += skips_[axis] * element_size_; while (axis-- && ++indices_[axis] == extents_[axis]) { indices_[axis] = 0; - input_ += skips_[axis]; + input_ += skips_[axis] * element_size_; } } void IncrementInnerDimension() { - input_ += inner_step_; + input_ += inner_step_ * element_size_; if (++inner_counter_ == inner_extent_) { inner_counter_ = 0; AdvanceOverInnerExtent(); } } - // postfix iterator increment - const T* operator++(int) { - const T* input = input_; - IncrementInnerDimension(); - return input; - } - - // prefix iterator increment - const T* operator++() { - IncrementInnerDimension(); + const void* cur_input() const { return input_; } - const T& operator*() const { - return *input_; + // Assumes SolitaryInnerStep() == true + void* CopyInnermostAxisSolitaryInnerStep(void* output) { + byte* out_bytes = reinterpret_cast(output); + auto bytes_to_copy = inner_extent_ * element_size_; + + if (!is_string_tensor_) { + std::copy(input_, input_ + bytes_to_copy, out_bytes); + } else { + const std::string* input = reinterpret_cast(input_); + std::string* out = reinterpret_cast(output); + std::copy(input, input + inner_extent_, out); + } + + input_ += bytes_to_copy; + out_bytes += bytes_to_copy; + AdvanceOverInnerExtent(); + + return out_bytes; } + // Assumes generic inner_step_ + void* CopyInnermostAxisNonSolitaryInnerStep(void* output) { + // need to special case std::string so the copy works correctly + if (!is_string_tensor_) { + // switch on element size so copy is efficient + switch (element_size_) { + case sizeof(uint8_t): + output = TypedCopyInnermostAxisNonSolitaryInnerStep(output); + break; + case sizeof(uint16_t): + output = TypedCopyInnermostAxisNonSolitaryInnerStep(output); + break; + case sizeof(uint32_t): + output = TypedCopyInnermostAxisNonSolitaryInnerStep(output); + break; + case sizeof(uint64_t): + output = TypedCopyInnermostAxisNonSolitaryInnerStep(output); + break; + default: + ORT_THROW("Unexpected element size of ", element_size_); + } + } else { + output = TypedCopyInnermostAxisNonSolitaryInnerStep(output); + } + + return output; + } + + public: // splitting the function that copies the innermost dimension into 2 separate methods, // CopyInnermostAxisSolitaryInnerStep and CopyInnermostAxisNonSolitaryInnerStep, // as this is most likely being called within a loop @@ -238,33 +277,78 @@ struct SliceIterator { // up to the caller to call the correct one based on SolitaryInnerStep(). bool SolitaryInnerStep() const { return inner_step_ == 1; } - // Assumes SolitaryInnerStep() == true - T* CopyInnermostAxisSolitaryInnerStep(T* output) { - std::copy(input_, input_ + inner_extent_, output); - input_ += inner_extent_; - output += inner_extent_; - AdvanceOverInnerExtent(); - return output; - } - - // Assumes generic inner_step_ - T* CopyInnermostAxisNonSolitaryInnerStep(T* output) { + private: + template + void* TypedCopyInnermostAxisNonSolitaryInnerStep(void* output) { + // sizeof(T) == element_size_ + T* out = reinterpret_cast(output); for (size_t i = 0; i < inner_extent_; ++i) { - *output++ = *input_; + *out++ = *reinterpret_cast(input_); IncrementInnerDimension(); } - return output; + + return out; } - private: const Tensor& tensor_; - const T* input_{tensor_.template Data()}; + const bool is_string_tensor_{tensor_.IsDataTypeString()}; + // we do everything in this class using bytes to minimize binary size + const byte* input_{reinterpret_cast(tensor_.DataRaw())}; + const int64_t element_size_ = tensor_.DataType()->Size(); + gsl::span extents_; size_t inner_counter_{}, inner_extent_, inner_step_; SliceSkips skips_; std::vector indices_; // There is no index for innermost axis since it's a special case }; +// This provides easy sequential iteration over a subset of a tensor given a span of starts, extents & optionally steps +template +struct SliceIterator : public SliceIteratorBase { + SliceIterator(const Tensor& tensor, gsl::span starts, + gsl::span extents, gsl::span steps) + : SliceIteratorBase(tensor, starts, extents, steps) { + } + + // This construct takes a explicit tensor_shape which might be different from the shape defined in input tensor. + // The explicit tensor_shape usually has inner most axis flattened. For example, given shape[1,4,4,2], if last axis + // does not have padding or slice, then it will be flattened as [1,4,8] for better performance (One inner most copy instead of 4). + // Also supports arbitrary positive and negative stepping along individual axes + SliceIterator(const Tensor& tensor, const TensorShape& tensor_shape, gsl::span starts, + gsl::span extents, gsl::span steps) + : SliceIteratorBase(tensor, tensor_shape, starts, extents, steps) { + } + + // postfix iterator increment + const T* operator++(int) { + const T* input = static_cast(cur_input()); + IncrementInnerDimension(); + return input; + } + + // prefix iterator increment + const T* operator++() { + IncrementInnerDimension(); + return static_cast(cur_input()); + } + + const T& operator*() const { + return *static_cast(cur_input()); + } + + // Assumes SolitaryInnerStep() == true + T* CopyInnermostAxisSolitaryInnerStep(T* output) { + void* new_output = SliceIteratorBase::CopyInnermostAxisSolitaryInnerStep(output); + return static_cast(new_output); + } + + // Assumes generic inner_step_ + T* CopyInnermostAxisNonSolitaryInnerStep(T* output) { + void* new_output = SliceIteratorBase::CopyInnermostAxisNonSolitaryInnerStep(output); + return static_cast(new_output); + } +}; + inline void CopyCpuTensor(const Tensor* src, Tensor* tgt) { void* target = tgt->MutableDataRaw(); const void* source = src->DataRaw(); diff --git a/onnxruntime/core/providers/cuda/tensor/slice.cc b/onnxruntime/core/providers/cuda/tensor/slice.cc index ffcdbb23f2..2660ef3bc9 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.cc +++ b/onnxruntime/core/providers/cuda/tensor/slice.cc @@ -85,7 +85,7 @@ Status Slice::ComputeInternal(OpKernelContext* ctx) const { p_flattened_output_dims)); } else { - ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_, + ORT_RETURN_IF_ERROR(PrepareForCompute(StartsAttribute(), EndsAttribute(), AxesAttribute(), input_dimensions, starts, steps, output_dims, p_flattened_output_dims)); }