mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Support int64 in ReduceMin cuda op for Opset 14 (#8307)
* reducemin int64_t support * fix xxcuda.so load error * testtest * refactor * update doc * propagate types to opset14 * re-generate doc * rename macro
This commit is contained in:
parent
8d8db7c9f0
commit
04297110c3
3 changed files with 78 additions and 29 deletions
|
|
@ -607,7 +607,8 @@ Do not modify directly.*
|
|||
|ReduceMean|*in* data:**T**<br> *out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
|
||||
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
|
||||
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
|
||||
|ReduceMin|*in* data:**T**<br> *out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int8), tensor(uint8)|
|
||||
|ReduceMin|*in* data:**T**<br> *out* reduced:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|
||||
|||13|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int8), tensor(uint8)|
|
||||
|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int8), tensor(uint8)|
|
||||
|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
|
||||
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
|
||||
|
|
|
|||
|
|
@ -1058,12 +1058,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMean);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMean);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMean);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, ReduceMin);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, float, ReduceMin);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, double, ReduceMin);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, ReduceMin);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, ReduceMin);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int8_t, ReduceMin);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, uint8_t, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceProd);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceProd);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceProd);
|
||||
|
|
@ -1159,6 +1159,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, ReduceMin);
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub);
|
||||
|
|
@ -1885,12 +1893,12 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMean)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMean)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMean)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, float, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, double, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int8_t, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, uint8_t, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceProd)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceProd)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceProd)>,
|
||||
|
|
@ -1985,6 +1993,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, ReduceMin)>,
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub)>,
|
||||
|
|
|
|||
|
|
@ -40,8 +40,7 @@ namespace cuda {
|
|||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
name<T>);
|
||||
|
||||
// Register those with changes in OpSet12.
|
||||
#define REGISTER_KERNEL_TYPED_12(name, T) \
|
||||
#define REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
name, \
|
||||
kOnnxDomain, \
|
||||
|
|
@ -65,7 +64,11 @@ namespace cuda {
|
|||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
name<T>); \
|
||||
name<T>);
|
||||
|
||||
// Register those with changes in OpSet12.
|
||||
#define REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(name, T) \
|
||||
REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
name, \
|
||||
kOnnxDomain, \
|
||||
|
|
@ -75,6 +78,28 @@ namespace cuda {
|
|||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
name<T>);
|
||||
|
||||
#define REGISTER_KERNEL_VERSIONED_TYPED_13(name, T) \
|
||||
REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
name, \
|
||||
kOnnxDomain, \
|
||||
13, 13, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
name<T>);
|
||||
|
||||
// Register ReduceMin int64_t support in OpSet14.
|
||||
#define REGISTER_KERNEL_TYPED_14(name, T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
name, \
|
||||
kOnnxDomain, \
|
||||
14, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
name<T>);
|
||||
|
||||
// CUDA ArgMax/ArgMin doesn't have OpSet12 implementation (with select_last_index attr), keep it in OpSet11 for now.
|
||||
#define REGISTER_KERNEL_TYPED_11(name, T) \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
|
|
@ -956,22 +981,30 @@ REGISTER_KERNEL_HFD_11(ArgMin)
|
|||
REGISTER_KERNEL_HFD(ReduceL1)
|
||||
REGISTER_KERNEL_HFD(ReduceL2)
|
||||
|
||||
REGISTER_KERNEL_TYPED_12(ReduceMax, MLFloat16)
|
||||
REGISTER_KERNEL_TYPED_12(ReduceMax, float)
|
||||
REGISTER_KERNEL_TYPED_12(ReduceMax, double)
|
||||
REGISTER_KERNEL_TYPED_12(ReduceMax, int32_t)
|
||||
REGISTER_KERNEL_TYPED_12(ReduceMax, int64_t)
|
||||
REGISTER_KERNEL_TYPED_12(ReduceMax, int8_t)
|
||||
REGISTER_KERNEL_TYPED_12(ReduceMax, uint8_t)
|
||||
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, MLFloat16)
|
||||
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, float)
|
||||
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, double)
|
||||
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, int32_t)
|
||||
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, int64_t)
|
||||
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, int8_t)
|
||||
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, uint8_t)
|
||||
|
||||
REGISTER_KERNEL_HFD(ReduceMean)
|
||||
|
||||
REGISTER_KERNEL_TYPED_12(ReduceMin, MLFloat16)
|
||||
REGISTER_KERNEL_TYPED_12(ReduceMin, float)
|
||||
REGISTER_KERNEL_TYPED_12(ReduceMin, double)
|
||||
REGISTER_KERNEL_TYPED_12(ReduceMin, int32_t)
|
||||
REGISTER_KERNEL_TYPED_12(ReduceMin, int8_t)
|
||||
REGISTER_KERNEL_TYPED_12(ReduceMin, uint8_t)
|
||||
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, MLFloat16)
|
||||
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, float)
|
||||
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, double)
|
||||
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, int32_t)
|
||||
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, int8_t)
|
||||
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, uint8_t)
|
||||
|
||||
REGISTER_KERNEL_TYPED_14(ReduceMin, MLFloat16)
|
||||
REGISTER_KERNEL_TYPED_14(ReduceMin, float)
|
||||
REGISTER_KERNEL_TYPED_14(ReduceMin, double)
|
||||
REGISTER_KERNEL_TYPED_14(ReduceMin, int32_t)
|
||||
REGISTER_KERNEL_TYPED_14(ReduceMin, int8_t)
|
||||
REGISTER_KERNEL_TYPED_14(ReduceMin, uint8_t)
|
||||
REGISTER_KERNEL_TYPED_14(ReduceMin, int64_t)
|
||||
|
||||
REGISTER_KERNEL_HFD(ReduceProd)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue