mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Support Round op in the CUDA EP (#2601)
* Support Round op for the CUDA EP * Update version * Fix build * Fix opset version * Update * PR comments * Fix build * Nit
This commit is contained in:
parent
715e365723
commit
b3d0b114fe
6 changed files with 44 additions and 11 deletions
|
|
@ -119,6 +119,24 @@ __device__ __inline__ double _Erf(double a) { return erf(a); }
|
|||
template <>
|
||||
__device__ __inline__ half _Erf(half a) { return half(erff((float)a)); }
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline__ T _Round(T a);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ float _Round(float a) { return rintf(a); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ double _Round(double a) { return rint(a); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ half _Round(half a) {
|
||||
#if __CUDA_ARCH__ < 530
|
||||
return half(rintf((float)a));
|
||||
#else
|
||||
return hrint(a);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline__ T _Exp(T a);
|
||||
|
||||
|
|
|
|||
|
|
@ -667,6 +667,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, bool, Equal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int32_t, Equal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int64_t, Equal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Round);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Round);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Round);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, CumSum);
|
||||
|
||||
static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
||||
|
|
@ -1129,6 +1132,9 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, bool, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int32_t, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int64_t, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Round)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Round)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Round)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, CumSum)>,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -23,15 +23,14 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
x<T>);
|
||||
|
||||
#define UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(x, ver, T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
x, \
|
||||
kOnnxDomain, \
|
||||
ver, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
|
||||
#define UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(x, ver, T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
x, \
|
||||
kOnnxDomain, \
|
||||
ver, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()).TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
|
||||
x<T>);
|
||||
|
||||
#define UNARY_ELEMENTWISE_COMPUTE(x, T) \
|
||||
|
|
@ -51,7 +50,7 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
|
|||
UNARY_ELEMENTWISE_REGISTER_KERNEL(name, ver, T) \
|
||||
UNARY_ELEMENTWISE_COMPUTE(name, T)
|
||||
|
||||
#define UNARY_LOGICALOP_TYPED(name, ver, T) \
|
||||
#define UNARY_LOGICALOP_TYPED(name, ver, T) \
|
||||
UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, T) \
|
||||
UNARY_ELEMENTWISE_COMPUTE(name, T)
|
||||
|
||||
|
|
@ -98,6 +97,7 @@ UNARY_OP_HFD(Log, 6)
|
|||
UNARY_OP_HFD(Exp, 6)
|
||||
UNARY_OP_HFD(Erf, 9)
|
||||
UNARY_LOGICALOP_TYPED(Not, 1, bool)
|
||||
UNARY_OP_HFD(Round, 11)
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -91,5 +91,12 @@ class Not final : public UnaryElementwise {
|
|||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Round final : public UnaryElementwise {
|
||||
public:
|
||||
Round(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt)
|
|||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Log)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Exp)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool)
|
||||
|
||||
// When casting, half needs to be converted via float type from most other types
|
||||
|
|
|
|||
|
|
@ -23,7 +23,8 @@ namespace cuda {
|
|||
UNARY_OP_NAME_EXPR(Exp, _Exp(a)) \
|
||||
UNARY_OP_NAME_EXPR(Log, _Log(a)) \
|
||||
UNARY_OP_NAME_EXPR(Erf, _Erf(a)) \
|
||||
UNARY_OP_NAME_EXPR(Not, !a)
|
||||
UNARY_OP_NAME_EXPR(Not, !a) \
|
||||
UNARY_OP_NAME_EXPR(Round, _Round(a))
|
||||
|
||||
#define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \
|
||||
template <typename T> \
|
||||
|
|
|
|||
Loading…
Reference in a new issue