mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-30 23:18:20 +00:00
Rashuai/dynamic slice refactored (#264)
* define dynamic slice * remove obsolete * add test cases * remove disabled cases * rename test cases * fix comments * format code * fix comments * fix compile err * fix typo * removed duplicated delaration * add enforced checks * add enforced checks * add extra processing on negative axis * fix typo
This commit is contained in:
parent
bd2ace7619
commit
fc76076e29
7 changed files with 319 additions and 63 deletions
|
|
@ -171,6 +171,35 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int32_t, DynamicSlice);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int64_t, DynamicSlice);
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 2, Split);
|
||||
|
|
@ -355,6 +384,7 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
|
|||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 5, Reshape)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Shape)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Size)>());
|
||||
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, Slice)>());
|
||||
|
|
@ -368,6 +398,35 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
|
|||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice)>());
|
||||
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int32_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int32_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int32_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int32_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int32_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int32_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int32_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int32_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int32_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int32_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int32_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int32_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int32_t, DynamicSlice)>());
|
||||
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int64_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int64_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int64_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int64_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int64_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int64_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int64_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int64_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int64_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int64_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int64_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int64_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int64_t, DynamicSlice)>());
|
||||
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 2, Split)>());
|
||||
|
|
|
|||
|
|
@ -8,27 +8,64 @@ using namespace std;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
#define ADD_TYPED_SLICE_OP(data_type) \
|
||||
#define ADD_TYPED_SLICE_OP(data_type, indice_type) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
Slice, \
|
||||
1, \
|
||||
data_type, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()), \
|
||||
Slice<data_type>);
|
||||
Slice<data_type, indice_type, false>);
|
||||
|
||||
ADD_TYPED_SLICE_OP(uint8_t);
|
||||
ADD_TYPED_SLICE_OP(uint16_t);
|
||||
ADD_TYPED_SLICE_OP(uint32_t);
|
||||
ADD_TYPED_SLICE_OP(uint64_t);
|
||||
ADD_TYPED_SLICE_OP(int8_t);
|
||||
ADD_TYPED_SLICE_OP(int16_t);
|
||||
ADD_TYPED_SLICE_OP(int32_t);
|
||||
ADD_TYPED_SLICE_OP(int64_t);
|
||||
ADD_TYPED_SLICE_OP(float);
|
||||
ADD_TYPED_SLICE_OP(double);
|
||||
ADD_TYPED_SLICE_OP(MLFloat16);
|
||||
ADD_TYPED_SLICE_OP(bool);
|
||||
ADD_TYPED_SLICE_OP(string);
|
||||
ADD_TYPED_SLICE_OP(uint8_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(uint16_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(uint32_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(uint64_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(int8_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(int16_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(int32_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(int64_t, int64_t);
|
||||
ADD_TYPED_SLICE_OP(float, int64_t);
|
||||
ADD_TYPED_SLICE_OP(double, int64_t);
|
||||
ADD_TYPED_SLICE_OP(MLFloat16,int64_t);
|
||||
ADD_TYPED_SLICE_OP(bool, int64_t);
|
||||
ADD_TYPED_SLICE_OP(string, int64_t);
|
||||
|
||||
#define ADD_TYPED_DYNAMIC_SLICE_OP(data_type, indice_type) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
DynamicSlice, \
|
||||
1, \
|
||||
data_type##_##indice_type, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()) \
|
||||
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<indice_type>()), \
|
||||
Slice<data_type, indice_type, true>);
|
||||
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int8_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int16_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int32_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int64_t, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(float, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(double, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16,int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(bool, int32_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(string, int32_t);
|
||||
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int8_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int16_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int32_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(int64_t, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(float, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(double, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16,int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(bool, int64_t);
|
||||
ADD_TYPED_DYNAMIC_SLICE_OP(string, int64_t);
|
||||
|
||||
namespace {
|
||||
// std::clamp doesn't exist until C++17 so create a local version
|
||||
|
|
@ -39,33 +76,33 @@ const T& clamp(const T& v, const T& lo, const T& hi) {
|
|||
return v;
|
||||
}
|
||||
} // namespace
|
||||
Status SliceBase::PrepareForCompute(const size_t dimension_count, const std::vector<int64_t>& input_dimensions,
|
||||
std::vector<int64_t>& starts, std::vector<int64_t>& output_dims) const {
|
||||
|
||||
Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
|
||||
const std::vector<int64_t>& raw_ends,
|
||||
const std::vector<int64_t>& raw_axes,
|
||||
const size_t dimension_count,
|
||||
const std::vector<int64_t>& input_dimensions,
|
||||
std::vector<int64_t>& starts,
|
||||
std::vector<int64_t>& output_dims) const {
|
||||
// Initialize axes to the provided axes attribute or to the default sequence
|
||||
std::vector<int64_t> axes(axes_);
|
||||
if (!has_axes_) {
|
||||
std::vector<int64_t> axes(raw_axes);
|
||||
if (axes.size() == 0) {
|
||||
//axes are omitted, they are set to[0, ..., ndim - 1]
|
||||
axes.resize(starts.size());
|
||||
for (size_t i = 0; i < starts.size(); i++)
|
||||
axes[i] = i;
|
||||
|
||||
if (axes.size() > starts_.size())
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'axes' has more entries than the 'starts' attribute holds");
|
||||
if (axes.size() > ends_.size())
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'axes' has more entries than the 'ends' attribute holds");
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
}
|
||||
|
||||
// Iterate through the provided axes and override the start/end ranges
|
||||
for (size_t axesIndex = 0; axesIndex < axes.size(); axesIndex++) {
|
||||
auto axis = static_cast<size_t>(axes[axesIndex]);
|
||||
if (axis >= dimension_count)
|
||||
auto axis = axes[axesIndex] < 0 ? axes[axesIndex] + static_cast<int64_t>(dimension_count) : axes[axesIndex];
|
||||
if (axis >= static_cast<int64_t>(dimension_count) || axis < 0)
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'axes' has an axis outside of the tensor dimension count");
|
||||
auto start = starts_[axesIndex];
|
||||
auto start = raw_starts[axesIndex];
|
||||
if (start < 0)
|
||||
start += input_dimensions[axis];
|
||||
starts[axis] = clamp(start, int64_t{0}, input_dimensions[axis]);
|
||||
|
||||
auto end = ends_[axesIndex];
|
||||
auto end = raw_ends[axesIndex];
|
||||
if (end < 0)
|
||||
end += input_dimensions[axis];
|
||||
output_dims[axis] = clamp(end, int64_t{0}, input_dimensions[axis]) - starts[axis];
|
||||
|
|
@ -76,8 +113,39 @@ Status SliceBase::PrepareForCompute(const size_t dimension_count, const std::vec
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Slice<T>::Compute(OpKernelContext* ctx) const {
|
||||
template <typename T, typename Tind, bool dynamic>
|
||||
void Slice<T, Tind, dynamic>::FillVectors(const OpKernelContext* context,
|
||||
std::vector<int64_t>& input_starts,
|
||||
std::vector<int64_t>& input_ends,
|
||||
std::vector<int64_t>& input_axes) const {
|
||||
ORT_ENFORCE(context->Input<Tensor>(1) != nullptr, "Required starts input is missing");
|
||||
ORT_ENFORCE(context->Input<Tensor>(2) != nullptr, "Required ends input is missing");
|
||||
|
||||
auto starts_tensor_ptr = context->Input<Tensor>(1);
|
||||
ORT_ENFORCE(starts_tensor_ptr->Shape().NumDimensions() == 1, "Starts input must be a 1-D array");
|
||||
input_starts = std::vector<int64_t> (starts_tensor_ptr->Data<Tind>(),
|
||||
starts_tensor_ptr->Data<Tind>() +
|
||||
starts_tensor_ptr->Shape().Size());
|
||||
|
||||
auto ends_tensor_ptr = context->Input<Tensor>(2);
|
||||
ORT_ENFORCE(ends_tensor_ptr->Shape().NumDimensions() == 1, "ends input must be a 1-D array");
|
||||
input_ends = std::vector<int64_t> (ends_tensor_ptr->Data<Tind>(),
|
||||
ends_tensor_ptr->Data<Tind>() +
|
||||
ends_tensor_ptr->Shape().Size());
|
||||
|
||||
ORT_ENFORCE(input_starts.size() == input_ends.size(), "Found mismatch between starts and ends input");
|
||||
|
||||
if (context->Input<Tensor>(3) != nullptr) {
|
||||
auto axes_tensor_ptr = context->Input<Tensor>(3);
|
||||
input_axes = std::vector<int64_t> (axes_tensor_ptr->Data<Tind>(),
|
||||
axes_tensor_ptr->Data<Tind>() +
|
||||
axes_tensor_ptr->Shape().Size());
|
||||
ORT_ENFORCE(input_axes.size() == input_starts.size(), "Axes input is invalid");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Tind, bool dynamic>
|
||||
Status Slice<T, Tind, dynamic>::Compute(OpKernelContext* ctx) const {
|
||||
const Tensor* input_tensor_ptr = ctx->Input<Tensor>(0);
|
||||
ORT_ENFORCE(input_tensor_ptr != nullptr);
|
||||
auto& input_tensor = *input_tensor_ptr;
|
||||
|
|
@ -88,7 +156,15 @@ Status Slice<T>::Compute(OpKernelContext* ctx) const {
|
|||
std::vector<int64_t> starts(dimension_count, 0);
|
||||
std::vector<int64_t> output_dims(input_dimensions);
|
||||
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(dimension_count, input_dimensions, starts, output_dims));
|
||||
if (dynamic) {
|
||||
std::vector<int64_t> input_starts, input_ends, input_axes;
|
||||
FillVectors(ctx, input_starts, input_ends, input_axes);
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes,
|
||||
dimension_count, input_dimensions, starts, output_dims));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
|
||||
dimension_count, input_dimensions, starts, output_dims));
|
||||
}
|
||||
|
||||
TensorShape output_shape(output_dims);
|
||||
auto& output_tensor = *ctx->Output(0, output_shape);
|
||||
|
|
|
|||
|
|
@ -9,33 +9,38 @@ namespace onnxruntime {
|
|||
|
||||
class SliceBase {
|
||||
protected:
|
||||
SliceBase(const OpKernelInfo& info) {
|
||||
has_axes_ = info.GetAttrs("axes", axes_).IsOK();
|
||||
|
||||
ORT_ENFORCE(info.GetAttrs("starts", starts_).IsOK(), "Invalid 'starts' attribute value");
|
||||
ORT_ENFORCE(info.GetAttrs("ends", ends_).IsOK(), "Invalid 'ends' attribute value");
|
||||
|
||||
if (has_axes_) {
|
||||
if (axes_.size() > starts_.size())
|
||||
ORT_THROW("'axes' has more entries than the 'starts' attribute holds");
|
||||
if (axes_.size() > ends_.size())
|
||||
ORT_THROW("'axes' has more entries than the 'ends' attribute holds");
|
||||
SliceBase (const OpKernelInfo& info, bool dynamic = false) {
|
||||
if (!dynamic) {
|
||||
auto has_starts = info.GetAttrs("starts", attr_starts_).IsOK();
|
||||
auto has_ends = info.GetAttrs("ends", attr_ends_).IsOK();
|
||||
auto has_axes = info.GetAttrs("axes", attr_axes_).IsOK();
|
||||
ORT_ENFORCE(has_starts && has_ends && attr_starts_.size() == attr_ends_.size(),
|
||||
"Missing or invalid starts and ends attribute");
|
||||
ORT_ENFORCE(!has_axes || attr_axes_.size() == attr_starts_.size(),
|
||||
"Invalid axes attribute");
|
||||
}
|
||||
}
|
||||
|
||||
Status PrepareForCompute(const size_t dimension_count, const std::vector<int64_t>& input_dimensions,
|
||||
std::vector<int64_t>& starts, std::vector<int64_t>& output_dims) const;
|
||||
Status PrepareForCompute(const std::vector<int64_t>& raw_starts,
|
||||
const std::vector<int64_t>& raw_ends,
|
||||
const std::vector<int64_t>& raw_axes,
|
||||
const size_t dimension_count,
|
||||
const std::vector<int64_t>& input_dimensions,
|
||||
std::vector<int64_t>& starts,
|
||||
std::vector<int64_t>& output_dims) const;
|
||||
|
||||
std::vector<int64_t> axes_;
|
||||
bool has_axes_;
|
||||
std::vector<int64_t> starts_, ends_;
|
||||
std::vector<int64_t> attr_starts_, attr_ends_, attr_axes_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename Tind, bool dynamic>
|
||||
struct Slice final : public OpKernel, public SliceBase {
|
||||
Slice(const OpKernelInfo& info) : OpKernel(info), SliceBase(info) {}
|
||||
|
||||
Slice(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, dynamic) {}
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
private:
|
||||
void FillVectors(const OpKernelContext* context,
|
||||
std::vector<int64_t>& raw_starts,
|
||||
std::vector<int64_t>& raw_ends,
|
||||
std::vector<int64_t>& raw_axes) const;
|
||||
}; // namespace onnxruntime
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -26,7 +26,9 @@ Status Slice::ComputeInternal(OpKernelContext* ctx) const {
|
|||
std::vector<int64_t> starts(dimension_count, 0);
|
||||
std::vector<int64_t> output_dims(input_dimensions);
|
||||
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(dimension_count, input_dimensions, starts, output_dims));
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
|
||||
dimension_count, input_dimensions,
|
||||
starts, output_dims));
|
||||
|
||||
TensorShape output_shape(output_dims);
|
||||
auto output_tensor = ctx->Output(0, output_shape);
|
||||
|
|
|
|||
|
|
@ -270,11 +270,6 @@ int real_main(int argc, char* argv[]) {
|
|||
{"Softsign", "disable reason"},
|
||||
{"convtranspose_1d", "disable reason"},
|
||||
{"convtranspose_3d", "disable reason"},
|
||||
{"dynamic_slice", "disable reason"},
|
||||
{"dynamic_slice_default_axes", "disable reason"},
|
||||
{"dynamic_slice_end_out_of_bounds", "disable reason"},
|
||||
{"dynamic_slice_neg", "disable reason"},
|
||||
{"dynamic_slice_start_out_of_bounds", "disable reason"},
|
||||
{"eyelike_populate_off_main_diagonal", "disable reason"},
|
||||
{"eyelike_with_dtype", "disable reason"},
|
||||
{"eyelike_without_dtype", "disable reason"},
|
||||
|
|
|
|||
124
onnxruntime/test/providers/cpu/tensor/dynamic_slice_op_test.cc
Normal file
124
onnxruntime/test/providers/cpu/tensor/dynamic_slice_op_test.cc
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(DynamicSliceTest, dynamic_slice_varied_types) {
|
||||
OpTester test1 ("DynamicSlice", 1);
|
||||
test1.AddInput <int32_t> ("data", {3,3}, {1,2,3,4,5,6,7,8,9});
|
||||
test1.AddInput <int32_t> ("starts", {2}, {1,1});
|
||||
test1.AddInput <int32_t> ("ends", {2}, {3,3});
|
||||
test1.AddOutput <int32_t> ("output", {2,2}, {5,6,8,9});
|
||||
test1.Run();
|
||||
|
||||
OpTester test2("DynamicSlice", 1);
|
||||
test2.AddInput <int64_t> ("data", {3,3}, {1LL,2LL,3LL,4LL,5LL,6LL,7LL,8LL,9LL});
|
||||
test2.AddInput <int32_t> ("starts", {2}, {1,1});
|
||||
test2.AddInput <int32_t> ("ends", {2}, {3,3});
|
||||
test2.AddOutput <int64_t> ("output", {2,2}, {5LL,6LL,8LL,9LL});
|
||||
test2.Run();
|
||||
|
||||
OpTester test3("DynamicSlice", 1);
|
||||
test3.AddInput <std::string> ("data", {3,3}, {"a","b","c","d","e","f","g","h","i"});
|
||||
test3.AddInput <int64_t> ("starts", {2}, {1,1});
|
||||
test3.AddInput <int64_t> ("ends", {2}, {3,3});
|
||||
test3.AddOutput <std::string> ("output", {2,2}, {"e","f","h","i"});
|
||||
test3.Run();
|
||||
|
||||
OpTester test4("DynamicSlice", 1);
|
||||
test4.AddInput <float> ("data", {3,3}, {1.1f,2.2f,3.3f,4.4f,5.5f,6.6f,7.7f,8.8f,9.9f});
|
||||
test4.AddInput <int32_t> ("starts", {2}, {1,1});
|
||||
test4.AddInput <int32_t> ("ends", {2}, {3,3});
|
||||
test4.AddOutput <float> ("output", {2,2}, {5.5f,6.6f,8.8f,9.9f});
|
||||
test4.Run();
|
||||
|
||||
OpTester test5("DynamicSlice", 1);
|
||||
test5.AddInput <bool> ("data", {3,3}, {false,true,false,false,false,false,true,false,true});
|
||||
test5.AddInput <int32_t> ("starts", {2}, {1,1});
|
||||
test5.AddInput <int32_t> ("ends", {2}, {3,3});
|
||||
test5.AddOutput <bool> ("output", {2,2}, {false,false,false,true});
|
||||
test5.Run();
|
||||
}
|
||||
|
||||
TEST(DynamicSliceTest, dynamic_slice_with_axes) {
|
||||
OpTester test1 ("DynamicSlice", 1);
|
||||
test1.AddInput <int32_t> ("data", {3,3}, {1,2,3,4,5,6,7,8,9});
|
||||
test1.AddInput <int32_t> ("starts", {1}, {1});
|
||||
test1.AddInput <int32_t> ("ends", {1}, {3});
|
||||
test1.AddInput <int32_t> ("axes", {1}, {-1});
|
||||
test1.AddOutput <int32_t> ("output", {3,2}, {2,3,5,6,8,9});
|
||||
test1.Run();
|
||||
|
||||
OpTester test2 ("DynamicSlice", 1);
|
||||
test2.AddInput <int32_t> ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9,
|
||||
10,11,12,13,14,15,16,17,18,
|
||||
19,20,21,22,23,24,25,26,27});
|
||||
test2.AddInput <int32_t> ("starts", {1}, {1});
|
||||
test2.AddInput <int32_t> ("ends", {1}, {2});
|
||||
test2.AddInput <int32_t> ("axes", {1}, {2});
|
||||
test2.AddOutput <int32_t> ("output", {3,3,1}, {2,5,8,11,14,17,20,23,26});
|
||||
test2.Run();
|
||||
}
|
||||
|
||||
TEST(DynamicSliceTest, dynamic_slice_with_negative_axes) {
|
||||
OpTester test1 ("DynamicSlice", 1);
|
||||
test1.AddInput <int32_t> ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9,
|
||||
10,11,12,13,14,15,16,17,18,
|
||||
19,20,21,22,23,24,25,26,27});
|
||||
test1.AddInput <int32_t> ("starts", {1}, {1});
|
||||
test1.AddInput <int32_t> ("ends", {1}, {-1});
|
||||
test1.AddInput <int32_t> ("axes", {1}, {1});
|
||||
test1.AddOutput <int32_t> ("output", {3,1,3}, {4,5,6,13,14,15,22,23,24});
|
||||
test1.Run();
|
||||
|
||||
OpTester test2 ("DynamicSlice", 1);
|
||||
test2.AddInput <int32_t> ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9,
|
||||
10,11,12,13,14,15,16,17,18,
|
||||
19,20,21,22,23,24,25,26,27});
|
||||
test2.AddInput <int32_t> ("starts", {2}, {-3,0});
|
||||
test2.AddInput <int32_t> ("ends", {2}, {-1,2});
|
||||
test2.AddInput <int32_t> ("axes", {2}, {0,2});
|
||||
test2.AddOutput <int32_t> ("output", {2,3,2}, {1,2,4,5,7,8,10,11,13,14,16,17});
|
||||
test2.Run();
|
||||
}
|
||||
|
||||
TEST(DynamicSliceTest, dynamic_slice_ends_out_of_bounds) {
|
||||
OpTester test ("DynamicSlice", 1);
|
||||
test.AddInput <int32_t> ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9,
|
||||
10,11,12,13,14,15,16,17,18,
|
||||
19,20,21,22,23,24,25,26,27});
|
||||
test.AddInput <int32_t> ("starts", {2}, {0,-2});
|
||||
test.AddInput <int32_t> ("ends", {2}, {2,1000});
|
||||
test.AddInput <int32_t> ("axes", {2}, {1,2});
|
||||
test.AddOutput <int32_t> ("output", {3,2,2}, {2,3,5,6,11,12,14,15,20,21,23,24});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(DynamicSliceTest, dynamic_slice_full_axes) {
|
||||
OpTester test1 ("DynamicSlice", 1);
|
||||
test1.AddInput <int32_t> ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9,
|
||||
10,11,12,13,14,15,16,17,18,
|
||||
19,20,21,22,23,24,25,26,27});
|
||||
test1.AddInput <int32_t> ("starts", {3}, {0,1,1});
|
||||
test1.AddInput <int32_t> ("ends", {3}, {1,3,2});
|
||||
test1.AddInput <int32_t> ("axes", {3}, {0,1,2});
|
||||
test1.AddOutput <int32_t> ("output", {1,2,1}, {5,8});
|
||||
test1.Run();
|
||||
|
||||
OpTester test2 ("DynamicSlice", 1);
|
||||
test2.AddInput <int32_t> ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9,
|
||||
10,11,12,13,14,15,16,17,18,
|
||||
19,20,21,22,23,24,25,26,27});
|
||||
test2.AddInput <int32_t> ("starts", {3}, {1,0,1});
|
||||
test2.AddInput <int32_t> ("ends", {3}, {2,1,3});
|
||||
test2.AddInput <int32_t> ("axes", {3}, {2,0,1});
|
||||
test2.AddOutput <int32_t> ("output", {1,2,1}, {5,8});
|
||||
test2.Run();
|
||||
}
|
||||
|
||||
} // namespace Test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -25,11 +25,6 @@ backend_test.exclude(r'(test_acosh_cpu.*'
|
|||
'|test_convtranspose_1d_cpu.*'
|
||||
'|test_convtranspose_3d_cpu.*'
|
||||
'|test_cosh_example_cpu.*'
|
||||
'|test_dynamic_slice_cpu.*'
|
||||
'|test_dynamic_slice_default_axes_cpu.*'
|
||||
'|test_dynamic_slice_end_out_of_bounds_cpu.*'
|
||||
'|test_dynamic_slice_neg_cpu.*'
|
||||
'|test_dynamic_slice_start_out_of_bounds_cpu.*'
|
||||
'|test_eyelike_populate_off_main_diagonal_cpu.*'
|
||||
'|test_eyelike_with_dtype_cpu.*'
|
||||
'|test_eyelike_without_dtype_cpu.*'
|
||||
|
|
|
|||
Loading…
Reference in a new issue