Support int64 in ReduceMin cuda op for Opset 14 (#8307)

* reducemin int64_t support

* fix xxcuda.so load error

* testtest

* refactor

* update doc

* propagate types to opset14

* re-generate doc

* rename macro
This commit is contained in:
Ye Wang 2021-07-13 16:18:06 -07:00 committed by GitHub
parent 8d8db7c9f0
commit 04297110c3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 78 additions and 29 deletions

View file

@ -607,7 +607,8 @@ Do not modify directly.*
|ReduceMean|*in* data:**T**<br> *out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
|ReduceMin|*in* data:**T**<br> *out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int8), tensor(uint8)|
|ReduceMin|*in* data:**T**<br> *out* reduced:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|||13|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int8), tensor(uint8)|
|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int8), tensor(uint8)|
|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|

View file

@ -1058,12 +1058,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMean);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMean);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMean);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, ReduceMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, float, ReduceMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, double, ReduceMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, ReduceMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, ReduceMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int8_t, ReduceMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, uint8_t, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceProd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceProd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceProd);
@ -1159,6 +1159,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, ReduceMin);
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub);
@ -1885,12 +1893,12 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, float, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, double, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int8_t, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, uint8_t, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceProd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceProd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceProd)>,
@ -1985,6 +1993,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, ReduceMin)>,
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub)>,

View file

@ -40,8 +40,7 @@ namespace cuda {
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);
// Register those with changes in OpSet12.
#define REGISTER_KERNEL_TYPED_12(name, T) \
#define REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
@ -65,7 +64,11 @@ namespace cuda {
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>); \
name<T>);
// Register those with changes in OpSet12.
#define REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(name, T) \
REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
@ -75,6 +78,28 @@ namespace cuda {
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);
#define REGISTER_KERNEL_VERSIONED_TYPED_13(name, T) \
REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
13, 13, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);
// Register ReduceMin int64_t support in OpSet14.
#define REGISTER_KERNEL_TYPED_14(name, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
14, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);
// CUDA ArgMax/ArgMin doesn't have OpSet12 implementation (with select_last_index attr), keep it in OpSet11 for now.
#define REGISTER_KERNEL_TYPED_11(name, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
@ -956,22 +981,30 @@ REGISTER_KERNEL_HFD_11(ArgMin)
REGISTER_KERNEL_HFD(ReduceL1)
REGISTER_KERNEL_HFD(ReduceL2)
REGISTER_KERNEL_TYPED_12(ReduceMax, MLFloat16)
REGISTER_KERNEL_TYPED_12(ReduceMax, float)
REGISTER_KERNEL_TYPED_12(ReduceMax, double)
REGISTER_KERNEL_TYPED_12(ReduceMax, int32_t)
REGISTER_KERNEL_TYPED_12(ReduceMax, int64_t)
REGISTER_KERNEL_TYPED_12(ReduceMax, int8_t)
REGISTER_KERNEL_TYPED_12(ReduceMax, uint8_t)
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, MLFloat16)
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, float)
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, double)
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, int32_t)
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, int64_t)
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, int8_t)
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, uint8_t)
REGISTER_KERNEL_HFD(ReduceMean)
REGISTER_KERNEL_TYPED_12(ReduceMin, MLFloat16)
REGISTER_KERNEL_TYPED_12(ReduceMin, float)
REGISTER_KERNEL_TYPED_12(ReduceMin, double)
REGISTER_KERNEL_TYPED_12(ReduceMin, int32_t)
REGISTER_KERNEL_TYPED_12(ReduceMin, int8_t)
REGISTER_KERNEL_TYPED_12(ReduceMin, uint8_t)
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, MLFloat16)
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, float)
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, double)
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, int32_t)
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, int8_t)
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, uint8_t)
REGISTER_KERNEL_TYPED_14(ReduceMin, MLFloat16)
REGISTER_KERNEL_TYPED_14(ReduceMin, float)
REGISTER_KERNEL_TYPED_14(ReduceMin, double)
REGISTER_KERNEL_TYPED_14(ReduceMin, int32_t)
REGISTER_KERNEL_TYPED_14(ReduceMin, int8_t)
REGISTER_KERNEL_TYPED_14(ReduceMin, uint8_t)
REGISTER_KERNEL_TYPED_14(ReduceMin, int64_t)
REGISTER_KERNEL_HFD(ReduceProd)