mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Support uint8 datatype for Upsample op in CPU and CUDA providers (#440)
This commit is contained in:
parent
2062c49033
commit
d35409f58e
6 changed files with 44 additions and 0 deletions
|
|
@ -209,6 +209,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Tra
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Unsqueeze);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Upsample);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, int32_t, Upsample);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, uint8_t, Upsample);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int8_t, Expand);
|
||||
|
|
@ -453,6 +454,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Unsqueeze)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Upsample)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, int32_t, Upsample)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, uint8_t, Upsample)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Expand)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Expand)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int8_t, Expand)>());
|
||||
|
|
|
|||
|
|
@ -22,6 +22,13 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
Upsample<int32_t>);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
Upsample,
|
||||
7, 9,
|
||||
uint8_t,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<uint8_t>()),
|
||||
Upsample<uint8_t>);
|
||||
|
||||
template <typename T>
|
||||
void UpsampleNearest2x(
|
||||
int64_t batch_size,
|
||||
|
|
|
|||
|
|
@ -501,6 +501,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, double, Upsample);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, Upsample);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, int32_t, Upsample);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, uint8_t, Upsample);
|
||||
|
||||
static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyFromHost)>());
|
||||
|
|
@ -764,6 +765,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, double, Upsample)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, Upsample)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, int32_t, Upsample)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, uint8_t, Upsample)>());
|
||||
}
|
||||
|
||||
std::shared_ptr<KernelRegistry> GetCudaKernelRegistry() {
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ REGISTER_KERNEL_TYPED(float)
|
|||
REGISTER_KERNEL_TYPED(double)
|
||||
REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
REGISTER_KERNEL_TYPED(int32_t)
|
||||
REGISTER_KERNEL_TYPED(uint8_t)
|
||||
|
||||
template <typename T>
|
||||
Status Upsample<T>::BaseCompute(OpKernelContext* context, const std::vector<float>& scales) const {
|
||||
|
|
|
|||
|
|
@ -127,6 +127,7 @@ SPECIALIZED_IMPL(float)
|
|||
SPECIALIZED_IMPL(double)
|
||||
SPECIALIZED_IMPL(half)
|
||||
SPECIALIZED_IMPL(int32_t)
|
||||
SPECIALIZED_IMPL(uint8_t)
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -70,6 +70,37 @@ TEST(UpsampleOpTest, UpsampleOpNearestTest_int32) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(UpsampleOpTest, UpsampleOpNearestTest_uint8) {
|
||||
OpTester test("Upsample");
|
||||
|
||||
std::vector<float> scales{1.0f, 1.0f, 2.0f, 3.0f};
|
||||
test.AddAttribute("mode", "nearest");
|
||||
test.AddAttribute("scales", scales);
|
||||
|
||||
const int64_t N = 1, C = 2, H = 2, W = 2;
|
||||
std::vector<uint8_t> X = {1, 3,
|
||||
3, 5,
|
||||
|
||||
3, 5,
|
||||
7, 9};
|
||||
|
||||
test.AddInput<uint8_t>("X", {N, C, H, W}, X);
|
||||
|
||||
std::vector<uint8_t> Y = {
|
||||
1, 1, 1, 3, 3, 3,
|
||||
1, 1, 1, 3, 3, 3,
|
||||
3, 3, 3, 5, 5, 5,
|
||||
3, 3, 3, 5, 5, 5,
|
||||
|
||||
3, 3, 3, 5, 5, 5,
|
||||
3, 3, 3, 5, 5, 5,
|
||||
7, 7, 7, 9, 9, 9,
|
||||
7, 7, 7, 9, 9, 9};
|
||||
|
||||
test.AddOutput<uint8_t>("Y", {N, C, (int64_t)(H * scales[2]), (int64_t)(W * scales[3])}, Y);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(UpsampleOpTest, UpsampleOpNearest2XTest) {
|
||||
OpTester test("Upsample");
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue