mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Add missing types for Slice op (#74)
This commit is contained in:
parent
aa549cd194
commit
1c9d0b2729
3 changed files with 77 additions and 11 deletions
|
|
@ -161,7 +161,19 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 5, Reshape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Shape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Size);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Compress);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace);
|
||||
|
|
@ -333,7 +345,19 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
|
|||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 5, Reshape)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Shape)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Size)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, bool, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint8_t, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint16_t, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint32_t, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, uint64_t, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int8_t, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int16_t, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int32_t, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, int64_t, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, string, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Compress)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, SpaceToDepth)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 4, DepthToSpace)>());
|
||||
|
|
|
|||
|
|
@ -4,14 +4,32 @@
|
|||
#include "core/providers/cpu/tensor/slice.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
using namespace ::onnxruntime::common;
|
||||
using namespace std;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Slice,
|
||||
1,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Slice<float>);
|
||||
#define ADD_TYPED_SLICE_OP(data_type) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
Slice, \
|
||||
1, \
|
||||
data_type, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()), \
|
||||
Slice<data_type>);
|
||||
|
||||
ADD_TYPED_SLICE_OP(uint8_t);
|
||||
ADD_TYPED_SLICE_OP(uint16_t);
|
||||
ADD_TYPED_SLICE_OP(uint32_t);
|
||||
ADD_TYPED_SLICE_OP(uint64_t);
|
||||
ADD_TYPED_SLICE_OP(int8_t);
|
||||
ADD_TYPED_SLICE_OP(int16_t);
|
||||
ADD_TYPED_SLICE_OP(int32_t);
|
||||
ADD_TYPED_SLICE_OP(int64_t);
|
||||
ADD_TYPED_SLICE_OP(float);
|
||||
ADD_TYPED_SLICE_OP(double);
|
||||
ADD_TYPED_SLICE_OP(MLFloat16);
|
||||
ADD_TYPED_SLICE_OP(bool);
|
||||
ADD_TYPED_SLICE_OP(string);
|
||||
|
||||
namespace {
|
||||
// std::clamp doesn't exist until C++17 so create a local version
|
||||
template <typename T>
|
||||
|
|
@ -58,8 +76,8 @@ Status SliceBase::PrepareForCompute(const size_t dimension_count, const std::vec
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status Slice<float>::Compute(OpKernelContext* ctx) const {
|
||||
template <typename T>
|
||||
Status Slice<T>::Compute(OpKernelContext* ctx) const {
|
||||
const Tensor* input_tensor_ptr = ctx->Input<Tensor>(0);
|
||||
ONNXRUNTIME_ENFORCE(input_tensor_ptr != nullptr);
|
||||
auto& input_tensor = *input_tensor_ptr;
|
||||
|
|
@ -74,10 +92,10 @@ Status Slice<float>::Compute(OpKernelContext* ctx) const {
|
|||
|
||||
TensorShape output_shape(output_dims);
|
||||
auto& output_tensor = *ctx->Output(0, output_shape);
|
||||
auto* output = output_tensor.template MutableData<float>();
|
||||
auto* output = output_tensor.template MutableData<T>();
|
||||
const auto* output_end = output + output_shape.Size();
|
||||
|
||||
SliceIterator<float> input_iterator(input_tensor, starts, output_dims);
|
||||
SliceIterator<T> input_iterator(input_tensor, starts, output_dims);
|
||||
while (output != output_end)
|
||||
*output++ = *input_iterator++;
|
||||
|
||||
|
|
|
|||
|
|
@ -134,5 +134,29 @@ TEST(SliceTest, Slice3D) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SliceTest, Slice1D_Int) {
|
||||
OpTester test("Slice");
|
||||
|
||||
test.AddAttribute("axes", std::vector<int64_t>{0});
|
||||
test.AddAttribute("starts", std::vector<int64_t>{2});
|
||||
test.AddAttribute("ends", std::vector<int64_t>{4});
|
||||
|
||||
test.AddInput<int32_t>("data", {6}, {0L, 1L, 2L, 3L, 4L, 5L});
|
||||
test.AddOutput<int32_t>("output", {2}, {2L, 3L});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SliceTest, Slice1D_String) {
|
||||
OpTester test("Slice");
|
||||
|
||||
test.AddAttribute("axes", std::vector<int64_t>{0});
|
||||
test.AddAttribute("starts", std::vector<int64_t>{2});
|
||||
test.AddAttribute("ends", std::vector<int64_t>{4});
|
||||
|
||||
test.AddInput<std::string>("data", {6}, {"0", "1", "2", "3", "4", "5"});
|
||||
test.AddOutput<std::string>("output", {2}, {"2", "3"});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace Test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue