executable size reduction: cleaned up slice op to get savings (#621)

* Initial commit

* More cahnges

* More changes

* Fix build break
This commit is contained in:
Hariharan Seshadri 2019-03-19 15:00:16 -07:00 committed by GitHub
parent 3f52de07c7
commit 1aa24cbbf3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 142 additions and 163 deletions

View file

@ -177,33 +177,19 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, DynamicSlice);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace);
@ -442,33 +428,19 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int32_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int32_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int32_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int32_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int32_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int32_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int32_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int32_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int32_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int32_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int32_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int32_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int32_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int64_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int64_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int64_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int64_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int64_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int64_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int64_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int64_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int64_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int64_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int64_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int64_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int64_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, DynamicSlice)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace)>());

View file

@ -8,64 +8,51 @@ using namespace std;
namespace onnxruntime {
#define ADD_TYPED_SLICE_OP(data_type, indice_type) \
#define ADD_TYPED_SLICE_OP(data_type) \
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
Slice, \
1, \
data_type, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()), \
Slice<data_type, indice_type, false>);
Slice<data_type, false>);
ADD_TYPED_SLICE_OP(uint8_t, int64_t);
ADD_TYPED_SLICE_OP(uint16_t, int64_t);
ADD_TYPED_SLICE_OP(uint32_t, int64_t);
ADD_TYPED_SLICE_OP(uint64_t, int64_t);
ADD_TYPED_SLICE_OP(int8_t, int64_t);
ADD_TYPED_SLICE_OP(int16_t, int64_t);
ADD_TYPED_SLICE_OP(int32_t, int64_t);
ADD_TYPED_SLICE_OP(int64_t, int64_t);
ADD_TYPED_SLICE_OP(float, int64_t);
ADD_TYPED_SLICE_OP(double, int64_t);
ADD_TYPED_SLICE_OP(MLFloat16,int64_t);
ADD_TYPED_SLICE_OP(bool, int64_t);
ADD_TYPED_SLICE_OP(string, int64_t);
ADD_TYPED_SLICE_OP(uint8_t);
ADD_TYPED_SLICE_OP(uint16_t);
ADD_TYPED_SLICE_OP(uint32_t);
ADD_TYPED_SLICE_OP(uint64_t);
ADD_TYPED_SLICE_OP(int8_t);
ADD_TYPED_SLICE_OP(int16_t);
ADD_TYPED_SLICE_OP(int32_t);
ADD_TYPED_SLICE_OP(int64_t);
ADD_TYPED_SLICE_OP(float);
ADD_TYPED_SLICE_OP(double);
ADD_TYPED_SLICE_OP(MLFloat16);
ADD_TYPED_SLICE_OP(bool);
ADD_TYPED_SLICE_OP(string);
#define ADD_TYPED_DYNAMIC_SLICE_OP(data_type, indice_type) \
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
DynamicSlice, \
1, \
data_type##_##indice_type, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()) \
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<indice_type>()), \
Slice<data_type, indice_type, true>);
#define ADD_TYPED_DYNAMIC_SLICE_OP(data_type) \
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
DynamicSlice, \
1, \
data_type, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()) \
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(), \
DataTypeImpl::GetTensorType<int64_t>()}), \
Slice<data_type, true>);
ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int8_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int16_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int32_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int64_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(float, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(double, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16,int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(bool, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(string, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int8_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int16_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int32_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int64_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(float, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(double, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16,int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(bool, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(string, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int8_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int16_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(float);
ADD_TYPED_DYNAMIC_SLICE_OP(double);
ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16);
ADD_TYPED_DYNAMIC_SLICE_OP(bool);
ADD_TYPED_DYNAMIC_SLICE_OP(string);
namespace {
// std::clamp doesn't exist until C++17 so create a local version
@ -78,12 +65,11 @@ const T& clamp(const T& v, const T& lo, const T& hi) {
} // namespace
Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
const std::vector<int64_t>& raw_ends,
const std::vector<int64_t>& raw_ends,
const std::vector<int64_t>& raw_axes,
const size_t dimension_count,
const std::vector<int64_t>& input_dimensions,
std::vector<int64_t>& starts,
std::vector<int64_t>& output_dims) const {
std::vector<int64_t>& starts,
std::vector<int64_t>& output_dims) const {
// Initialize axes to the provided axes attribute or to the default sequence
std::vector<int64_t> axes(raw_axes);
if (axes.size() == 0) {
@ -93,6 +79,7 @@ Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
}
// Iterate through the provided axes and override the start/end ranges
const auto& dimension_count = input_dimensions.size();
for (size_t axesIndex = 0; axesIndex < axes.size(); axesIndex++) {
auto axis = axes[axesIndex] < 0 ? axes[axesIndex] + static_cast<int64_t>(dimension_count) : axes[axesIndex];
if (axis >= static_cast<int64_t>(dimension_count) || axis < 0)
@ -113,63 +100,84 @@ Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
return Status::OK();
}
template <typename Tind>
void SliceBase::FillVectorsFromInput(const OpKernelContext* context,
std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends,
std::vector<int64_t>& input_axes) const {
auto stat_tensor = context->Input<Tensor>(1);
std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends,
std::vector<int64_t>& input_axes) const {
auto start_tensor = context->Input<Tensor>(1);
auto ends_tensor = context->Input<Tensor>(2);
auto axes_tensor = context->Input<Tensor>(3);
ORT_ENFORCE (nullptr != stat_tensor && stat_tensor->Shape().NumDimensions() == 1, "Starts must be a 1-D array" );
ORT_ENFORCE (nullptr != ends_tensor && ends_tensor->Shape().NumDimensions() == 1, "ends must be a 1-D array" );
ORT_ENFORCE (stat_tensor->Shape() == ends_tensor->Shape(), "Starts and ends shape mismatch");
ORT_ENFORCE (nullptr == axes_tensor || stat_tensor->Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch");
ORT_ENFORCE(nullptr != start_tensor && start_tensor->Shape().NumDimensions() == 1, "Starts must be a 1-D array");
ORT_ENFORCE(nullptr != ends_tensor && ends_tensor->Shape().NumDimensions() == 1, "Ends must be a 1-D array");
ORT_ENFORCE(start_tensor->Shape() == ends_tensor->Shape(), "Starts and ends shape mismatch");
ORT_ENFORCE(nullptr == axes_tensor || start_tensor->Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch");
auto size = stat_tensor->Shape().Size();
const auto& dtype = start_tensor->DataType();
const auto& size = start_tensor->Shape().Size();
input_starts.resize(size);
std::copy(stat_tensor->Data<Tind>(), stat_tensor->Data<Tind>() + size, input_starts.begin());
input_ends.resize(size);
std::copy(ends_tensor->Data<Tind>(), ends_tensor->Data<Tind>() + size, input_ends.begin());
if (nullptr != axes_tensor) {
if (nullptr != axes_tensor)
input_axes.resize(size);
std::copy(axes_tensor->Data<Tind>(), axes_tensor->Data<Tind>() + size, input_axes.begin());
if (dtype == DataTypeImpl::GetType<int32_t>()) {
std::copy(start_tensor->Data<int32_t>(), start_tensor->Data<int32_t>() + size, input_starts.begin());
std::copy(ends_tensor->Data<int32_t>(), ends_tensor->Data<int32_t>() + size, input_ends.begin());
if (nullptr != axes_tensor)
std::copy(axes_tensor->Data<int32_t>(), axes_tensor->Data<int32_t>() + size, input_axes.begin());
}
else if (dtype == DataTypeImpl::GetType<int64_t>()) {
std::copy(start_tensor->Data<int64_t>(), start_tensor->Data<int64_t>() + size, input_starts.begin());
std::copy(ends_tensor->Data<int64_t>(), ends_tensor->Data<int64_t>() + size, input_ends.begin());
if (nullptr != axes_tensor)
std::copy(axes_tensor->Data<int64_t>(), axes_tensor->Data<int64_t>() + size, input_axes.begin());
}
// should not reach this as no kernel is registered for this condition to be triggered - just an additional safety check
else {
ORT_THROW("Data type for starts and ends inputs' need to be int32_t or int64_t, but instead got ", dtype);
}
}
template <typename T, typename Tind, bool dynamic>
Status Slice<T, Tind, dynamic>::Compute(OpKernelContext* ctx) const {
const Tensor* input_tensor_ptr = ctx->Input<Tensor>(0);
ORT_ENFORCE(input_tensor_ptr != nullptr);
auto& input_tensor = *input_tensor_ptr;
auto& input_dimensions = input_tensor.Shape().GetDims();
// Initialize the starts & ends to the actual tensor shape
const size_t dimension_count = input_dimensions.size();
std::vector<int64_t> starts(dimension_count, 0);
std::vector<int64_t> output_dims(input_dimensions);
if (dynamic) {
std::vector<int64_t> input_starts, input_ends, input_axes;
FillVectorsFromInput<Tind>(ctx, input_starts, input_ends, input_axes);
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes,
dimension_count, input_dimensions, starts, output_dims));
} else {
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
dimension_count, input_dimensions, starts, output_dims));
}
template <typename T>
Status SliceImpl(OpKernelContext* ctx,
const Tensor& input_tensor,
std::vector<int64_t>& output_dims,
const std::vector<int64_t>& starts) {
TensorShape output_shape(output_dims);
auto& output_tensor = *ctx->Output(0, output_shape);
auto* output = output_tensor.template MutableData<T>();
const auto* output_end = output + output_shape.Size();
const auto* output_end = output + output_tensor.Shape().Size();
SliceIterator<T> input_iterator(input_tensor, starts, output_dims);
SliceIterator<T> input_iterator(input_tensor, starts, output_tensor.Shape().GetDims());
while (output != output_end)
*output++ = *input_iterator++;
return Status::OK();
}
template <typename T, bool dynamic>
Status Slice<T, dynamic>::Compute(OpKernelContext* ctx) const {
const Tensor* input_tensor_ptr = ctx->Input<Tensor>(0);
ORT_ENFORCE(input_tensor_ptr != nullptr, "Missing input tensor to be processed");
const auto& input_tensor = *input_tensor_ptr;
const auto& input_dimensions = input_tensor.Shape().GetDims();
// Initialize the starts & ends to the actual tensor shape
std::vector<int64_t> starts(input_dimensions.size(), 0);
std::vector<int64_t> output_dims(input_dimensions);
if (dynamic) {
std::vector<int64_t> input_starts, input_ends, input_axes;
FillVectorsFromInput(ctx, input_starts, input_ends, input_axes);
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes,
input_dimensions, starts, output_dims));
} else {
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
input_dimensions, starts, output_dims));
}
return SliceImpl<T>(ctx, input_tensor, output_dims, starts);
}
} // namespace onnxruntime

View file

@ -9,35 +9,34 @@ namespace onnxruntime {
class SliceBase {
protected:
SliceBase (const OpKernelInfo& info, bool dynamic = false) {
SliceBase(const OpKernelInfo& info, bool dynamic = false) {
if (!dynamic) {
auto has_starts = info.GetAttrs("starts", attr_starts_).IsOK();
auto has_ends = info.GetAttrs("ends", attr_ends_).IsOK();
auto has_axes = info.GetAttrs("axes", attr_axes_).IsOK();
auto has_ends = info.GetAttrs("ends", attr_ends_).IsOK();
auto has_axes = info.GetAttrs("axes", attr_axes_).IsOK();
ORT_ENFORCE(has_starts && has_ends && attr_starts_.size() == attr_ends_.size(),
"Missing or invalid starts and ends attribute");
"Missing or invalid starts and ends attribute");
ORT_ENFORCE(!has_axes || attr_axes_.size() == attr_starts_.size(),
"Invalid axes attribute");
"Invalid axes attribute, axes attribute (if present) should have the same size as starts/ends attributes");
}
}
Status PrepareForCompute(const std::vector<int64_t>& raw_starts,
const std::vector<int64_t>& raw_ends,
const std::vector<int64_t>& raw_ends,
const std::vector<int64_t>& raw_axes,
const size_t dimension_count,
const std::vector<int64_t>& input_dimensions,
std::vector<int64_t>& starts,
std::vector<int64_t>& output_dims) const;
template<typename Tind>
std::vector<int64_t>& starts,
std::vector<int64_t>& output_dims) const;
void FillVectorsFromInput(const OpKernelContext* context,
std::vector<int64_t>& raw_starts,
std::vector<int64_t>& raw_ends,
std::vector<int64_t>& raw_axes) const;
std::vector<int64_t>& raw_starts,
std::vector<int64_t>& raw_ends,
std::vector<int64_t>& raw_axes) const;
std::vector<int64_t> attr_starts_, attr_ends_, attr_axes_;
};
template <typename T, typename Tind, bool dynamic>
template <typename T, bool dynamic>
struct Slice final : public OpKernel, public SliceBase {
Slice(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, dynamic) {}
Status Compute(OpKernelContext* context) const override;

View file

@ -40,13 +40,13 @@ Status Slice<Tind, dynamic>::ComputeInternal(OpKernelContext* ctx) const {
if (dynamic) {
std::vector<int64_t> input_starts, input_ends, input_axes;
FillVectorsFromInput<Tind>(ctx, input_starts, input_ends, input_axes);
FillVectorsFromInput(ctx, input_starts, input_ends, input_axes);
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes,
dimension_count, input_dimensions, starts, output_dims));
input_dimensions, starts, output_dims));
} else {
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
dimension_count, input_dimensions, starts, output_dims));
input_dimensions, starts, output_dims));
}
TensorShape output_shape(output_dims);