Reduce binary size of Slice implementation (#3238)

* Make the Slice implementation based on type sizes and reduce templatized code to a minimum.

* Remove using 'dynamic' as a template param to Slice as well.
This commit is contained in:
Scott McKay 2020-04-08 07:19:29 +10:00 committed by GitHub
parent 53b9d52fc6
commit 48e96ea65f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 220 additions and 289 deletions

View file

@ -3,33 +3,19 @@
#include "core/providers/cpu/tensor/slice.h"
using namespace ::onnxruntime::common;
using namespace onnxruntime::common;
using namespace std;
namespace onnxruntime {
namespace contrib {
#define ADD_TYPED_DYNAMIC_SLICE_OP(data_type) \
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
DynamicSlice, \
1, \
data_type, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()).TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}), \
Slice<data_type, true>);
ONNX_CPU_OPERATOR_KERNEL(
DynamicSlice,
1,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
Slice10);
ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int8_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int16_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(float);
ADD_TYPED_DYNAMIC_SLICE_OP(double);
ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16);
ADD_TYPED_DYNAMIC_SLICE_OP(bool);
ADD_TYPED_DYNAMIC_SLICE_OP(string);
} // namespace contrib_ops
} // namespace contrib
} // namespace onnxruntime

View file

@ -42,19 +42,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastG
// we cannot change the domain now as this will break backward compatibility.
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Affine);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, DynamicSlice);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, DynamicSlice);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ImageScaler);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MeanVarianceNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ParametricSoftplus);
@ -126,19 +114,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
// contrib ops to main backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Affine)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ImageScaler)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MeanVarianceNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ParametricSoftplus)>,

View file

@ -189,19 +189,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 5, Reshape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Size);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, bool, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, float, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, double, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, MLFloat16, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, uint8_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, uint16_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, uint32_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, uint64_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, int8_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, int16_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, int32_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, int64_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, string, Slice);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, Slice);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, DepthToSpace);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 2, 10, Split);
@ -300,19 +288,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearConv);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, bool, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, float, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, double, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, MLFloat16, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, uint8_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, uint16_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, uint32_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, uint64_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, int8_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, int16_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, int32_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, int64_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, string, Slice);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, Slice);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, Dropout);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf);
@ -371,19 +347,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Fl
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Compress);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Concat);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gather);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, bool, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, double, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, MLFloat16, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint16_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint64_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int8_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int16_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int32_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int64_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, string, Slice);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Slice);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Split);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Squeeze);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Unsqueeze);
@ -773,32 +737,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 5, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Size)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
bool, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
float, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
double, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
MLFloat16, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
uint8_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
uint16_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
uint32_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
uint64_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
int8_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
int16_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
int32_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
int64_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
string, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
DepthToSpace)>,
@ -969,32 +909,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
MatMulInteger)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearConv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
bool, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
float, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
double, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
MLFloat16, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
uint8_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
uint16_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
uint32_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
uint64_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
int8_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
int16_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
int32_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
int64_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
string, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, Dropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
NonMaxSuppression)>,
@ -1044,32 +960,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Compress)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Concat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, bool,
Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float,
Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, double,
Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, MLFloat16,
Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t,
Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint16_t,
Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t,
Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint64_t,
Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int8_t,
Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int16_t,
Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int32_t,
Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int64_t,
Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, string,
Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Unsqueeze)>,

View file

