mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Reduce binary size of Slice implementation (#3238)
* Make the Slice implementation based on type sizes and reduce templatized code to a minimum. * Remove using 'dynamic' as a template param to Slice as well.
This commit is contained in:
parent
53b9d52fc6
commit
48e96ea65f
7 changed files with 220 additions and 289 deletions
|
|
@ -3,33 +3,19 @@
|
|||
|
||||
#include "core/providers/cpu/tensor/slice.h"
|
||||
|
||||
using namespace ::onnxruntime::common;
|
||||
using namespace onnxruntime::common;
|
||||
using namespace std;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
#define ADD_TYPED_DYNAMIC_SLICE_OP(data_type) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
DynamicSlice, \
|
||||
1, \
|
||||
data_type, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()).TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}), \
|
||||
Slice<data_type, true>);
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
DynamicSlice,
|
||||
1,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
Slice10);
|
||||
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int8_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int16_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(float);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(double);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(bool);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(string);
|
||||
|
||||
} // namespace contrib_ops
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -42,19 +42,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastG
|
|||
// we cannot change the domain now as this will break backward compatibility.
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Affine);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Crop);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, DynamicSlice);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, DynamicSlice);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ImageScaler);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MeanVarianceNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ParametricSoftplus);
|
||||
|
|
@ -126,19 +114,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
|||
// contrib ops to main backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Affine)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Crop)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ImageScaler)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MeanVarianceNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ParametricSoftplus)>,
|
||||
|
|
|
|||
|
|
@ -189,19 +189,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 5, Reshape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Shape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Size);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, bool, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, float, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, double, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, MLFloat16, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, uint8_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, uint16_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, uint32_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, uint64_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, int8_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, int16_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, int32_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, int64_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, string, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, Slice);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, DepthToSpace);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 2, 10, Split);
|
||||
|
|
@ -300,19 +288,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearConv);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, bool, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, float, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, double, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, MLFloat16, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, uint8_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, uint16_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, uint32_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, uint64_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, int8_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, int16_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, int32_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, int64_t, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, string, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, Slice);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, Dropout);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf);
|
||||
|
|
@ -371,19 +347,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Fl
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Compress);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Concat);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gather);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, bool, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, double, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, MLFloat16, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint16_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint64_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int8_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int16_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int32_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int64_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, string, Slice);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Slice);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Split);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Squeeze);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Unsqueeze);
|
||||
|
|
@ -773,32 +737,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 5, Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Size)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
|
||||
bool, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
|
||||
float, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
|
||||
double, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
|
||||
MLFloat16, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
|
||||
uint8_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
|
||||
uint16_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
|
||||
uint32_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
|
||||
uint64_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
|
||||
int8_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
|
||||
int16_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
|
||||
int32_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
|
||||
int64_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
|
||||
string, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9,
|
||||
Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
|
||||
DepthToSpace)>,
|
||||
|
|
@ -969,32 +909,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
MatMulInteger)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearConv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
bool, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
float, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
double, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
MLFloat16, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
uint8_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
uint16_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
uint32_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
uint64_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
int8_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
int16_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
int32_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
int64_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
string, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, Dropout)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
|
||||
NonMaxSuppression)>,
|
||||
|
|
@ -1044,32 +960,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Compress)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Concat)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gather)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, bool,
|
||||
Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float,
|
||||
Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, double,
|
||||
Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, MLFloat16,
|
||||
Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t,
|
||||
Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint16_t,
|
||||
Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t,
|
||||
Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint64_t,
|
||||
Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int8_t,
|
||||
Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int16_t,
|
||||
Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int32_t,
|
||||
Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int64_t,
|
||||
Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, string,
|
||||
Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Split)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Squeeze)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Unsqueeze)>,
|
||||
|
|
|
|||
|
|
@ -11,74 +11,29 @@ using namespace ::onnxruntime::common;
|
|||
using namespace std;
|
||||
|
||||
namespace onnxruntime {
|
||||
#define ADD_TYPED_SLICE_V9_OP(data_type) \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
|
||||
Slice, \
|
||||
1, 9, \
|
||||
data_type, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()), \
|
||||
Slice<data_type, false>);
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Slice,
|
||||
1, 9,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
|
||||
Slice1);
|
||||
|
||||
ADD_TYPED_SLICE_V9_OP(uint8_t);
|
||||
ADD_TYPED_SLICE_V9_OP(uint16_t);
|
||||
ADD_TYPED_SLICE_V9_OP(uint32_t);
|
||||
ADD_TYPED_SLICE_V9_OP(uint64_t);
|
||||
ADD_TYPED_SLICE_V9_OP(int8_t);
|
||||
ADD_TYPED_SLICE_V9_OP(int16_t);
|
||||
ADD_TYPED_SLICE_V9_OP(int32_t);
|
||||
ADD_TYPED_SLICE_V9_OP(int64_t);
|
||||
ADD_TYPED_SLICE_V9_OP(float);
|
||||
ADD_TYPED_SLICE_V9_OP(double);
|
||||
ADD_TYPED_SLICE_V9_OP(MLFloat16);
|
||||
ADD_TYPED_SLICE_V9_OP(bool);
|
||||
ADD_TYPED_SLICE_V9_OP(string);
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Slice,
|
||||
10, 10,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
Slice10);
|
||||
|
||||
#define ADD_TYPED_SLICE_V10_OP(data_type) \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
|
||||
Slice, \
|
||||
10, \
|
||||
10, \
|
||||
data_type, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()).TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}), \
|
||||
Slice<data_type, true>);
|
||||
|
||||
ADD_TYPED_SLICE_V10_OP(uint8_t);
|
||||
ADD_TYPED_SLICE_V10_OP(uint16_t);
|
||||
ADD_TYPED_SLICE_V10_OP(uint32_t);
|
||||
ADD_TYPED_SLICE_V10_OP(uint64_t);
|
||||
ADD_TYPED_SLICE_V10_OP(int8_t);
|
||||
ADD_TYPED_SLICE_V10_OP(int16_t);
|
||||
ADD_TYPED_SLICE_V10_OP(int32_t);
|
||||
ADD_TYPED_SLICE_V10_OP(int64_t);
|
||||
ADD_TYPED_SLICE_V10_OP(float);
|
||||
ADD_TYPED_SLICE_V10_OP(double);
|
||||
ADD_TYPED_SLICE_V10_OP(MLFloat16);
|
||||
ADD_TYPED_SLICE_V10_OP(bool);
|
||||
ADD_TYPED_SLICE_V10_OP(string);
|
||||
|
||||
#define ADD_TYPED_SLICE_V11_OP(data_type) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
Slice, \
|
||||
11, \
|
||||
data_type, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()) \
|
||||
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}), \
|
||||
Slice<data_type, true>);
|
||||
|
||||
ADD_TYPED_SLICE_V11_OP(uint8_t);
|
||||
ADD_TYPED_SLICE_V11_OP(uint16_t);
|
||||
ADD_TYPED_SLICE_V11_OP(uint32_t);
|
||||
ADD_TYPED_SLICE_V11_OP(uint64_t);
|
||||
ADD_TYPED_SLICE_V11_OP(int8_t);
|
||||
ADD_TYPED_SLICE_V11_OP(int16_t);
|
||||
ADD_TYPED_SLICE_V11_OP(int32_t);
|
||||
ADD_TYPED_SLICE_V11_OP(int64_t);
|
||||
ADD_TYPED_SLICE_V11_OP(float);
|
||||
ADD_TYPED_SLICE_V11_OP(double);
|
||||
ADD_TYPED_SLICE_V11_OP(MLFloat16);
|
||||
ADD_TYPED_SLICE_V11_OP(bool);
|
||||
ADD_TYPED_SLICE_V11_OP(string);
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Slice,
|
||||
11,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
Slice10);
|
||||
|
||||
namespace {
|
||||
// std::clamp doesn't exist until C++17 so create a local version
|
||||
|
|
@ -315,12 +270,12 @@ void SliceBase::FillVectorsFromInput(const OpKernelContext* context,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
Status SliceImpl(OpKernelContext* ctx,
|
||||
const Tensor& input_tensor,
|
||||
std::vector<int64_t>& output_dims,
|
||||
std::vector<int64_t>* flattened_output_dims,
|
||||
const std::vector<int64_t>& starts,
|
||||
const std::vector<int64_t>& steps) {
|
||||
static Status SliceImpl(OpKernelContext* ctx,
|
||||
const Tensor& input_tensor,
|
||||
std::vector<int64_t>& output_dims,
|
||||
std::vector<int64_t>* flattened_output_dims,
|
||||
const std::vector<int64_t>& starts,
|
||||
const std::vector<int64_t>& steps) {
|
||||
TensorShape output_shape(output_dims);
|
||||
auto& output_tensor = *ctx->Output(0, output_shape);
|
||||
|
||||
|
|
@ -328,7 +283,8 @@ Status SliceImpl(OpKernelContext* ctx,
|
|||
if (output_shape.Size() == 0)
|
||||
return Status::OK();
|
||||
|
||||
auto* output = output_tensor.template MutableData<T>();
|
||||
// use MutableDataRaw as actual data type in tensor may not match as we templatize on data size
|
||||
T* output = reinterpret_cast<T*>(output_tensor.MutableDataRaw());
|
||||
const auto* output_end = output + output_tensor.Shape().Size();
|
||||
|
||||
auto create_output = [&output, &output_end](SliceIterator<T>& input_iterator) {
|
||||
|
|
@ -363,8 +319,7 @@ Status SliceImpl(OpKernelContext* ctx,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T, bool dynamic>
|
||||
Status Slice<T, dynamic>::Compute(OpKernelContext* ctx) const {
|
||||
Status SliceBase::Compute(OpKernelContext* ctx) const {
|
||||
const auto* input_tensor_ptr = ctx->Input<Tensor>(0);
|
||||
ORT_ENFORCE(input_tensor_ptr != nullptr, "Missing input tensor to be processed");
|
||||
const auto& input_tensor = *input_tensor_ptr;
|
||||
|
|
@ -379,7 +334,7 @@ Status Slice<T, dynamic>::Compute(OpKernelContext* ctx) const {
|
|||
std::vector<int64_t>* p_flattened_output_dims = &flattened_output_dims;
|
||||
|
||||
// Slice V10 & DynamicSlice
|
||||
if (dynamic) {
|
||||
if (dynamic_) {
|
||||
std::vector<int64_t> input_starts;
|
||||
std::vector<int64_t> input_ends;
|
||||
std::vector<int64_t> input_axes;
|
||||
|
|
@ -396,6 +351,31 @@ Status Slice<T, dynamic>::Compute(OpKernelContext* ctx) const {
|
|||
p_flattened_output_dims));
|
||||
}
|
||||
|
||||
return SliceImpl<T>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
|
||||
Status status = Status::OK();
|
||||
|
||||
if (input_tensor.IsDataTypeString()) {
|
||||
status = SliceImpl<std::string>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
|
||||
} else {
|
||||
const auto element_size = input_tensor.DataType()->Size();
|
||||
|
||||
switch (element_size) {
|
||||
case sizeof(uint32_t):
|
||||
status = SliceImpl<uint32_t>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
|
||||
break;
|
||||
case sizeof(uint64_t):
|
||||
status = SliceImpl<uint64_t>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
|
||||
break;
|
||||
case sizeof(uint16_t):
|
||||
status = SliceImpl<uint16_t>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
|
||||
break;
|
||||
case sizeof(uint8_t):
|
||||
status = SliceImpl<uint8_t>(ctx, input_tensor, output_dims, p_flattened_output_dims, starts, steps);
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Unsupported input data type of ", input_tensor.DataType());
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ namespace onnxruntime {
|
|||
|
||||
class SliceBase {
|
||||
protected:
|
||||
SliceBase(const OpKernelInfo& info, bool dynamic = false) {
|
||||
SliceBase(const OpKernelInfo& info, bool dynamic = false)
|
||||
: dynamic_(dynamic) {
|
||||
if (!dynamic) {
|
||||
auto has_starts = info.GetAttrs("starts", attr_starts_).IsOK();
|
||||
auto has_ends = info.GetAttrs("ends", attr_ends_).IsOK();
|
||||
|
|
@ -49,13 +50,26 @@ class SliceBase {
|
|||
std::vector<int64_t>& input_axes,
|
||||
std::vector<int64_t>& input_steps) const;
|
||||
|
||||
Status Compute(OpKernelContext* context) const;
|
||||
|
||||
protected:
|
||||
const std::vector<int64_t>& StartsAttribute() const { return attr_starts_; }
|
||||
const std::vector<int64_t>& EndsAttribute() const { return attr_ends_; }
|
||||
const std::vector<int64_t>& AxesAttribute() const { return attr_axes_; }
|
||||
|
||||
private:
|
||||
bool dynamic_;
|
||||
std::vector<int64_t> attr_starts_, attr_ends_, attr_axes_;
|
||||
};
|
||||
|
||||
template <typename T, bool dynamic>
|
||||
struct Slice final : public OpKernel, public SliceBase {
|
||||
Slice(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, dynamic) {}
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
}; // namespace onnxruntime
|
||||
struct Slice1 final : public OpKernel, public SliceBase {
|
||||
Slice1(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, false) {}
|
||||
Status Compute(OpKernelContext* context) const override { return SliceBase::Compute(context); }
|
||||
};
|
||||
|
||||
struct Slice10 final : public OpKernel, public SliceBase {
|
||||
Slice10(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, true) {}
|
||||
Status Compute(OpKernelContext* context) const override { return SliceBase::Compute(context); }
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -157,10 +157,14 @@ struct SliceSkips : std::vector<int64_t> {
|
|||
};
|
||||
|
||||
// This provides easy sequential iteration over a subset of a tensor given a span of starts, extents & optionally steps
|
||||
template <typename T>
|
||||
struct SliceIterator {
|
||||
SliceIterator(const Tensor& tensor, gsl::span<const int64_t> starts,
|
||||
gsl::span<const int64_t> extents, gsl::span<const int64_t> steps)
|
||||
// The base class is type agnostic to minimize binary size. The derived class provides any type specific logic.
|
||||
struct SliceIteratorBase {
|
||||
private:
|
||||
enum class byte : unsigned char {};
|
||||
|
||||
protected:
|
||||
SliceIteratorBase(const Tensor& tensor, gsl::span<const int64_t> starts,
|
||||
gsl::span<const int64_t> extents, gsl::span<const int64_t> steps)
|
||||
: tensor_(tensor), extents_(extents), skips_(tensor_.Shape(), extents, steps), indices_(extents.size(), 0) {
|
||||
auto& dims = tensor_.Shape().GetDims();
|
||||
Init(dims, starts, steps);
|
||||
|
|
@ -170,16 +174,15 @@ struct SliceIterator {
|
|||
// The explicit tensor_shape usually has inner most axis flattened. For example, given shape[1,4,4,2], if last axis
|
||||
// does not have padding or slice, then it will be flattened as [1,4,8] for better performance (One inner most copy instead of 4).
|
||||
// Also supports arbitrary positive and negative stepping along individual axes
|
||||
SliceIterator(const Tensor& tensor, const TensorShape& tensor_shape, gsl::span<const int64_t> starts,
|
||||
gsl::span<const int64_t> extents, gsl::span<const int64_t> steps)
|
||||
SliceIteratorBase(const Tensor& tensor, const TensorShape& tensor_shape, gsl::span<const int64_t> starts,
|
||||
gsl::span<const int64_t> extents, gsl::span<const int64_t> steps)
|
||||
: tensor_(tensor), extents_(extents), skips_(tensor_shape, extents, steps), indices_(extents.size(), 0) {
|
||||
const auto& dims = tensor_shape.GetDims();
|
||||
Init(dims, starts, steps);
|
||||
}
|
||||
|
||||
// Initialize initial skip and inner_extent.
|
||||
void Init(const std::vector<int64_t>& dims, gsl::span<const int64_t> starts,
|
||||
gsl::span<const int64_t> steps) {
|
||||
void Init(const std::vector<int64_t>& dims, gsl::span<const int64_t> starts, gsl::span<const int64_t> steps) {
|
||||
ORT_ENFORCE(dims.size() == starts.size() &&
|
||||
dims.size() == extents_.size() &&
|
||||
dims.size() >= steps.size());
|
||||
|
|
@ -187,7 +190,7 @@ struct SliceIterator {
|
|||
size_t pitch = 1;
|
||||
// Initial skip, so that input_ points to the first element to copy
|
||||
for (size_t i = dims.size(); i-- > 0;) {
|
||||
input_ += pitch * starts[i];
|
||||
input_ += pitch * starts[i] * element_size_;
|
||||
pitch *= dims[i];
|
||||
}
|
||||
|
||||
|
|
@ -199,38 +202,74 @@ struct SliceIterator {
|
|||
|
||||
void AdvanceOverInnerExtent() {
|
||||
size_t axis = skips_.size() - 1;
|
||||
input_ += skips_[axis];
|
||||
input_ += skips_[axis] * element_size_;
|
||||
while (axis-- && ++indices_[axis] == extents_[axis]) {
|
||||
indices_[axis] = 0;
|
||||
input_ += skips_[axis];
|
||||
input_ += skips_[axis] * element_size_;
|
||||
}
|
||||
}
|
||||
|
||||
void IncrementInnerDimension() {
|
||||
input_ += inner_step_;
|
||||
input_ += inner_step_ * element_size_;
|
||||
if (++inner_counter_ == inner_extent_) {
|
||||
inner_counter_ = 0;
|
||||
AdvanceOverInnerExtent();
|
||||
}
|
||||
}
|
||||
|
||||
// postfix iterator increment
|
||||
const T* operator++(int) {
|
||||
const T* input = input_;
|
||||
IncrementInnerDimension();
|
||||
return input;
|
||||
}
|
||||
|
||||
// prefix iterator increment
|
||||
const T* operator++() {
|
||||
IncrementInnerDimension();
|
||||
const void* cur_input() const {
|
||||
return input_;
|
||||
}
|
||||
|
||||
const T& operator*() const {
|
||||
return *input_;
|
||||
// Assumes SolitaryInnerStep() == true
|
||||
void* CopyInnermostAxisSolitaryInnerStep(void* output) {
|
||||
byte* out_bytes = reinterpret_cast<byte*>(output);
|
||||
auto bytes_to_copy = inner_extent_ * element_size_;
|
||||
|
||||
if (!is_string_tensor_) {
|
||||
std::copy(input_, input_ + bytes_to_copy, out_bytes);
|
||||
} else {
|
||||
const std::string* input = reinterpret_cast<const std::string*>(input_);
|
||||
std::string* out = reinterpret_cast<std::string*>(output);
|
||||
std::copy(input, input + inner_extent_, out);
|
||||
}
|
||||
|
||||
input_ += bytes_to_copy;
|
||||
out_bytes += bytes_to_copy;
|
||||
AdvanceOverInnerExtent();
|
||||
|
||||
return out_bytes;
|
||||
}
|
||||
|
||||
// Assumes generic inner_step_
|
||||
void* CopyInnermostAxisNonSolitaryInnerStep(void* output) {
|
||||
// need to special case std::string so the copy works correctly
|
||||
if (!is_string_tensor_) {
|
||||
// switch on element size so copy is efficient
|
||||
switch (element_size_) {
|
||||
case sizeof(uint8_t):
|
||||
output = TypedCopyInnermostAxisNonSolitaryInnerStep<uint8_t>(output);
|
||||
break;
|
||||
case sizeof(uint16_t):
|
||||
output = TypedCopyInnermostAxisNonSolitaryInnerStep<uint16_t>(output);
|
||||
break;
|
||||
case sizeof(uint32_t):
|
||||
output = TypedCopyInnermostAxisNonSolitaryInnerStep<uint32_t>(output);
|
||||
break;
|
||||
case sizeof(uint64_t):
|
||||
output = TypedCopyInnermostAxisNonSolitaryInnerStep<uint64_t>(output);
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Unexpected element size of ", element_size_);
|
||||
}
|
||||
} else {
|
||||
output = TypedCopyInnermostAxisNonSolitaryInnerStep<std::string>(output);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
public:
|
||||
// splitting the function that copies the innermost dimension into 2 separate methods,
|
||||
// CopyInnermostAxisSolitaryInnerStep and CopyInnermostAxisNonSolitaryInnerStep,
|
||||
// as this is most likely being called within a loop
|
||||
|
|
@ -238,33 +277,78 @@ struct SliceIterator {
|
|||
// up to the caller to call the correct one based on SolitaryInnerStep().
|
||||
bool SolitaryInnerStep() const { return inner_step_ == 1; }
|
||||
|
||||
// Assumes SolitaryInnerStep() == true
|
||||
T* CopyInnermostAxisSolitaryInnerStep(T* output) {
|
||||
std::copy(input_, input_ + inner_extent_, output);
|
||||
input_ += inner_extent_;
|
||||
output += inner_extent_;
|
||||
AdvanceOverInnerExtent();
|
||||
return output;
|
||||
}
|
||||
|
||||
// Assumes generic inner_step_
|
||||
T* CopyInnermostAxisNonSolitaryInnerStep(T* output) {
|
||||
private:
|
||||
template <typename T>
|
||||
void* TypedCopyInnermostAxisNonSolitaryInnerStep(void* output) {
|
||||
// sizeof(T) == element_size_
|
||||
T* out = reinterpret_cast<T*>(output);
|
||||
for (size_t i = 0; i < inner_extent_; ++i) {
|
||||
*output++ = *input_;
|
||||
*out++ = *reinterpret_cast<const T*>(input_);
|
||||
IncrementInnerDimension();
|
||||
}
|
||||
return output;
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
private:
|
||||
const Tensor& tensor_;
|
||||
const T* input_{tensor_.template Data<T>()};
|
||||
const bool is_string_tensor_{tensor_.IsDataTypeString()};
|
||||
// we do everything in this class using bytes to minimize binary size
|
||||
const byte* input_{reinterpret_cast<const byte*>(tensor_.DataRaw())};
|
||||
const int64_t element_size_ = tensor_.DataType()->Size();
|
||||
|
||||
gsl::span<const int64_t> extents_;
|
||||
size_t inner_counter_{}, inner_extent_, inner_step_;
|
||||
SliceSkips skips_;
|
||||
std::vector<int64_t> indices_; // There is no index for innermost axis since it's a special case
|
||||
};
|
||||
|
||||
// This provides easy sequential iteration over a subset of a tensor given a span of starts, extents & optionally steps
|
||||
template <typename T>
|
||||
struct SliceIterator : public SliceIteratorBase {
|
||||
SliceIterator(const Tensor& tensor, gsl::span<const int64_t> starts,
|
||||
gsl::span<const int64_t> extents, gsl::span<const int64_t> steps)
|
||||
: SliceIteratorBase(tensor, starts, extents, steps) {
|
||||
}
|
||||
|
||||
// This construct takes a explicit tensor_shape which might be different from the shape defined in input tensor.
|
||||
// The explicit tensor_shape usually has inner most axis flattened. For example, given shape[1,4,4,2], if last axis
|
||||
// does not have padding or slice, then it will be flattened as [1,4,8] for better performance (One inner most copy instead of 4).
|
||||
// Also supports arbitrary positive and negative stepping along individual axes
|
||||
SliceIterator(const Tensor& tensor, const TensorShape& tensor_shape, gsl::span<const int64_t> starts,
|
||||
gsl::span<const int64_t> extents, gsl::span<const int64_t> steps)
|
||||
: SliceIteratorBase(tensor, tensor_shape, starts, extents, steps) {
|
||||
}
|
||||
|
||||
// postfix iterator increment
|
||||
const T* operator++(int) {
|
||||
const T* input = static_cast<const T*>(cur_input());
|
||||
IncrementInnerDimension();
|
||||
return input;
|
||||
}
|
||||
|
||||
// prefix iterator increment
|
||||
const T* operator++() {
|
||||
IncrementInnerDimension();
|
||||
return static_cast<const T*>(cur_input());
|
||||
}
|
||||
|
||||
const T& operator*() const {
|
||||
return *static_cast<const T*>(cur_input());
|
||||
}
|
||||
|
||||
// Assumes SolitaryInnerStep() == true
|
||||
T* CopyInnermostAxisSolitaryInnerStep(T* output) {
|
||||
void* new_output = SliceIteratorBase::CopyInnermostAxisSolitaryInnerStep(output);
|
||||
return static_cast<T*>(new_output);
|
||||
}
|
||||
|
||||
// Assumes generic inner_step_
|
||||
T* CopyInnermostAxisNonSolitaryInnerStep(T* output) {
|
||||
void* new_output = SliceIteratorBase::CopyInnermostAxisNonSolitaryInnerStep(output);
|
||||
return static_cast<T*>(new_output);
|
||||
}
|
||||
};
|
||||
|
||||
inline void CopyCpuTensor(const Tensor* src, Tensor* tgt) {
|
||||
void* target = tgt->MutableDataRaw();
|
||||
const void* source = src->DataRaw();
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ Status Slice<Tind, dynamic>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
p_flattened_output_dims));
|
||||
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(StartsAttribute(), EndsAttribute(), AxesAttribute(),
|
||||
input_dimensions, starts, steps, output_dims,
|
||||
p_flattened_output_dims));
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue