Add missing types for Slice op (#74)

This commit is contained in:
Hector Li 2018-12-03 13:54:48 -08:00 committed by GitHub
parent aa549cd194
commit 1c9d0b2729
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 77 additions and 11 deletions

View file

@ -161,7 +161,19 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 5, Reshape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Size);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Compress);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace);
@ -333,7 +345,19 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 5, Reshape)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Shape)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Size)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Compress)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth)>());
fn(BuildKernel<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace)>());

View file

@ -4,14 +4,32 @@
#include "core/providers/cpu/tensor/slice.h"
#include "core/providers/cpu/tensor/utils.h"
using namespace ::onnxruntime::common;
using namespace std;
namespace onnxruntime {
ONNX_CPU_OPERATOR_KERNEL(
Slice,
1,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Slice<float>);
#define ADD_TYPED_SLICE_OP(data_type) \
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
Slice, \
1, \
data_type, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()), \
Slice<data_type>);
ADD_TYPED_SLICE_OP(uint8_t);
ADD_TYPED_SLICE_OP(uint16_t);
ADD_TYPED_SLICE_OP(uint32_t);
ADD_TYPED_SLICE_OP(uint64_t);
ADD_TYPED_SLICE_OP(int8_t);
ADD_TYPED_SLICE_OP(int16_t);
ADD_TYPED_SLICE_OP(int32_t);
ADD_TYPED_SLICE_OP(int64_t);
ADD_TYPED_SLICE_OP(float);
ADD_TYPED_SLICE_OP(double);
ADD_TYPED_SLICE_OP(MLFloat16);
ADD_TYPED_SLICE_OP(bool);
ADD_TYPED_SLICE_OP(string);
namespace {
// std::clamp doesn't exist until C++17 so create a local version
template <typename T>
@ -58,8 +76,8 @@ Status SliceBase::PrepareForCompute(const size_t dimension_count, const std::vec
return Status::OK();
}
template <>
Status Slice<float>::Compute(OpKernelContext* ctx) const {
template <typename T>
Status Slice<T>::Compute(OpKernelContext* ctx) const {
const Tensor* input_tensor_ptr = ctx->Input<Tensor>(0);
ONNXRUNTIME_ENFORCE(input_tensor_ptr != nullptr);
auto& input_tensor = *input_tensor_ptr;
@ -74,10 +92,10 @@ Status Slice<float>::Compute(OpKernelContext* ctx) const {
TensorShape output_shape(output_dims);
auto& output_tensor = *ctx->Output(0, output_shape);
auto* output = output_tensor.template MutableData<float>();
auto* output = output_tensor.template MutableData<T>();
const auto* output_end = output + output_shape.Size();
SliceIterator<float> input_iterator(input_tensor, starts, output_dims);
SliceIterator<T> input_iterator(input_tensor, starts, output_dims);
while (output != output_end)
*output++ = *input_iterator++;

View file

@ -134,5 +134,29 @@ TEST(SliceTest, Slice3D) {
test.Run();
}
TEST(SliceTest, Slice1D_Int) {
OpTester test("Slice");
test.AddAttribute("axes", std::vector<int64_t>{0});
test.AddAttribute("starts", std::vector<int64_t>{2});
test.AddAttribute("ends", std::vector<int64_t>{4});
test.AddInput<int32_t>("data", {6}, {0L, 1L, 2L, 3L, 4L, 5L});
test.AddOutput<int32_t>("output", {2}, {2L, 3L});
test.Run();
}
TEST(SliceTest, Slice1D_String) {
OpTester test("Slice");
test.AddAttribute("axes", std::vector<int64_t>{0});
test.AddAttribute("starts", std::vector<int64_t>{2});
test.AddAttribute("ends", std::vector<int64_t>{4});
test.AddInput<std::string>("data", {6}, {"0", "1", "2", "3", "4", "5"});
test.AddOutput<std::string>("output", {2}, {"2", "3"});
test.Run();
}
} // namespace Test
} // namespace onnxruntime