@ -11,74 +11,29 @@ using namespace ::onnxruntime::common;
using namespace std;
namespace onnxruntime {
#define ADD_TYPED_SLICE_V9_OP(data_type) \
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
Slice, \
1, 9, \
data_type, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()), \
Slice<data_type, false>);
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Slice,
1, 9,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
Slice1);
ADD_TYPED_SLICE_V9_OP(uint8_t);
ADD_TYPED_SLICE_V9_OP(uint16_t);
ADD_TYPED_SLICE_V9_OP(uint32_t);
ADD_TYPED_SLICE_V9_OP(uint64_t);
ADD_TYPED_SLICE_V9_OP(int8_t);
ADD_TYPED_SLICE_V9_OP(int16_t);
ADD_TYPED_SLICE_V9_OP(int32_t);
ADD_TYPED_SLICE_V9_OP(int64_t);
ADD_TYPED_SLICE_V9_OP(float);
ADD_TYPED_SLICE_V9_OP(double);
ADD_TYPED_SLICE_V9_OP(MLFloat16);
ADD_TYPED_SLICE_V9_OP(bool);
ADD_TYPED_SLICE_V9_OP(string);
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Slice,
10, 10,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
Slice10);
#define ADD_TYPED_SLICE_V10_OP(data_type) \
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
Slice, \
10, \
10, \
data_type, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()).TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}), \
Slice<data_type, true>);
ADD_TYPED_SLICE_V10_OP(uint8_t);
ADD_TYPED_SLICE_V10_OP(uint16_t);
ADD_TYPED_SLICE_V10_OP(uint32_t);
ADD_TYPED_SLICE_V10_OP(uint64_t);
ADD_TYPED_SLICE_V10_OP(int8_t);
ADD_TYPED_SLICE_V10_OP(int16_t);
ADD_TYPED_SLICE_V10_OP(int32_t);
ADD_TYPED_SLICE_V10_OP(int64_t);
ADD_TYPED_SLICE_V10_OP(float);
ADD_TYPED_SLICE_V10_OP(double);
ADD_TYPED_SLICE_V10_OP(MLFloat16);
ADD_TYPED_SLICE_V10_OP(bool);
ADD_TYPED_SLICE_V10_OP(string);
#define ADD_TYPED_SLICE_V11_OP(data_type) \
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
Slice, \
11, \
data_type, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()) \
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}), \
Slice<data_type, true>);
ADD_TYPED_SLICE_V11_OP(uint8_t);
ADD_TYPED_SLICE_V11_OP(uint16_t);
ADD_TYPED_SLICE_V11_OP(uint32_t);
ADD_TYPED_SLICE_V11_OP(uint64_t);
ADD_TYPED_SLICE_V11_OP(int8_t);
ADD_TYPED_SLICE_V11_OP(int16_t);
ADD_TYPED_SLICE_V11_OP(int32_t);
ADD_TYPED_SLICE_V11_OP(int64_t);
ADD_TYPED_SLICE_V11_OP(float);
ADD_TYPED_SLICE_V11_OP(double);
ADD_TYPED_SLICE_V11_OP(MLFloat16);
ADD_TYPED_SLICE_V11_OP(bool);
ADD_TYPED_SLICE_V11_OP(string);
ONNX_CPU_OPERATOR_KERNEL(
Slice,
11,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
Slice10);
namespace {
// std::clamp doesn't exist until C++17 so create a local version
@ -315,12 +270,12 @@ void SliceBase::FillVectorsFromInput(const OpKernelContext* context,
}
template <typename T>
Status SliceImpl(OpKernelContext* ctx,
const Tensor& input_tensor,
std::vector<int64_t>& output_dims,
std::vector<int64_t>* flattened_output_dims,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& steps) {
static Status SliceImpl(OpKernelContext* ctx,
const Tensor& input_tensor,
std::vector<int64_t>& output_dims,
std::vector<int64_t>* flattened_output_dims,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& steps) {
TensorShape output_shape(output_dims);
auto& output_tensor = *ctx->Output(0, output_shape);
@ -328,7 +283,8 @@ Status SliceImpl(OpKernelContext* ctx,
if (output_shape.Size() == 0)
return Status::OK();
auto* output = output_tensor.template MutableData<T>();
// use MutableDataRaw as actual data type in tensor may not match as we templatize on data size
T* output = reinterpret_cast<T*>(output_tensor.MutableDataRaw());
const auto* output_end = output + output_tensor.Shape().Size();
auto create_output = [&output, &output_end](SliceIterator<T>& input_iterator) {
@ -363,8 +319,7 @@ Status SliceImpl(OpKernelContext* ctx,
return Status::OK();
}
template <typename T, bool dynamic>
Status Slice<T, dynamic>::Compute(OpKernelContext* ctx) const {
Status SliceBase::Compute(OpKernelContext* ctx) const {
const auto* input_tensor_ptr = ctx->Input<Tensor>(0);
ORT_ENFORCE(input_tensor_ptr != nullptr, "Missing input tensor to be processed");
const auto& input_tensor = *input_tensor_ptr;
@ -379,7 +334,7 @@ Status Slice<T, dynamic>::Compute(OpKernelContext* ctx) const {
std::vector<int64_t>* p_flattened_output_dims = &flattened_output_dims;
// Slice V10 & DynamicSlice
if (dynamic) {
if (dynamic_) {
std::vector<int64_t> input_starts;
std::vector<int64_t> input_ends;
std::vector<int64_t> input_axes;
@ -396,6 +351,31 @@ Status Slice<T, dynamic>::Compute(OpKernelContext* ctx) const {
p_flattened_output_dims));
}
return SliceImpl<T>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
Status status = Status::OK();
if (input_tensor.IsDataTypeString()) {
status = SliceImpl<std::string>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
} else {
const auto element_size = input_tensor.DataType()->Size();
switch (element_size) {
case sizeof(uint32_t):
status = SliceImpl<uint32_t>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
break;
case sizeof(uint64_t):
status = SliceImpl<uint64_t>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
break;
case sizeof(uint16_t):
status = SliceImpl<uint16_t>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
break;
case sizeof(uint8_t):
status = SliceImpl<uint8_t>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
break;
default:
ORT_THROW("Unsupported input data type of ", input_tensor.DataType());
}
}
return status;
}
} // namespace onnxruntime

View file

@ -9,7 +9,8 @@ namespace onnxruntime {
class SliceBase {
protected:
SliceBase(const OpKernelInfo& info, bool dynamic = false) {
SliceBase(const OpKernelInfo& info, bool dynamic = false)
: dynamic_(dynamic) {
if (!dynamic) {
auto has_starts = info.GetAttrs("starts", attr_starts_).IsOK();
auto has_ends = info.GetAttrs("ends", attr_ends_).IsOK();
@ -49,13 +50,26 @@ class SliceBase {
std::vector<int64_t>& input_axes,
std::vector<int64_t>& input_steps) const;
Status Compute(OpKernelContext* context) const;
protected:
const std::vector<int64_t>& StartsAttribute() const { return attr_starts_; }
const std::vector<int64_t>& EndsAttribute() const { return attr_ends_; }
const std::vector<int64_t>& AxesAttribute() const { return attr_axes_; }
private:
bool dynamic_;
std::vector<int64_t> attr_starts_, attr_ends_, attr_axes_;
};
template <typename T, bool dynamic>
struct Slice final : public OpKernel, public SliceBase {
Slice(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, dynamic) {}
Status Compute(OpKernelContext* context) const override;
}; // namespace onnxruntime
struct Slice1 final : public OpKernel, public SliceBase {
Slice1(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, false) {}
Status Compute(OpKernelContext* context) const override { return SliceBase::Compute(context); }
};
struct Slice10 final : public OpKernel, public SliceBase {
Slice10(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, true) {}
Status Compute(OpKernelContext* context) const override { return SliceBase::Compute(context); }
};
} // namespace onnxruntime

View file

@ -157,10 +157,14 @@ struct SliceSkips : std::vector<int64_t> {
};
// This provides easy sequential iteration over a subset of a tensor given a span of starts, extents & optionally steps
template <typename T>
struct SliceIterator {
SliceIterator(const Tensor& tensor, gsl::span<const int64_t> starts,
gsl::span<const int64_t> extents, gsl::span<const int64_t> steps)
// The base class is type agnostic to minimize binary size. The derived class provides any type specific logic.
struct SliceIteratorBase {
private:
enum class byte : unsigned char {};
protected:
SliceIteratorBase(const Tensor& tensor, gsl::span<const int64_t> starts,
gsl::span<const int64_t> extents, gsl::span<const int64_t> steps)
: tensor_(tensor), extents_(extents), skips_(tensor_.Shape(), extents, steps), indices_(extents.size(), 0) {
auto& dims = tensor_.Shape().GetDims();
Init(dims, starts, steps);
@ -170,16 +174,15 @@ struct SliceIterator {
// The explicit tensor_shape usually has inner most axis flattened. For example, given shape[1,4,4,2], if last axis
// does not have padding or slice, then it will be flattened as [1,4,8] for better performance (One inner most copy instead of 4).
// Also supports arbitrary positive and negative stepping along individual axes
SliceIterator(const Tensor& tensor, const TensorShape& tensor_shape, gsl::span<const int64_t> starts,
gsl::span<const int64_t> extents, gsl::span<const int64_t> steps)
SliceIteratorBase(const Tensor& tensor, const TensorShape& tensor_shape, gsl::span<const int64_t> starts,
gsl::span<const int64_t> extents, gsl::span<const int64_t> steps)
: tensor_(tensor), extents_(extents), skips_(tensor_shape, extents, steps), indices_(extents.size(), 0) {
const auto& dims = tensor_shape.GetDims();
Init(dims, starts, steps);
}
// Initialize initial skip and inner_extent.
void Init(const std::vector<int64_t>& dims, gsl::span<const int64_t> starts,
gsl::span<const int64_t> steps) {
void Init(const std::vector<int64_t>& dims, gsl::span<const int64_t> starts, gsl::span<const int64_t> steps) {
ORT_ENFORCE(dims.size() == starts.size() &&
dims.size() == extents_.size() &&
dims.size() >= steps.size());
@ -187,7 +190,7 @@ struct SliceIterator {
size_t pitch = 1;
// Initial skip, so that input_ points to the first element to copy
for (size_t i = dims.size(); i-- > 0;) {
input_ += pitch * starts[i];
input_ += pitch * starts[i] * element_size_;
pitch *= dims[i];
}
@ -199,38 +202,74 @@ struct SliceIterator {
void AdvanceOverInnerExtent() {
size_t axis = skips_.size() - 1;
input_ += skips_[axis];
input_ += skips_[axis] * element_size_;
while (axis-- && ++indices_[axis] == extents_[axis]) {
indices_[axis] = 0;
input_ += skips_[axis];
input_ += skips_[axis] * element_size_;
}
}
void IncrementInnerDimension() {
input_ += inner_step_;
input_ += inner_step_ * element_size_;
if (++inner_counter_ == inner_extent_) {
inner_counter_ = 0;
AdvanceOverInnerExtent();
}
}
// postfix iterator increment
const T* operator++(int) {
const T* input = input_;
IncrementInnerDimension();
return input;
}
// prefix iterator increment
const T* operator++() {
IncrementInnerDimension();
const void* cur_input() const {
return input_;
}
const T& operator*() const {
return *input_;
// Assumes SolitaryInnerStep() == true
void* CopyInnermostAxisSolitaryInnerStep(void* output) {
byte* out_bytes = reinterpret_cast<byte*>(output);
auto bytes_to_copy = inner_extent_ * element_size_;
if (!is_string_tensor_) {
std::copy(input_, input_ + bytes_to_copy, out_bytes);
} else {
const std::string* input = reinterpret_cast<const std::string*>(input_);
std::string* out = reinterpret_cast<std::string*>(output);
std::copy(input, input + inner_extent_, out);
}
input_ += bytes_to_copy;
out_bytes += bytes_to_copy;
AdvanceOverInnerExtent();
return out_bytes;
}
// Assumes generic inner_step_
void* CopyInnermostAxisNonSolitaryInnerStep(void* output) {
// need to special case std::string so the copy works correctly
if (!is_string_tensor_) {
// switch on element size so copy is efficient
switch (element_size_) {
case sizeof(uint8_t):
output = TypedCopyInnermostAxisNonSolitaryInnerStep<uint8_t>(output);
break;
case sizeof(uint16_t):
output = TypedCopyInnermostAxisNonSolitaryInnerStep<uint16_t>(output);
break;
case sizeof(uint32_t):
output = TypedCopyInnermostAxisNonSolitaryInnerStep<uint32_t>(output);
break;
case sizeof(uint64_t):
output = TypedCopyInnermostAxisNonSolitaryInnerStep<uint64_t>(output);
break;
default:
ORT_THROW("Unexpected element size of ", element_size_);
}
} else {
output = TypedCopyInnermostAxisNonSolitaryInnerStep<std::string>(output);
}
return output;
}
public:
// splitting the function that copies the innermost dimension into 2 separate methods,
// CopyInnermostAxisSolitaryInnerStep and CopyInnermostAxisNonSolitaryInnerStep,
// as this is most likely being called within a loop
@ -238,33 +277,78 @@ struct SliceIterator {
// up to the caller to call the correct one based on SolitaryInnerStep().
bool SolitaryInnerStep() const { return inner_step_ == 1; }
// Assumes SolitaryInnerStep() == true
T* CopyInnermostAxisSolitaryInnerStep(T* output) {
std::copy(input_, input_ + inner_extent_, output);
input_ += inner_extent_;
output += inner_extent_;
AdvanceOverInnerExtent();
return output;
}
// Assumes generic inner_step_
T* CopyInnermostAxisNonSolitaryInnerStep(T* output) {
private:
template <typename T>
void* TypedCopyInnermostAxisNonSolitaryInnerStep(void* output) {
// sizeof(T) == element_size_
T* out = reinterpret_cast<T*>(output);
for (size_t i = 0; i < inner_extent_; ++i) {
*output++ = *input_;
*out++ = *reinterpret_cast<const T*>(input_);
IncrementInnerDimension();
}
return output;
return out;
}
private:
const Tensor& tensor_;
const T* input_{tensor_.template Data<T>()};
const bool is_string_tensor_{tensor_.IsDataTypeString()};
// we do everything in this class using bytes to minimize binary size
const byte* input_{reinterpret_cast<const byte*>(tensor_.DataRaw())};
const int64_t element_size_ = tensor_.DataType()->Size();
gsl::span<const int64_t> extents_;
size_t inner_counter_{}, inner_extent_, inner_step_;
SliceSkips skips_;
std::vector<int64_t> indices_; // There is no index for innermost axis since it's a special case
};
// This provides easy sequential iteration over a subset of a tensor given a span of starts, extents & optionally steps
template <typename T>
struct SliceIterator : public SliceIteratorBase {
SliceIterator(const Tensor& tensor, gsl::span<const int64_t> starts,
gsl::span<const int64_t> extents, gsl::span<const int64_t> steps)
: SliceIteratorBase(tensor, starts, extents, steps) {
}
// This construct takes a explicit tensor_shape which might be different from the shape defined in input tensor.
// The explicit tensor_shape usually has inner most axis flattened. For example, given shape[1,4,4,2], if last axis
// does not have padding or slice, then it will be flattened as [1,4,8] for better performance (One inner most copy instead of 4).
// Also supports arbitrary positive and negative stepping along individual axes
SliceIterator(const Tensor& tensor, const TensorShape& tensor_shape, gsl::span<const int64_t> starts,
gsl::span<const int64_t> extents, gsl::span<const int64_t> steps)
: SliceIteratorBase(tensor, tensor_shape, starts, extents, steps) {
}
// postfix iterator increment
const T* operator++(int) {
const T* input = static_cast<const T*>(cur_input());
IncrementInnerDimension();
return input;
}
// prefix iterator increment
const T* operator++() {
IncrementInnerDimension();
return static_cast<const T*>(cur_input());
}
const T& operator*() const {
return *static_cast<const T*>(cur_input());
}
// Assumes SolitaryInnerStep() == true
T* CopyInnermostAxisSolitaryInnerStep(T* output) {
void* new_output = SliceIteratorBase::CopyInnermostAxisSolitaryInnerStep(output);
return static_cast<T*>(new_output);
}
// Assumes generic inner_step_
T* CopyInnermostAxisNonSolitaryInnerStep(T* output) {
void* new_output = SliceIteratorBase::CopyInnermostAxisNonSolitaryInnerStep(output);
return static_cast<T*>(new_output);
}
};
inline void CopyCpuTensor(const Tensor* src, Tensor* tgt) {
void* target = tgt->MutableDataRaw();
const void* source = src->DataRaw();

View file

@ -85,7 +85,7 @@ Status Slice<Tind, dynamic>::ComputeInternal(OpKernelContext* ctx) const {
p_flattened_output_dims));
} else {
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
ORT_RETURN_IF_ERROR(PrepareForCompute(StartsAttribute(), EndsAttribute(), AxesAttribute(),
input_dimensions, starts, steps, output_dims,
p_flattened_output_dims));
}