mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
executable size reduction: cleaned up slice op to get savings (#621)
* Initial commit * More cahnges * More changes * Fix build break
This commit is contained in:
parent
3f52de07c7
commit
1aa24cbbf3
4 changed files with 142 additions and 163 deletions
|
|
@ -177,33 +177,19 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int32_t, DynamicSlice);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, DynamicSlice);
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace);
|
||||
|
|
@ -442,33 +428,19 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice)>());
|
||||
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int32_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int32_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int32_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int32_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int32_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int32_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int32_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int32_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int32_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int32_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int32_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int32_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int32_t, DynamicSlice)>());
|
||||
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int64_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int64_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int64_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int64_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int64_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int64_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int64_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int64_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int64_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int64_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int64_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int64_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int64_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, DynamicSlice)>());
|
||||
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace)>());
|
||||
|
|
|
|||
|
|
@ -8,64 +8,51 @@ using namespace std;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
#define ADD_TYPED_SLICE_OP(data_type, indice_type) \
|
||||
#define ADD_TYPED_SLICE_OP(data_type) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
Slice, \
|
||||
1, \
|
||||
data_type, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()), \
|
||||
Slice<data_type, indice_type, false>);
|
||||
Slice<data_type, false>);
|
||||
|
||||
ADD_TYPED_SLICE_OP(uint8_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(uint16_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(uint32_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(uint64_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(int8_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(int16_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(int32_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(int64_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(float, int64_t);
|
||||
ADD_TYPED_SLICE_OP(double, int64_t);
|
||||
ADD_TYPED_SLICE_OP(MLFloat16,int64_t);
|
||||
ADD_TYPED_SLICE_OP(bool, int64_t);
|
||||
ADD_TYPED_SLICE_OP(string, int64_t);
|
||||
ADD_TYPED_SLICE_OP(uint8_t);
|
||||
ADD_TYPED_SLICE_OP(uint16_t);
|
||||
ADD_TYPED_SLICE_OP(uint32_t);
|
||||
ADD_TYPED_SLICE_OP(uint64_t);
|
||||
ADD_TYPED_SLICE_OP(int8_t);
|
||||
ADD_TYPED_SLICE_OP(int16_t);
|
||||
ADD_TYPED_SLICE_OP(int32_t);
|
||||
ADD_TYPED_SLICE_OP(int64_t);
|
||||
ADD_TYPED_SLICE_OP(float);
|
||||
ADD_TYPED_SLICE_OP(double);
|
||||
ADD_TYPED_SLICE_OP(MLFloat16);
|
||||
ADD_TYPED_SLICE_OP(bool);
|
||||
ADD_TYPED_SLICE_OP(string);
|
||||
|
||||
#define ADD_TYPED_DYNAMIC_SLICE_OP(data_type, indice_type) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
DynamicSlice, \
|
||||
1, \
|
||||
data_type##_##indice_type, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()) \
|
||||
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<indice_type>()), \
|
||||
Slice<data_type, indice_type, true>);
|
||||
#define ADD_TYPED_DYNAMIC_SLICE_OP(data_type) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
DynamicSlice, \
|
||||
1, \
|
||||
data_type, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()) \
|
||||
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(), \
|
||||
DataTypeImpl::GetTensorType<int64_t>()}), \
|
||||
Slice<data_type, true>);
|
||||
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int8_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int16_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int32_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int64_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(float, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(double, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16,int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(bool, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(string, int32_t);
|
||||
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int8_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int16_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int32_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int64_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(float, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(double, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16,int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(bool, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(string, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int8_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int16_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(float);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(double);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(bool);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(string);
|
||||
|
||||
namespace {
|
||||
// std::clamp doesn't exist until C++17 so create a local version
|
||||
|
|
@ -78,12 +65,11 @@ const T& clamp(const T& v, const T& lo, const T& hi) {
|
|||
} // namespace
|
||||
|
||||
Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
|
||||
const std::vector<int64_t>& raw_ends,
|
||||
const std::vector<int64_t>& raw_ends,
|
||||
const std::vector<int64_t>& raw_axes,
|
||||
const size_t dimension_count,
|
||||
const std::vector<int64_t>& input_dimensions,
|
||||
std::vector<int64_t>& starts,
|
||||
std::vector<int64_t>& output_dims) const {
|
||||
std::vector<int64_t>& starts,
|
||||
std::vector<int64_t>& output_dims) const {
|
||||
// Initialize axes to the provided axes attribute or to the default sequence
|
||||
std::vector<int64_t> axes(raw_axes);
|
||||
if (axes.size() == 0) {
|
||||
|
|
@ -93,6 +79,7 @@ Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
|
|||
}
|
||||
|
||||
// Iterate through the provided axes and override the start/end ranges
|
||||
const auto& dimension_count = input_dimensions.size();
|
||||
for (size_t axesIndex = 0; axesIndex < axes.size(); axesIndex++) {
|
||||
auto axis = axes[axesIndex] < 0 ? axes[axesIndex] + static_cast<int64_t>(dimension_count) : axes[axesIndex];
|
||||
if (axis >= static_cast<int64_t>(dimension_count) || axis < 0)
|
||||
|
|
@ -113,63 +100,84 @@ Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Tind>
|
||||
void SliceBase::FillVectorsFromInput(const OpKernelContext* context,
|
||||
std::vector<int64_t>& input_starts,
|
||||
std::vector<int64_t>& input_ends,
|
||||
std::vector<int64_t>& input_axes) const {
|
||||
auto stat_tensor = context->Input<Tensor>(1);
|
||||
std::vector<int64_t>& input_starts,
|
||||
std::vector<int64_t>& input_ends,
|
||||
std::vector<int64_t>& input_axes) const {
|
||||
auto start_tensor = context->Input<Tensor>(1);
|
||||
auto ends_tensor = context->Input<Tensor>(2);
|
||||
auto axes_tensor = context->Input<Tensor>(3);
|
||||
|
||||
ORT_ENFORCE (nullptr != stat_tensor && stat_tensor->Shape().NumDimensions() == 1, "Starts must be a 1-D array" );
|
||||
ORT_ENFORCE (nullptr != ends_tensor && ends_tensor->Shape().NumDimensions() == 1, "ends must be a 1-D array" );
|
||||
ORT_ENFORCE (stat_tensor->Shape() == ends_tensor->Shape(), "Starts and ends shape mismatch");
|
||||
ORT_ENFORCE (nullptr == axes_tensor || stat_tensor->Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch");
|
||||
ORT_ENFORCE(nullptr != start_tensor && start_tensor->Shape().NumDimensions() == 1, "Starts must be a 1-D array");
|
||||
ORT_ENFORCE(nullptr != ends_tensor && ends_tensor->Shape().NumDimensions() == 1, "Ends must be a 1-D array");
|
||||
ORT_ENFORCE(start_tensor->Shape() == ends_tensor->Shape(), "Starts and ends shape mismatch");
|
||||
ORT_ENFORCE(nullptr == axes_tensor || start_tensor->Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch");
|
||||
|
||||
auto size = stat_tensor->Shape().Size();
|
||||
const auto& dtype = start_tensor->DataType();
|
||||
const auto& size = start_tensor->Shape().Size();
|
||||
input_starts.resize(size);
|
||||
std::copy(stat_tensor->Data<Tind>(), stat_tensor->Data<Tind>() + size, input_starts.begin());
|
||||
input_ends.resize(size);
|
||||
std::copy(ends_tensor->Data<Tind>(), ends_tensor->Data<Tind>() + size, input_ends.begin());
|
||||
if (nullptr != axes_tensor) {
|
||||
if (nullptr != axes_tensor)
|
||||
input_axes.resize(size);
|
||||
std::copy(axes_tensor->Data<Tind>(), axes_tensor->Data<Tind>() + size, input_axes.begin());
|
||||
|
||||
if (dtype == DataTypeImpl::GetType<int32_t>()) {
|
||||
std::copy(start_tensor->Data<int32_t>(), start_tensor->Data<int32_t>() + size, input_starts.begin());
|
||||
std::copy(ends_tensor->Data<int32_t>(), ends_tensor->Data<int32_t>() + size, input_ends.begin());
|
||||
if (nullptr != axes_tensor)
|
||||
std::copy(axes_tensor->Data<int32_t>(), axes_tensor->Data<int32_t>() + size, input_axes.begin());
|
||||
}
|
||||
|
||||
else if (dtype == DataTypeImpl::GetType<int64_t>()) {
|
||||
std::copy(start_tensor->Data<int64_t>(), start_tensor->Data<int64_t>() + size, input_starts.begin());
|
||||
std::copy(ends_tensor->Data<int64_t>(), ends_tensor->Data<int64_t>() + size, input_ends.begin());
|
||||
if (nullptr != axes_tensor)
|
||||
std::copy(axes_tensor->Data<int64_t>(), axes_tensor->Data<int64_t>() + size, input_axes.begin());
|
||||
}
|
||||
|
||||
// should not reach this as no kernel is registered for this condition to be triggered - just an additional safety check
|
||||
else {
|
||||
ORT_THROW("Data type for starts and ends inputs' need to be int32_t or int64_t, but instead got ", dtype);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Tind, bool dynamic>
|
||||
Status Slice<T, Tind, dynamic>::Compute(OpKernelContext* ctx) const {
|
||||
const Tensor* input_tensor_ptr = ctx->Input<Tensor>(0);
|
||||
ORT_ENFORCE(input_tensor_ptr != nullptr);
|
||||
auto& input_tensor = *input_tensor_ptr;
|
||||
auto& input_dimensions = input_tensor.Shape().GetDims();
|
||||
|
||||
// Initialize the starts & ends to the actual tensor shape
|
||||
const size_t dimension_count = input_dimensions.size();
|
||||
std::vector<int64_t> starts(dimension_count, 0);
|
||||
std::vector<int64_t> output_dims(input_dimensions);
|
||||
|
||||
if (dynamic) {
|
||||
std::vector<int64_t> input_starts, input_ends, input_axes;
|
||||
FillVectorsFromInput<Tind>(ctx, input_starts, input_ends, input_axes);
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes,
|
||||
dimension_count, input_dimensions, starts, output_dims));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
|
||||
dimension_count, input_dimensions, starts, output_dims));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status SliceImpl(OpKernelContext* ctx,
|
||||
const Tensor& input_tensor,
|
||||
std::vector<int64_t>& output_dims,
|
||||
const std::vector<int64_t>& starts) {
|
||||
TensorShape output_shape(output_dims);
|
||||
auto& output_tensor = *ctx->Output(0, output_shape);
|
||||
auto* output = output_tensor.template MutableData<T>();
|
||||
const auto* output_end = output + output_shape.Size();
|
||||
const auto* output_end = output + output_tensor.Shape().Size();
|
||||
|
||||
SliceIterator<T> input_iterator(input_tensor, starts, output_dims);
|
||||
SliceIterator<T> input_iterator(input_tensor, starts, output_tensor.Shape().GetDims());
|
||||
while (output != output_end)
|
||||
*output++ = *input_iterator++;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T, bool dynamic>
|
||||
Status Slice<T, dynamic>::Compute(OpKernelContext* ctx) const {
|
||||
const Tensor* input_tensor_ptr = ctx->Input<Tensor>(0);
|
||||
ORT_ENFORCE(input_tensor_ptr != nullptr, "Missing input tensor to be processed");
|
||||
const auto& input_tensor = *input_tensor_ptr;
|
||||
const auto& input_dimensions = input_tensor.Shape().GetDims();
|
||||
|
||||
// Initialize the starts & ends to the actual tensor shape
|
||||
std::vector<int64_t> starts(input_dimensions.size(), 0);
|
||||
std::vector<int64_t> output_dims(input_dimensions);
|
||||
|
||||
if (dynamic) {
|
||||
std::vector<int64_t> input_starts, input_ends, input_axes;
|
||||
FillVectorsFromInput(ctx, input_starts, input_ends, input_axes);
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes,
|
||||
input_dimensions, starts, output_dims));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
|
||||
input_dimensions, starts, output_dims));
|
||||
}
|
||||
|
||||
return SliceImpl<T>(ctx, input_tensor, output_dims, starts);
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -9,35 +9,34 @@ namespace onnxruntime {
|
|||
|
||||
class SliceBase {
|
||||
protected:
|
||||
SliceBase (const OpKernelInfo& info, bool dynamic = false) {
|
||||
SliceBase(const OpKernelInfo& info, bool dynamic = false) {
|
||||
if (!dynamic) {
|
||||
auto has_starts = info.GetAttrs("starts", attr_starts_).IsOK();
|
||||
auto has_ends = info.GetAttrs("ends", attr_ends_).IsOK();
|
||||
auto has_axes = info.GetAttrs("axes", attr_axes_).IsOK();
|
||||
auto has_ends = info.GetAttrs("ends", attr_ends_).IsOK();
|
||||
auto has_axes = info.GetAttrs("axes", attr_axes_).IsOK();
|
||||
ORT_ENFORCE(has_starts && has_ends && attr_starts_.size() == attr_ends_.size(),
|
||||
"Missing or invalid starts and ends attribute");
|
||||
"Missing or invalid starts and ends attribute");
|
||||
ORT_ENFORCE(!has_axes || attr_axes_.size() == attr_starts_.size(),
|
||||
"Invalid axes attribute");
|
||||
"Invalid axes attribute, axes attribute (if present) should have the same size as starts/ends attributes");
|
||||
}
|
||||
}
|
||||
|
||||
Status PrepareForCompute(const std::vector<int64_t>& raw_starts,
|
||||
const std::vector<int64_t>& raw_ends,
|
||||
const std::vector<int64_t>& raw_ends,
|
||||
const std::vector<int64_t>& raw_axes,
|
||||
const size_t dimension_count,
|
||||
const std::vector<int64_t>& input_dimensions,
|
||||
std::vector<int64_t>& starts,
|
||||
std::vector<int64_t>& output_dims) const;
|
||||
template<typename Tind>
|
||||
std::vector<int64_t>& starts,
|
||||
std::vector<int64_t>& output_dims) const;
|
||||
|
||||
void FillVectorsFromInput(const OpKernelContext* context,
|
||||
std::vector<int64_t>& raw_starts,
|
||||
std::vector<int64_t>& raw_ends,
|
||||
std::vector<int64_t>& raw_axes) const;
|
||||
std::vector<int64_t>& raw_starts,
|
||||
std::vector<int64_t>& raw_ends,
|
||||
std::vector<int64_t>& raw_axes) const;
|
||||
|
||||
std::vector<int64_t> attr_starts_, attr_ends_, attr_axes_;
|
||||
};
|
||||
|
||||
template <typename T, typename Tind, bool dynamic>
|
||||
template <typename T, bool dynamic>
|
||||
struct Slice final : public OpKernel, public SliceBase {
|
||||
Slice(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, dynamic) {}
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
|
|
|||
|
|
@ -40,13 +40,13 @@ Status Slice<Tind, dynamic>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
|
||||
if (dynamic) {
|
||||
std::vector<int64_t> input_starts, input_ends, input_axes;
|
||||
FillVectorsFromInput<Tind>(ctx, input_starts, input_ends, input_axes);
|
||||
FillVectorsFromInput(ctx, input_starts, input_ends, input_axes);
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes,
|
||||
dimension_count, input_dimensions, starts, output_dims));
|
||||
input_dimensions, starts, output_dims));
|
||||
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
|
||||
dimension_count, input_dimensions, starts, output_dims));
|
||||
input_dimensions, starts, output_dims));
|
||||
}
|
||||
|
||||
TensorShape output_shape(output_dims);
|
||||
|
|
|
|||
Loading…
Reference in a new issue