From 04297110c3e0eb317b6bfd2e4476da7d3b0fe362 Mon Sep 17 00:00:00 2001
From: Ye Wang <52801275+wangyems@users.noreply.github.com>
Date: Tue, 13 Jul 2021 16:18:06 -0700
Subject: [PATCH] Support int64 in ReduceMin cuda op for Opset 14 (#8307)
* reducemin int64_t support
* fix xxcuda.so load error
* testtest
* refactor
* update doc
* propagate types to opset14
* re-generate doc
* rename macro
---
docs/OperatorKernels.md | 3 +-
.../providers/cuda/cuda_execution_provider.cc | 39 +++++++----
.../providers/cuda/reduction/reduction_ops.cc | 65 ++++++++++++++-----
3 files changed, 78 insertions(+), 29 deletions(-)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index a844c2b4f1..3665bae302 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -607,7 +607,8 @@ Do not modify directly.*
|ReduceMean|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
-|ReduceMin|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int8), tensor(uint8)|
+|ReduceMin|*in* data:**T**
*out* reduced:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
+|||13|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int8), tensor(uint8)|
|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int8), tensor(uint8)|
|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index 9e210932f1..900674b932 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1058,12 +1058,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMean);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMean);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMean);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceMin);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMin);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMin);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMin);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, ReduceMin);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, ReduceMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, float, ReduceMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, double, ReduceMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, ReduceMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, ReduceMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int8_t, ReduceMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, uint8_t, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceProd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceProd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceProd);
@@ -1159,6 +1159,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, BatchNormalization);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, ReduceMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, ReduceMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, ReduceMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, ReduceMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, ReduceMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, ReduceMin);
+
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub);
@@ -1885,12 +1893,12 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -1985,6 +1993,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
index 338442b887..783e6beb3a 100644
--- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
+++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
@@ -40,8 +40,7 @@ namespace cuda {
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \
name);
-// Register those with changes in OpSet12.
-#define REGISTER_KERNEL_TYPED_12(name, T) \
+#define REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
@@ -65,7 +64,11 @@ namespace cuda {
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \
- name); \
+ name);
+
+// Register those with changes in OpSet12.
+#define REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(name, T) \
+ REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
@@ -75,6 +78,28 @@ namespace cuda {
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \
name);
+#define REGISTER_KERNEL_VERSIONED_TYPED_13(name, T) \
+ REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
+ name, \
+ kOnnxDomain, \
+ 13, 13, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ name);
+
+// Register ReduceMin int64_t support in OpSet14.
+#define REGISTER_KERNEL_TYPED_14(name, T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ name, \
+ kOnnxDomain, \
+ 14, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ name);
+
// CUDA ArgMax/ArgMin doesn't have OpSet12 implementation (with select_last_index attr), keep it in OpSet11 for now.
#define REGISTER_KERNEL_TYPED_11(name, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
@@ -956,22 +981,30 @@ REGISTER_KERNEL_HFD_11(ArgMin)
REGISTER_KERNEL_HFD(ReduceL1)
REGISTER_KERNEL_HFD(ReduceL2)
-REGISTER_KERNEL_TYPED_12(ReduceMax, MLFloat16)
-REGISTER_KERNEL_TYPED_12(ReduceMax, float)
-REGISTER_KERNEL_TYPED_12(ReduceMax, double)
-REGISTER_KERNEL_TYPED_12(ReduceMax, int32_t)
-REGISTER_KERNEL_TYPED_12(ReduceMax, int64_t)
-REGISTER_KERNEL_TYPED_12(ReduceMax, int8_t)
-REGISTER_KERNEL_TYPED_12(ReduceMax, uint8_t)
+REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, MLFloat16)
+REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, float)
+REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, double)
+REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, int32_t)
+REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, int64_t)
+REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, int8_t)
+REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, uint8_t)
REGISTER_KERNEL_HFD(ReduceMean)
-REGISTER_KERNEL_TYPED_12(ReduceMin, MLFloat16)
-REGISTER_KERNEL_TYPED_12(ReduceMin, float)
-REGISTER_KERNEL_TYPED_12(ReduceMin, double)
-REGISTER_KERNEL_TYPED_12(ReduceMin, int32_t)
-REGISTER_KERNEL_TYPED_12(ReduceMin, int8_t)
-REGISTER_KERNEL_TYPED_12(ReduceMin, uint8_t)
+REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, MLFloat16)
+REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, float)
+REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, double)
+REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, int32_t)
+REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, int8_t)
+REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, uint8_t)
+
+REGISTER_KERNEL_TYPED_14(ReduceMin, MLFloat16)
+REGISTER_KERNEL_TYPED_14(ReduceMin, float)
+REGISTER_KERNEL_TYPED_14(ReduceMin, double)
+REGISTER_KERNEL_TYPED_14(ReduceMin, int32_t)
+REGISTER_KERNEL_TYPED_14(ReduceMin, int8_t)
+REGISTER_KERNEL_TYPED_14(ReduceMin, uint8_t)
+REGISTER_KERNEL_TYPED_14(ReduceMin, int64_t)
REGISTER_KERNEL_HFD(ReduceProd)