From d35409f58ef881e5b7e02533fa9c199e98386dff Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 5 Feb 2019 15:08:52 -0800 Subject: [PATCH] Support uint8 datatype for Upsample op in CPU and CUDA providers (#440) --- .../providers/cpu/cpu_execution_provider.cc | 2 ++ .../core/providers/cpu/tensor/upsample.cc | 7 +++++ .../providers/cuda/cuda_execution_provider.cc | 2 ++ .../core/providers/cuda/tensor/upsample.cc | 1 + .../providers/cuda/tensor/upsample_impl.cu | 1 + .../providers/cpu/tensor/upsample_op_test.cc | 31 +++++++++++++++++++ 6 files changed, 44 insertions(+) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 81cc7fd72d..b8fc7ce270 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -209,6 +209,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Tra class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Unsqueeze); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Upsample); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, int32_t, Upsample); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, uint8_t, Upsample); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Expand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Expand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int8_t, Expand); @@ -453,6 +454,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index f6fa2195f8..8b6d7dfabe 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -22,6 +22,13 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Upsample); +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + Upsample, + 7, 9, + uint8_t, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Upsample); + template void UpsampleNearest2x( int64_t batch_size, diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index db8a5c88d3..fc86ebc1f9 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -501,6 +501,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, double, Upsample); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, Upsample); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, int32_t, Upsample); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, uint8_t, Upsample); static void RegisterCudaKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); @@ -764,6 +765,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); } std::shared_ptr GetCudaKernelRegistry() { diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index 8347077852..2f4e3a37aa 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -27,6 +27,7 @@ REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) REGISTER_KERNEL_TYPED(MLFloat16) REGISTER_KERNEL_TYPED(int32_t) +REGISTER_KERNEL_TYPED(uint8_t) template Status Upsample::BaseCompute(OpKernelContext* context, const std::vector& scales) const { diff --git a/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu b/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu index baf1ac8c0b..6c723efff8 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu @@ -127,6 +127,7 @@ SPECIALIZED_IMPL(float) SPECIALIZED_IMPL(double) SPECIALIZED_IMPL(half) SPECIALIZED_IMPL(int32_t) +SPECIALIZED_IMPL(uint8_t) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc b/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc index f129c5c3f7..d954c49af6 100644 --- a/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc @@ -70,6 +70,37 @@ TEST(UpsampleOpTest, UpsampleOpNearestTest_int32) { test.Run(); } +TEST(UpsampleOpTest, UpsampleOpNearestTest_uint8) { + OpTester test("Upsample"); + + std::vector scales{1.0f, 1.0f, 2.0f, 3.0f}; + test.AddAttribute("mode", "nearest"); + test.AddAttribute("scales", scales); + + const int64_t N = 1, C = 2, H = 2, W = 2; + std::vector X = {1, 3, + 3, 5, + + 3, 5, + 7, 9}; + + test.AddInput("X", {N, C, H, W}, X); + + std::vector Y = { + 1, 1, 1, 3, 3, 3, + 1, 1, 1, 3, 3, 3, + 3, 3, 3, 5, 5, 5, + 3, 3, 3, 5, 5, 5, + + 3, 3, 3, 5, 5, 5, + 3, 3, 3, 5, 5, 5, + 7, 7, 7, 9, 9, 9, + 7, 7, 7, 9, 9, 9}; + + test.AddOutput("Y", {N, C, (int64_t)(H * scales[2]), (int64_t)(W * scales[3])}, Y); + test.Run(); +} + TEST(UpsampleOpTest, UpsampleOpNearest2XTest) { OpTester test("Upsample");