diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index a685396a75..b62a799fef 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -119,6 +119,24 @@ __device__ __inline__ double _Erf(double a) { return erf(a); } template <> __device__ __inline__ half _Erf(half a) { return half(erff((float)a)); } +template +__device__ __inline__ T _Round(T a); + +template <> +__device__ __inline__ float _Round(float a) { return rintf(a); } + +template <> +__device__ __inline__ double _Round(double a) { return rint(a); } + +template <> +__device__ __inline__ half _Round(half a) { +#if __CUDA_ARCH__ < 530 + return half(rintf((float)a)); +#else + return hrint(a); +#endif +} + template __device__ __inline__ T _Exp(T a); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 6c15a7ffe2..375d714612 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -667,6 +667,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, bool, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int32_t, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int64_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Round); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Round); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Round); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, CumSum); static void RegisterCudaKernels(KernelRegistry& kernel_registry) { @@ -1129,6 +1132,9 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, }; diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index af58f32c9c..3fec649929 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -23,15 +23,14 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ x); -#define UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(x, ver, T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - x, \ - kOnnxDomain, \ - ver, \ - T, \ - kCudaExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ +#define UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(x, ver, T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + x, \ + kOnnxDomain, \ + ver, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()).TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ x); #define UNARY_ELEMENTWISE_COMPUTE(x, T) \ @@ -51,7 +50,7 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa UNARY_ELEMENTWISE_REGISTER_KERNEL(name, ver, T) \ UNARY_ELEMENTWISE_COMPUTE(name, T) -#define UNARY_LOGICALOP_TYPED(name, ver, T) \ +#define UNARY_LOGICALOP_TYPED(name, ver, T) \ UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, T) \ UNARY_ELEMENTWISE_COMPUTE(name, T) @@ -98,6 +97,7 @@ UNARY_OP_HFD(Log, 6) UNARY_OP_HFD(Exp, 6) UNARY_OP_HFD(Erf, 9) UNARY_LOGICALOP_TYPED(Not, 1, bool) +UNARY_OP_HFD(Round, 11) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h index 1bd6d2160c..efeb26783b 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h @@ -91,5 +91,12 @@ class Not final : public UnaryElementwise { Status ComputeInternal(OpKernelContext* context) const override; }; +template +class Round final : public UnaryElementwise { + public: + Round(const OpKernelInfo& info) : UnaryElementwise(info) {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index a8ae4315a8..41a93cf4cc 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -77,6 +77,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Log) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Exp) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round) SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool) // When casting, half needs to be converted via float type from most other types diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index c22f597507..81123c46bf 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -23,7 +23,8 @@ namespace cuda { UNARY_OP_NAME_EXPR(Exp, _Exp(a)) \ UNARY_OP_NAME_EXPR(Log, _Log(a)) \ UNARY_OP_NAME_EXPR(Erf, _Erf(a)) \ - UNARY_OP_NAME_EXPR(Not, !a) + UNARY_OP_NAME_EXPR(Not, !a) \ + UNARY_OP_NAME_EXPR(Round, _Round(a)) #define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \ template \