Rashuai/dynamic slice refactored (#264)

* define dynamic slice

* remove obsolete

* add test cases

* remove disabled cases

* rename test cases

* fix comments

* format code

* fix comments

* fix compile err

* fix typo

* removed duplicated delaration

* add enforced checks

* add enforced checks

* add extra processing on negative axis

* fix typo
This commit is contained in:
Randy 2019-01-02 16:39:41 -08:00 committed by GitHub
parent bd2ace7619
commit fc76076e29
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 319 additions and 63 deletions

View file

@ -171,6 +171,35 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int64_t, DynamicSlice);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 2, Split);
@ -355,6 +384,7 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 5, Reshape)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Shape)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Size)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, Slice)>());
@ -368,6 +398,35 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int32_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int32_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int32_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int32_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int32_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int32_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int32_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int32_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int32_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int32_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int32_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int32_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int32_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool_int64_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float_int64_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double_int64_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16_int64_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t_int64_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t_int64_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t_int64_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t_int64_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t_int64_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t_int64_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t_int64_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t_int64_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string_int64_t, DynamicSlice)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth)>());
fn(BuildKernel<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 2, Split)>());

View file

@ -8,27 +8,64 @@ using namespace std;
namespace onnxruntime {
#define ADD_TYPED_SLICE_OP(data_type) \
#define ADD_TYPED_SLICE_OP(data_type, indice_type) \
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
Slice, \
1, \
data_type, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()), \
Slice<data_type>);
Slice<data_type, indice_type, false>);
ADD_TYPED_SLICE_OP(uint8_t);
ADD_TYPED_SLICE_OP(uint16_t);
ADD_TYPED_SLICE_OP(uint32_t);
ADD_TYPED_SLICE_OP(uint64_t);
ADD_TYPED_SLICE_OP(int8_t);
ADD_TYPED_SLICE_OP(int16_t);
ADD_TYPED_SLICE_OP(int32_t);
ADD_TYPED_SLICE_OP(int64_t);
ADD_TYPED_SLICE_OP(float);
ADD_TYPED_SLICE_OP(double);
ADD_TYPED_SLICE_OP(MLFloat16);
ADD_TYPED_SLICE_OP(bool);
ADD_TYPED_SLICE_OP(string);
ADD_TYPED_SLICE_OP(uint8_t, int64_t);
ADD_TYPED_SLICE_OP(uint16_t, int64_t);
ADD_TYPED_SLICE_OP(uint32_t, int64_t);
ADD_TYPED_SLICE_OP(uint64_t, int64_t);
ADD_TYPED_SLICE_OP(int8_t, int64_t);
ADD_TYPED_SLICE_OP(int16_t, int64_t);
ADD_TYPED_SLICE_OP(int32_t, int64_t);
ADD_TYPED_SLICE_OP(int64_t, int64_t);
ADD_TYPED_SLICE_OP(float, int64_t);
ADD_TYPED_SLICE_OP(double, int64_t);
ADD_TYPED_SLICE_OP(MLFloat16,int64_t);
ADD_TYPED_SLICE_OP(bool, int64_t);
ADD_TYPED_SLICE_OP(string, int64_t);
#define ADD_TYPED_DYNAMIC_SLICE_OP(data_type, indice_type) \
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
DynamicSlice, \
1, \
data_type##_##indice_type, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()) \
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<indice_type>()), \
Slice<data_type, indice_type, true>);
ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int8_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int16_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int32_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int64_t, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(float, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(double, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16,int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(bool, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(string, int32_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint8_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint16_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint32_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(uint64_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int8_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int16_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int32_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(int64_t, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(float, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(double, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(MLFloat16,int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(bool, int64_t);
ADD_TYPED_DYNAMIC_SLICE_OP(string, int64_t);
namespace {
// std::clamp doesn't exist until C++17 so create a local version
@ -39,33 +76,33 @@ const T& clamp(const T& v, const T& lo, const T& hi) {
return v;
}
} // namespace
Status SliceBase::PrepareForCompute(const size_t dimension_count, const std::vector<int64_t>& input_dimensions,
std::vector<int64_t>& starts, std::vector<int64_t>& output_dims) const {
Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
const std::vector<int64_t>& raw_ends,
const std::vector<int64_t>& raw_axes,
const size_t dimension_count,
const std::vector<int64_t>& input_dimensions,
std::vector<int64_t>& starts,
std::vector<int64_t>& output_dims) const {
// Initialize axes to the provided axes attribute or to the default sequence
std::vector<int64_t> axes(axes_);
if (!has_axes_) {
std::vector<int64_t> axes(raw_axes);
if (axes.size() == 0) {
//axes are omitted, they are set to[0, ..., ndim - 1]
axes.resize(starts.size());
for (size_t i = 0; i < starts.size(); i++)
axes[i] = i;
if (axes.size() > starts_.size())
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'axes' has more entries than the 'starts' attribute holds");
if (axes.size() > ends_.size())
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'axes' has more entries than the 'ends' attribute holds");
std::iota(axes.begin(), axes.end(), 0);
}
// Iterate through the provided axes and override the start/end ranges
for (size_t axesIndex = 0; axesIndex < axes.size(); axesIndex++) {
auto axis = static_cast<size_t>(axes[axesIndex]);
if (axis >= dimension_count)
auto axis = axes[axesIndex] < 0 ? axes[axesIndex] + static_cast<int64_t>(dimension_count) : axes[axesIndex];
if (axis >= static_cast<int64_t>(dimension_count) || axis < 0)
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'axes' has an axis outside of the tensor dimension count");
auto start = starts_[axesIndex];
auto start = raw_starts[axesIndex];
if (start < 0)
start += input_dimensions[axis];
starts[axis] = clamp(start, int64_t{0}, input_dimensions[axis]);
auto end = ends_[axesIndex];
auto end = raw_ends[axesIndex];
if (end < 0)
end += input_dimensions[axis];
output_dims[axis] = clamp(end, int64_t{0}, input_dimensions[axis]) - starts[axis];
@ -76,8 +113,39 @@ Status SliceBase::PrepareForCompute(const size_t dimension_count, const std::vec
return Status::OK();
}
template <typename T>
Status Slice<T>::Compute(OpKernelContext* ctx) const {
template <typename T, typename Tind, bool dynamic>
void Slice<T, Tind, dynamic>::FillVectors(const OpKernelContext* context,
std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends,
std::vector<int64_t>& input_axes) const {
ORT_ENFORCE(context->Input<Tensor>(1) != nullptr, "Required starts input is missing");
ORT_ENFORCE(context->Input<Tensor>(2) != nullptr, "Required ends input is missing");
auto starts_tensor_ptr = context->Input<Tensor>(1);
ORT_ENFORCE(starts_tensor_ptr->Shape().NumDimensions() == 1, "Starts input must be a 1-D array");
input_starts = std::vector<int64_t> (starts_tensor_ptr->Data<Tind>(),
starts_tensor_ptr->Data<Tind>() +
starts_tensor_ptr->Shape().Size());
auto ends_tensor_ptr = context->Input<Tensor>(2);
ORT_ENFORCE(ends_tensor_ptr->Shape().NumDimensions() == 1, "ends input must be a 1-D array");
input_ends = std::vector<int64_t> (ends_tensor_ptr->Data<Tind>(),
ends_tensor_ptr->Data<Tind>() +
ends_tensor_ptr->Shape().Size());
ORT_ENFORCE(input_starts.size() == input_ends.size(), "Found mismatch between starts and ends input");
if (context->Input<Tensor>(3) != nullptr) {
auto axes_tensor_ptr = context->Input<Tensor>(3);
input_axes = std::vector<int64_t> (axes_tensor_ptr->Data<Tind>(),
axes_tensor_ptr->Data<Tind>() +
axes_tensor_ptr->Shape().Size());
ORT_ENFORCE(input_axes.size() == input_starts.size(), "Axes input is invalid");
}
}
template <typename T, typename Tind, bool dynamic>
Status Slice<T, Tind, dynamic>::Compute(OpKernelContext* ctx) const {
const Tensor* input_tensor_ptr = ctx->Input<Tensor>(0);
ORT_ENFORCE(input_tensor_ptr != nullptr);
auto& input_tensor = *input_tensor_ptr;
@ -88,7 +156,15 @@ Status Slice<T>::Compute(OpKernelContext* ctx) const {
std::vector<int64_t> starts(dimension_count, 0);
std::vector<int64_t> output_dims(input_dimensions);
ORT_RETURN_IF_ERROR(PrepareForCompute(dimension_count, input_dimensions, starts, output_dims));
if (dynamic) {
std::vector<int64_t> input_starts, input_ends, input_axes;
FillVectors(ctx, input_starts, input_ends, input_axes);
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes,
dimension_count, input_dimensions, starts, output_dims));
} else {
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
dimension_count, input_dimensions, starts, output_dims));
}
TensorShape output_shape(output_dims);
auto& output_tensor = *ctx->Output(0, output_shape);

View file

@ -9,33 +9,38 @@ namespace onnxruntime {
class SliceBase {
protected:
SliceBase(const OpKernelInfo& info) {
has_axes_ = info.GetAttrs("axes", axes_).IsOK();
ORT_ENFORCE(info.GetAttrs("starts", starts_).IsOK(), "Invalid 'starts' attribute value");
ORT_ENFORCE(info.GetAttrs("ends", ends_).IsOK(), "Invalid 'ends' attribute value");
if (has_axes_) {
if (axes_.size() > starts_.size())
ORT_THROW("'axes' has more entries than the 'starts' attribute holds");
if (axes_.size() > ends_.size())
ORT_THROW("'axes' has more entries than the 'ends' attribute holds");
SliceBase (const OpKernelInfo& info, bool dynamic = false) {
if (!dynamic) {
auto has_starts = info.GetAttrs("starts", attr_starts_).IsOK();
auto has_ends = info.GetAttrs("ends", attr_ends_).IsOK();
auto has_axes = info.GetAttrs("axes", attr_axes_).IsOK();
ORT_ENFORCE(has_starts && has_ends && attr_starts_.size() == attr_ends_.size(),
"Missing or invalid starts and ends attribute");
ORT_ENFORCE(!has_axes || attr_axes_.size() == attr_starts_.size(),
"Invalid axes attribute");
}
}
Status PrepareForCompute(const size_t dimension_count, const std::vector<int64_t>& input_dimensions,
std::vector<int64_t>& starts, std::vector<int64_t>& output_dims) const;
Status PrepareForCompute(const std::vector<int64_t>& raw_starts,
const std::vector<int64_t>& raw_ends,
const std::vector<int64_t>& raw_axes,
const size_t dimension_count,
const std::vector<int64_t>& input_dimensions,
std::vector<int64_t>& starts,
std::vector<int64_t>& output_dims) const;
std::vector<int64_t> axes_;
bool has_axes_;
std::vector<int64_t> starts_, ends_;
std::vector<int64_t> attr_starts_, attr_ends_, attr_axes_;
};
template <typename T>
template <typename T, typename Tind, bool dynamic>
struct Slice final : public OpKernel, public SliceBase {
Slice(const OpKernelInfo& info) : OpKernel(info), SliceBase(info) {}
Slice(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, dynamic) {}
Status Compute(OpKernelContext* context) const override;
private:
void FillVectors(const OpKernelContext* context,
std::vector<int64_t>& raw_starts,
std::vector<int64_t>& raw_ends,
std::vector<int64_t>& raw_axes) const;
}; // namespace onnxruntime
} // namespace onnxruntime

View file

@ -26,7 +26,9 @@ Status Slice::ComputeInternal(OpKernelContext* ctx) const {
std::vector<int64_t> starts(dimension_count, 0);
std::vector<int64_t> output_dims(input_dimensions);
ORT_RETURN_IF_ERROR(PrepareForCompute(dimension_count, input_dimensions, starts, output_dims));
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
dimension_count, input_dimensions,
starts, output_dims));
TensorShape output_shape(output_dims);
auto output_tensor = ctx->Output(0, output_shape);

View file

@ -270,11 +270,6 @@ int real_main(int argc, char* argv[]) {
{"Softsign", "disable reason"},
{"convtranspose_1d", "disable reason"},
{"convtranspose_3d", "disable reason"},
{"dynamic_slice", "disable reason"},
{"dynamic_slice_default_axes", "disable reason"},
{"dynamic_slice_end_out_of_bounds", "disable reason"},
{"dynamic_slice_neg", "disable reason"},
{"dynamic_slice_start_out_of_bounds", "disable reason"},
{"eyelike_populate_off_main_diagonal", "disable reason"},
{"eyelike_with_dtype", "disable reason"},
{"eyelike_without_dtype", "disable reason"},

View file

@ -0,0 +1,124 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
TEST(DynamicSliceTest, dynamic_slice_varied_types) {
OpTester test1 ("DynamicSlice", 1);
test1.AddInput <int32_t> ("data", {3,3}, {1,2,3,4,5,6,7,8,9});
test1.AddInput <int32_t> ("starts", {2}, {1,1});
test1.AddInput <int32_t> ("ends", {2}, {3,3});
test1.AddOutput <int32_t> ("output", {2,2}, {5,6,8,9});
test1.Run();
OpTester test2("DynamicSlice", 1);
test2.AddInput <int64_t> ("data", {3,3}, {1LL,2LL,3LL,4LL,5LL,6LL,7LL,8LL,9LL});
test2.AddInput <int32_t> ("starts", {2}, {1,1});
test2.AddInput <int32_t> ("ends", {2}, {3,3});
test2.AddOutput <int64_t> ("output", {2,2}, {5LL,6LL,8LL,9LL});
test2.Run();
OpTester test3("DynamicSlice", 1);
test3.AddInput <std::string> ("data", {3,3}, {"a","b","c","d","e","f","g","h","i"});
test3.AddInput <int64_t> ("starts", {2}, {1,1});
test3.AddInput <int64_t> ("ends", {2}, {3,3});
test3.AddOutput <std::string> ("output", {2,2}, {"e","f","h","i"});
test3.Run();
OpTester test4("DynamicSlice", 1);
test4.AddInput <float> ("data", {3,3}, {1.1f,2.2f,3.3f,4.4f,5.5f,6.6f,7.7f,8.8f,9.9f});
test4.AddInput <int32_t> ("starts", {2}, {1,1});
test4.AddInput <int32_t> ("ends", {2}, {3,3});
test4.AddOutput <float> ("output", {2,2}, {5.5f,6.6f,8.8f,9.9f});
test4.Run();
OpTester test5("DynamicSlice", 1);
test5.AddInput <bool> ("data", {3,3}, {false,true,false,false,false,false,true,false,true});
test5.AddInput <int32_t> ("starts", {2}, {1,1});
test5.AddInput <int32_t> ("ends", {2}, {3,3});
test5.AddOutput <bool> ("output", {2,2}, {false,false,false,true});
test5.Run();
}
TEST(DynamicSliceTest, dynamic_slice_with_axes) {
OpTester test1 ("DynamicSlice", 1);
test1.AddInput <int32_t> ("data", {3,3}, {1,2,3,4,5,6,7,8,9});
test1.AddInput <int32_t> ("starts", {1}, {1});
test1.AddInput <int32_t> ("ends", {1}, {3});
test1.AddInput <int32_t> ("axes", {1}, {-1});
test1.AddOutput <int32_t> ("output", {3,2}, {2,3,5,6,8,9});
test1.Run();
OpTester test2 ("DynamicSlice", 1);
test2.AddInput <int32_t> ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9,
10,11,12,13,14,15,16,17,18,
19,20,21,22,23,24,25,26,27});
test2.AddInput <int32_t> ("starts", {1}, {1});
test2.AddInput <int32_t> ("ends", {1}, {2});
test2.AddInput <int32_t> ("axes", {1}, {2});
test2.AddOutput <int32_t> ("output", {3,3,1}, {2,5,8,11,14,17,20,23,26});
test2.Run();
}
TEST(DynamicSliceTest, dynamic_slice_with_negative_axes) {
OpTester test1 ("DynamicSlice", 1);
test1.AddInput <int32_t> ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9,
10,11,12,13,14,15,16,17,18,
19,20,21,22,23,24,25,26,27});
test1.AddInput <int32_t> ("starts", {1}, {1});
test1.AddInput <int32_t> ("ends", {1}, {-1});
test1.AddInput <int32_t> ("axes", {1}, {1});
test1.AddOutput <int32_t> ("output", {3,1,3}, {4,5,6,13,14,15,22,23,24});
test1.Run();
OpTester test2 ("DynamicSlice", 1);
test2.AddInput <int32_t> ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9,
10,11,12,13,14,15,16,17,18,
19,20,21,22,23,24,25,26,27});
test2.AddInput <int32_t> ("starts", {2}, {-3,0});
test2.AddInput <int32_t> ("ends", {2}, {-1,2});
test2.AddInput <int32_t> ("axes", {2}, {0,2});
test2.AddOutput <int32_t> ("output", {2,3,2}, {1,2,4,5,7,8,10,11,13,14,16,17});
test2.Run();
}
TEST(DynamicSliceTest, dynamic_slice_ends_out_of_bounds) {
OpTester test ("DynamicSlice", 1);
test.AddInput <int32_t> ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9,
10,11,12,13,14,15,16,17,18,
19,20,21,22,23,24,25,26,27});
test.AddInput <int32_t> ("starts", {2}, {0,-2});
test.AddInput <int32_t> ("ends", {2}, {2,1000});
test.AddInput <int32_t> ("axes", {2}, {1,2});
test.AddOutput <int32_t> ("output", {3,2,2}, {2,3,5,6,11,12,14,15,20,21,23,24});
test.Run();
}
TEST(DynamicSliceTest, dynamic_slice_full_axes) {
OpTester test1 ("DynamicSlice", 1);
test1.AddInput <int32_t> ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9,
10,11,12,13,14,15,16,17,18,
19,20,21,22,23,24,25,26,27});
test1.AddInput <int32_t> ("starts", {3}, {0,1,1});
test1.AddInput <int32_t> ("ends", {3}, {1,3,2});
test1.AddInput <int32_t> ("axes", {3}, {0,1,2});
test1.AddOutput <int32_t> ("output", {1,2,1}, {5,8});
test1.Run();
OpTester test2 ("DynamicSlice", 1);
test2.AddInput <int32_t> ("data", {3,3,3}, {1, 2, 3, 4, 5, 6, 7, 8, 9,
10,11,12,13,14,15,16,17,18,
19,20,21,22,23,24,25,26,27});
test2.AddInput <int32_t> ("starts", {3}, {1,0,1});
test2.AddInput <int32_t> ("ends", {3}, {2,1,3});
test2.AddInput <int32_t> ("axes", {3}, {2,0,1});
test2.AddOutput <int32_t> ("output", {1,2,1}, {5,8});
test2.Run();
}
} // namespace Test
} // namespace onnxruntime

View file

@ -25,11 +25,6 @@ backend_test.exclude(r'(test_acosh_cpu.*'
'|test_convtranspose_1d_cpu.*'
'|test_convtranspose_3d_cpu.*'
'|test_cosh_example_cpu.*'
'|test_dynamic_slice_cpu.*'
'|test_dynamic_slice_default_axes_cpu.*'
'|test_dynamic_slice_end_out_of_bounds_cpu.*'
'|test_dynamic_slice_neg_cpu.*'
'|test_dynamic_slice_start_out_of_bounds_cpu.*'
'|test_eyelike_populate_off_main_diagonal_cpu.*'
'|test_eyelike_with_dtype_cpu.*'
'|test_eyelike_without_dtype_cpu.*'